常规解法
容易想到使用记忆化搜索:
type TreeAncestor struct {
memo [][]int // memo[node][k]
parent []int
n int
}
func Constructor(n int, parent []int) TreeAncestor {
memo := make([][]int, n)
for i := 0; i < n; i++{
memo[i] = make([]int, n)
for j := 0; j < n; j++ {
memo[i][j] = -2
}
}
return TreeAncestor{
memo: memo,
parent: parent,
n: n,
}
}
func (this *TreeAncestor) GetKthAncestor(node int, k int) int {
if k >= this.n {
return -1
}
if k == 0 {
return node
}
if this.memo[node][k] != -2 {
return this.memo[node][k]
}
ancestor := -1
if father := this.parent[node]; father != -1 {
ancestor = this.GetKthAncestor(father, k - 1)
}
this.memo[node][k] = ancestor
return ancestor
}
这种做法会 OOM,即使优化内存,也会 TLE
倍增法
使用倍增法,思路如下:
预处理,存储距离每个节点为 1、2、4… 的节点
func Constructor(n int, parent []int) TreeAncestor {
memo := make([][]int, n)
for i, father := range parent {
memo[i] = make([]int, 0)
memo[i] = append(memo[i], father) // memo[i][0] = father, dis = 2^0 = 1
}
// 预处理
dis := 1
allNegative := false
for !allNegative {
allNegative = true
// memo[node][dis] = memo[memo[node][dis - 1]][dis - 1] <= 先找到距离 node 为 2^(dis-1) 的节点 v1
// 再找距离节点 v1 为 2^(dis-1) 的节点 v2
for node := 0; node < n; node++ {
v1 := memo[node][dis - 1]
v2 := -1
if v1 != -1 {
v2 = memo[v1][dis - 1]
}
if v2 != -1 {
allNegative = false
}
memo[node] = append(memo[node], v2)
}
dis++
}
return TreeAncestor{
memo: memo,
}
}
GetKthAncestor 时,可以按照 k 的二进制位来分解查询
例如,k = 7 = 1 + 2 + 4,二进制表示为:0000 0111
那么,可以:
- 先寻找距离当前 node 为 1 的节点 v1
- 再寻找距离 v1 为 2 的节点 v2
- 最后寻找距离 v2 为 4 的节点 v3
那么,v3 就是我们要寻找的节点
相较于原来的暴力搜索(最坏需要 7 次查询),使用倍增法,只需要 3 次查询
这种方式,每次查询最多 32 次(当然本题的数据范围,最多查询 16 次,取决于 k 中 1 的个数)
完整代码如下:
type TreeAncestor struct {
memo [][]int // memo[node][dis]: 距离 node 2^dis 的节点
}
func Constructor(n int, parent []int) TreeAncestor {
memo := make([][]int, n)
for i, father := range parent {
memo[i] = make([]int, 0)
memo[i] = append(memo[i], father) // memo[i][0] = father, dis = 2^0 = 1
}
// 预处理
dis := 1
allNegative := false
for !allNegative {
allNegative = true
// memo[node][dis] = memo[memo[node][dis - 1]][dis - 1] <= 先找到距离 node 为 2^(dis-1) 的节点 v1
// 再找距离节点 v1 为 2^(dis-1) 的节点 v2
for node := 0; node < n; node++ {
v1 := memo[node][dis - 1]
v2 := -1
if v1 != -1 {
v2 = memo[v1][dis - 1]
}
if v2 != -1 {
allNegative = false
}
memo[node] = append(memo[node], v2)
}
dis++
}
return TreeAncestor{
memo: memo,
}
}
func (this *TreeAncestor) GetKthAncestor(node int, k int) int {
res := node
dis := 0
for k != 0 && res != -1 {
if dis >= len(this.memo[res]) {
return -1
}
if k & 1 != 0 {
res = this.memo[res][dis]
}
dis++
k >>= 1
}
return res
}
整体还是 动态规划 的思想
拓展
假设给你很多查询,每个查询都要寻找任意两个节点的 LCA,怎么办?
如果按照 LC.236 的方式,每次都去查询一次,肯定超时
我们还是可以利用倍增法来简化查询过程:
- 先预处理一个 depth 数组,记录每个节点的深度
- 对于每一个查询
(x, y),我们保证 y 的深度大于 x(可以 swap 一下) - 利用 GetKthAncestor 来得到距离 y 为
depth[y] - depth[x]的节点 z - 这样,x、z 的深度相同,我们利用倍增法快速寻找 x、z 的 LCA:
- dis 从一个较大值开始
- 分别寻找距离 x、z 为 2^dis 的节点 v1、v2
- 如果 v1 == -1,说明 dis 太大,我们让 dis–
- 如果 v1 == v2,说明 v1 就是 x、z 的共同祖先,但不一定是最近的,我们让 dis–,看看能不能得到更近的
- 如果 v1 != v2,说明 dis 太小,共同祖先还在上面,我们让 dis++(当然代码实现与这里有点不同)
可以发现,还有一点 二分 的思想
完整实现代码如下:
type TreeAncestor struct {
memo [][]int // memo[node][dis]: 距离 node 2^dis 的节点
depth []int // depth[node]: node 的深度
}
func Construct(n int, parent []int) TreeAncestor {
return ConstructByParent(n, parent)
}
func ConstructByParent(n int, parent []int) TreeAncestor {
edges := make([][]int, 0, n-1)
for i := 0; i < n; i++ {
if parent[i] != -1 {
edges = append(edges, []int{parent[i], i})
}
}
return ConstructByEdges(edges)
}
func ConstructByEdges(edges [][]int) TreeAncestor {
n := len(edges) + 1
graph := make([][]int, n)
for _, edge := range edges {
graph[edge[0]] = append(graph[edge[0]], edge[1])
graph[edge[1]] = append(graph[edge[1]], edge[0])
}
// 初始化 depth 和 memo
depth := make([]int, n)
memo := make([][]int, n)
var dfs func(cur, father int)
dfs = func(cur, father int) {
memo[cur] = make([]int, 0)
memo[cur] = append(memo[cur], father) // memo[i][0] = father, dis = 2^0 = 1
for _, next := range graph[cur] {
if next != father {
depth[next] = depth[cur] + 1
dfs(next, cur)
}
}
}
dfs(0, -1)
// 预处理 memo
dis := 1
allNegative := false
for !allNegative {
allNegative = true
// memo[node][dis] = memo[memo[node][dis - 1]][dis - 1] <= 先找到距离 node 为 2^(dis-1) 的节点 v1
// 再找距离节点 v1 为 2^(dis-1) 的节点 v2
for node := 0; node < n; node++ {
v1 := memo[node][dis-1]
v2 := -1
if v1 != -1 {
v2 = memo[v1][dis-1]
}
if v2 != -1 {
allNegative = false
}
memo[node] = append(memo[node], v2)
}
dis++
}
return TreeAncestor{
depth: depth,
memo: memo,
}
}
func (this *TreeAncestor) GetKthAncestor(node int, k int) int {
res := node
dis := 0
for k != 0 && res != -1 {
if dis >= len(this.memo[res]) {
return -1
}
if k&1 != 0 {
res = this.memo[res][dis]
}
dis++
k >>= 1
}
return res
}
// 寻找任意两个节点的 LCA
func (this *TreeAncestor) GetLCA(x, y int) int {
depthX, depthY := this.depth[x], this.depth[y]
if depthX > depthY {
// swap
x, y = y, x
depthX, depthY = depthY, depthX
}
// 将 x、y 置于同一层
y = this.GetKthAncestor(y, depthY-depthX)
if x == y {
return x
}
// 逐层向上找(倍增)
for dis := len(this.memo[x]) - 1; dis >= 0; dis-- {
dx := this.memo[x][dis]
dy := this.memo[y][dis]
if dx == -1 { // dis 太大,减小 dis 重试(二分思想)
continue
}
if dx == dy {
// dis 太大,LCA 在下面,减小 dis
continue
}
if dx != dy {
// dis 太小,LCA 在上面,增加 dis
// 这里并没有让 dis++,而是同时向上跳 2^dis 步
// 留给下一次循环判断
x, y = dx, dy
}
}
return this.memo[x][0]
}
事实上,GetLCA 方法不仅适用于二叉树,还可以用于多叉树