从一道算法题理解最近公共祖先 LCA

从暴力到倍增 LCA:一题理解树上路径优化

题目描述

LeetCode 3558

给你一棵 n 个节点的无向树,节点编号为 1 ~ n,1 为根。

1
edges[i] = [u, v] 表示 u 和 v 之间有一条边。

每条边可以赋值为 1 或 2。对于每次查询 (u, v):

  1. 找到 u → v 的唯一路径(树上任意两点之间有且仅有一条路径)
  2. 计算路径上的边数 dist
  3. 返回 2^(dist - 1) mod (10^9 + 7)(为什么是这个公式?可以参考 给边赋权值的方案数 I 题解

举个例子:

1
2
3
4
5
    1
/ \
2 3
/ \
4 5

查询 (5, 3):

1
2
3
路径:5 → 2 → 1 → 3
边数:dist = 3
答案:2^(3-1) = 4

每次查询都需要快速求出树上两点之间的距离,这就是 LCA(最近公共祖先)的典型应用。

怎么求距离?关键在于找到 u 和 v 的最近公共祖先——知道了 LCA,距离就是 depth[u] + depth[v] - 2 * depth[lca]

那先不管优化,最朴素的做法是什么?

暴力思路

从 u 和 v 各自往上爬,谁深谁先走,直到碰面。

1
2
3
4
5
6
7
8
9
function distance(u, v):
while u != v:
if depth[u] > depth[v]:
u = parent[u]
else:
v = parent[v]
step++

return step

问题很明显:单次 O(n),多次查询 O(nq),数据一大直接 TLE。


优化思路:从逐格走到倍增跳跃

观察暴力过程会发现:每次查询,本质都在重复”向上走同一条路径”。

1
5 → 4 → 3 → 2 → 1

每次都从头走一遍,非常浪费。那能不能一次多跳几步?

如果能一次跳 2 步、4 步、8 步 … 那走 13 步就只需要跳 3 次:

1
13 = 8 + 4 + 1   →   跳 3 次就到了

这就是二进制拆分——任何步数都能拆成若干个 2 的幂之和。顺着这个想法,只需要预处理出每个节点向上跳 2^k 步到达哪里,查询时就能 O(log n) 拼出任意路径。

这就是倍增(Binary Lifting) 的由来。


倍增 LCA 的核心思想

定义一个数组:

1
up[u][k] = u 向上跳 2^k 步到达的节点

比如 up[5][0] 是 5 的父节点(跳 1 步),up[5][1] 是向上跳 2 步到达的节点,up[5][2] 是跳 4 步……

有了这张表,任意距离的跳跃都可以用二进制拆分完成。

递推关系

1
up[u][k] = up[ up[u][k-1] ][k-1]

意思很简单:想跳 2^k 步,先跳 2^(k-1) 步到中间节点,再从那里跳 2^(k-1) 步。

1
2
u ──2^(k-1)步──> 中间节点 ──2^(k-1)步──> 目标节点
合计:2^(k-1) + 2^(k-1) = 2^k 步

大步拆两小步,小步已经预处理好了。


示例推导

用一棵树来演示 up 表是怎么算的:

graph TD 1((1)) 2((2)) 3((3)) 4((4)) 5((5)) 6((6)) 1 --> 2 1 --> 3 2 --> 4 2 --> 5 3 --> 6

其中 5 → 4 → 3 → 2 → 1 是一条链,看 up[5][k] 怎么递推。

up[5][0]:跳 1 步

1
up[5][0] = 4     (5 的父节点)

up[5][1]:跳 2 步

1
2
3
up[5][1] = up[ up[5][0] ][0]
= up[4][0]
= 3

从 5 跳 1 步到 4,再从 4 跳 1 步到 3,合计 2 步。

1
2
5 → 4 → 3
(1步) (1步) = 2步

up[5][2]:跳 4 步

1
2
up[5][2] = up[ up[5][1] ][1]
= up[3][1]

up[3][1] 是 3 向上跳 2 步:

1
2
3
up[3][1] = up[ up[3][0] ][0]
= up[2][0]
= 1

所以 up[5][2] = 1,从 5 跳 2 步到 3,再从 3 跳 2 步到 1,合计 4 步。

1
2
5 → 4 → 3 → 2 → 1
(2步) (2步) = 4步

汇总:

k 2^k up[5][k] 拆分方式
0 1 4 直接跳 1 步
1 2 3 1 步 + 1 步
2 4 1 2 步 + 2 步

LCA 查询

分两步:

Step 1:拉齐深度

1
2
3
4
function lift(u, diff):
for k from LOG downto 0:
if diff >= 2^k:
u = up[u][k]

Step 2:一起跳

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
function lca(u, v):
if depth[u] < depth[v]:
swap(u, v)

lift u to same depth as v

if u == v:
return u

for k from LOG downto 0:
if up[u][k] != up[v][k]:
u = up[u][k]
v = up[v][k]

return up[u][0]

为什么这样就能找到 LCA?

拉齐深度后,u 和 v 到 LCA 的距离为什么相等?

因为 LCA 是公共祖先,到 LCA 的距离 = 自身深度 - LCA 深度。深度相同,距离就相同。

“一起跳”为什么能精确停在 LCA 的下一层?

关键在 up[u][k] != up[v][k] 这个判断:

  • !=:2^k 级祖先还没汇合,还在 LCA 下面,可以安全跳
  • ==:2^k 级祖先已经相同了,跳过去就越过 LCA 了,不跳

从大到小遍历 k,就是在用二进制逐步逼近 LCA:

1
2
3
4
5
距离 LCA 还有 5 步:5 = 4 + 1

k=2: 2^2=4, 跳了不会越过 → 跳 (剩 1 步)
k=1: 2^1=2, 跳了会越过 → 不跳
k=0: 2^0=1, 跳了不会越过 → 跳 (剩 0 步)

循环结束后,u 和 v 刚好停在 LCA 的正下方,所以 up[u][0] 就是 LCA。


距离计算

1
dist = depth[u] + depth[v] - 2 * depth[lca]

回到题目

每条边赋值 1 或 2,路径上 dist 条边共有 2^dist 种赋值方案。

权值和的奇偶性取决于有多少条边赋值为 1。每条边独立选择,恰好一半方案的权值和为奇数:

1
奇数和的方案数 = 2^dist / 2 = 2^(dist - 1)

所以只要快速算出 dist,就能 O(log n) 得到答案:

1
2
3
lca = LCA(u, v)
dist = depth[u] + depth[v] - 2 * depth[lca]
ans = 2^(dist - 1) mod (10^9 + 7)

Go 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package leetcode

const MOD int64 = 1_000_000_007

func modPow(a, b int64) int64 {
res := int64(1)
for b > 0 {
if b&1 == 1 {
res = res * a % MOD
}
a = a * a % MOD
b >>= 1
}
return res
}

func assignEdgeWeights(edges [][]int, queries [][]int) []int {
n := len(edges) + 1

graph := make([][]int, n+1)
for _, edge := range edges {
graph[edge[0]] = append(graph[edge[0]], edge[1])
graph[edge[1]] = append(graph[edge[1]], edge[0])
}

const LOG = 18

up := make([][LOG]int, n+1)
q := []int{1}
depth := make([]int, n+1)
depth[1] = 1
for len(q) > 0 {
size := len(q)
for _, node := range q[:size] {
for _, nextNode := range graph[node] {
if depth[nextNode] == 0 {
up[nextNode][0] = node
depth[nextNode] = depth[node] + 1
q = append(q, nextNode)
}
}
}
q = q[size:]
}

for i := 1; i < LOG; i++ {
for u := 1; u <= n; u++ {
up[u][i] = up[up[u][i-1]][i-1]
}
}

result := make([]int, 0, len(queries))

var lca func(u, v int) int

lca = func(u, v int) int {
if depth[u] > depth[v] {
diff := depth[u] - depth[v]
for k := LOG - 1; k >= 0; k-- {
if diff&(1<<k) != 0 {
u = up[u][k]
}
}
}

if depth[v] > depth[u] {
diff := depth[v] - depth[u]
for k := LOG - 1; k >= 0; k-- {
if diff&(1<<k) != 0 {
v = up[v][k]
}
}
}

if u == v {
return u
}

for k := LOG - 1; k >= 0; k-- {
if up[u][k] != up[v][k] {
u = up[u][k]
v = up[v][k]
}
}

return up[u][0]
}

for _, querie := range queries {
u, v := querie[0], querie[1]

if v == u {
result = append(result, 0)
continue
}

distance := depth[u] + depth[v] - 2*depth[lca(u, v)]

result = append(result, int(modPow(2, int64(distance-1))))
}

return result
}

总结

这道题的核心不在公式,而在如何快速求 dist。

树上两点距离 = 两点到 LCA 的距离之和,暴力求 LCA 每次 O(n),多次查询会 TLE。倍增 LCA 通过预处理 2^k 级别的跳跃表,把查询优化到 O(log n)——本质就是把”逐层爬树”变成”二进制拆分跳跃”。