从暴力到倍增 LCA:一题理解树上路径优化
题目描述
LeetCode 3558
给你一棵 n 个节点的无向树,节点编号为 1 ~ n,1 为根。
1
| edges[i] = [u, v] 表示 u 和 v 之间有一条边。
|
每条边可以赋值为 1 或 2。对于每次查询 (u, v):
- 找到 u → v 的唯一路径(树上任意两点之间有且仅有一条路径)
- 计算路径上的边数 dist
- 返回
2^(dist - 1) mod (10^9 + 7)(为什么是这个公式?可以参考 给边赋权值的方案数 I 题解)
举个例子:
查询 (5, 3):
1 2 3
| 路径:5 → 2 → 1 → 3 边数:dist = 3 答案:2^(3-1) = 4
|
每次查询都需要快速求出树上两点之间的距离,这就是 LCA(最近公共祖先)的典型应用。
怎么求距离?关键在于找到 u 和 v 的最近公共祖先——知道了 LCA,距离就是 depth[u] + depth[v] - 2 * depth[lca]。
那先不管优化,最朴素的做法是什么?
暴力思路
从 u 和 v 各自往上爬,谁深谁先走,直到碰面。
1 2 3 4 5 6 7 8 9
| function distance(u, v): while u != v: if depth[u] > depth[v]: u = parent[u] else: v = parent[v] step++
return step
|
问题很明显:单次 O(n),多次查询 O(nq),数据一大直接 TLE。
优化思路:从逐格走到倍增跳跃
观察暴力过程会发现:每次查询,本质都在重复”向上走同一条路径”。
每次都从头走一遍,非常浪费。那能不能一次多跳几步?
如果能一次跳 2 步、4 步、8 步 … 那走 13 步就只需要跳 3 次:
1
| 13 = 8 + 4 + 1 → 跳 3 次就到了
|
这就是二进制拆分——任何步数都能拆成若干个 2 的幂之和。顺着这个想法,只需要预处理出每个节点向上跳 2^k 步到达哪里,查询时就能 O(log n) 拼出任意路径。
这就是倍增(Binary Lifting) 的由来。
倍增 LCA 的核心思想
定义一个数组:
1
| up[u][k] = u 向上跳 2^k 步到达的节点
|
比如 up[5][0] 是 5 的父节点(跳 1 步),up[5][1] 是向上跳 2 步到达的节点,up[5][2] 是跳 4 步……
有了这张表,任意距离的跳跃都可以用二进制拆分完成。
递推关系
1
| up[u][k] = up[ up[u][k-1] ][k-1]
|
意思很简单:想跳 2^k 步,先跳 2^(k-1) 步到中间节点,再从那里跳 2^(k-1) 步。
1 2
| u ──2^(k-1)步──> 中间节点 ──2^(k-1)步──> 目标节点 合计:2^(k-1) + 2^(k-1) = 2^k 步
|
大步拆两小步,小步已经预处理好了。
示例推导
用一棵树来演示 up 表是怎么算的:
graph TD
1((1))
2((2))
3((3))
4((4))
5((5))
6((6))
1 --> 2
1 --> 3
2 --> 4
2 --> 5
3 --> 6
其中 5 → 4 → 3 → 2 → 1 是一条链,看 up[5][k] 怎么递推。
up[5][0]:跳 1 步
up[5][1]:跳 2 步
1 2 3
| up[5][1] = up[ up[5][0] ][0] = up[4][0] = 3
|
从 5 跳 1 步到 4,再从 4 跳 1 步到 3,合计 2 步。
1 2
| 5 → 4 → 3 (1步) (1步) = 2步
|
up[5][2]:跳 4 步
1 2
| up[5][2] = up[ up[5][1] ][1] = up[3][1]
|
而 up[3][1] 是 3 向上跳 2 步:
1 2 3
| up[3][1] = up[ up[3][0] ][0] = up[2][0] = 1
|
所以 up[5][2] = 1,从 5 跳 2 步到 3,再从 3 跳 2 步到 1,合计 4 步。
1 2
| 5 → 4 → 3 → 2 → 1 (2步) (2步) = 4步
|
汇总:
| k |
2^k |
up[5][k] |
拆分方式 |
| 0 |
1 |
4 |
直接跳 1 步 |
| 1 |
2 |
3 |
1 步 + 1 步 |
| 2 |
4 |
1 |
2 步 + 2 步 |
LCA 查询
分两步:
Step 1:拉齐深度
1 2 3 4
| function lift(u, diff): for k from LOG downto 0: if diff >= 2^k: u = up[u][k]
|
Step 2:一起跳
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| function lca(u, v): if depth[u] < depth[v]: swap(u, v)
lift u to same depth as v
if u == v: return u
for k from LOG downto 0: if up[u][k] != up[v][k]: u = up[u][k] v = up[v][k]
return up[u][0]
|
为什么这样就能找到 LCA?
拉齐深度后,u 和 v 到 LCA 的距离为什么相等?
因为 LCA 是公共祖先,到 LCA 的距离 = 自身深度 - LCA 深度。深度相同,距离就相同。
“一起跳”为什么能精确停在 LCA 的下一层?
关键在 up[u][k] != up[v][k] 这个判断:
!=:2^k 级祖先还没汇合,还在 LCA 下面,可以安全跳
==:2^k 级祖先已经相同了,跳过去就越过 LCA 了,不跳
从大到小遍历 k,就是在用二进制逐步逼近 LCA:
1 2 3 4 5
| 距离 LCA 还有 5 步:5 = 4 + 1
k=2: 2^2=4, 跳了不会越过 → 跳 (剩 1 步) k=1: 2^1=2, 跳了会越过 → 不跳 k=0: 2^0=1, 跳了不会越过 → 跳 (剩 0 步)
|
循环结束后,u 和 v 刚好停在 LCA 的正下方,所以 up[u][0] 就是 LCA。
距离计算
1
| dist = depth[u] + depth[v] - 2 * depth[lca]
|
回到题目
每条边赋值 1 或 2,路径上 dist 条边共有 2^dist 种赋值方案。
权值和的奇偶性取决于有多少条边赋值为 1。每条边独立选择,恰好一半方案的权值和为奇数:
1
| 奇数和的方案数 = 2^dist / 2 = 2^(dist - 1)
|
所以只要快速算出 dist,就能 O(log n) 得到答案:
1 2 3
| lca = LCA(u, v) dist = depth[u] + depth[v] - 2 * depth[lca] ans = 2^(dist - 1) mod (10^9 + 7)
|
Go 实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
| package leetcode
const MOD int64 = 1_000_000_007
func modPow(a, b int64) int64 { res := int64(1) for b > 0 { if b&1 == 1 { res = res * a % MOD } a = a * a % MOD b >>= 1 } return res }
func assignEdgeWeights(edges [][]int, queries [][]int) []int { n := len(edges) + 1
graph := make([][]int, n+1) for _, edge := range edges { graph[edge[0]] = append(graph[edge[0]], edge[1]) graph[edge[1]] = append(graph[edge[1]], edge[0]) }
const LOG = 18
up := make([][LOG]int, n+1) q := []int{1} depth := make([]int, n+1) depth[1] = 1 for len(q) > 0 { size := len(q) for _, node := range q[:size] { for _, nextNode := range graph[node] { if depth[nextNode] == 0 { up[nextNode][0] = node depth[nextNode] = depth[node] + 1 q = append(q, nextNode) } } } q = q[size:] }
for i := 1; i < LOG; i++ { for u := 1; u <= n; u++ { up[u][i] = up[up[u][i-1]][i-1] } }
result := make([]int, 0, len(queries))
var lca func(u, v int) int
lca = func(u, v int) int { if depth[u] > depth[v] { diff := depth[u] - depth[v] for k := LOG - 1; k >= 0; k-- { if diff&(1<<k) != 0 { u = up[u][k] } } }
if depth[v] > depth[u] { diff := depth[v] - depth[u] for k := LOG - 1; k >= 0; k-- { if diff&(1<<k) != 0 { v = up[v][k] } } }
if u == v { return u }
for k := LOG - 1; k >= 0; k-- { if up[u][k] != up[v][k] { u = up[u][k] v = up[v][k] } }
return up[u][0] }
for _, querie := range queries { u, v := querie[0], querie[1]
if v == u { result = append(result, 0) continue }
distance := depth[u] + depth[v] - 2*depth[lca(u, v)]
result = append(result, int(modPow(2, int64(distance-1)))) }
return result }
|
总结
这道题的核心不在公式,而在如何快速求 dist。
树上两点距离 = 两点到 LCA 的距离之和,暴力求 LCA 每次 O(n),多次查询会 TLE。倍增 LCA 通过预处理 2^k 级别的跳跃表,把查询优化到 O(log n)——本质就是把”逐层爬树”变成”二进制拆分跳跃”。