链接

常规解法

容易想到使用记忆化搜索:

type TreeAncestor struct {
    memo [][]int // memo[node][k]
    parent []int
    n int
}


func Constructor(n int, parent []int) TreeAncestor {
    memo := make([][]int, n)
    for i := 0; i < n; i++{
        memo[i] = make([]int, n)
        for j := 0; j < n; j++ {
            memo[i][j] = -2
        }
    }

    return TreeAncestor{
        memo: memo,
        parent: parent,
        n: n,
    }
}


func (this *TreeAncestor) GetKthAncestor(node int, k int) int {
    if k >= this.n {
        return -1
    }
    if k == 0 {
        return node
    }
    if this.memo[node][k] != -2 {
        return this.memo[node][k]
    }

    ancestor := -1
    if father := this.parent[node]; father != -1 {
        ancestor = this.GetKthAncestor(father, k - 1)
    }

    this.memo[node][k] = ancestor

    return ancestor
}

这种做法会 OOM,即使优化内存,也会 TLE

倍增法

使用倍增法,思路如下:

预处理,存储距离每个节点为 1、2、4… 的节点

func Constructor(n int, parent []int) TreeAncestor {
    memo := make([][]int, n)
    for i, father := range parent {
        memo[i] = make([]int, 0)
        memo[i] = append(memo[i], father) // memo[i][0] = father, dis = 2^0 = 1
    }

    // 预处理
    dis := 1
    allNegative := false
    for !allNegative {
        allNegative = true

        // memo[node][dis] = memo[memo[node][dis - 1]][dis - 1] <= 先找到距离 node 为 2^(dis-1) 的节点 v1
        // 再找距离节点 v1 为 2^(dis-1) 的节点 v2
        for node := 0; node < n; node++ {
            v1 := memo[node][dis - 1]
            v2 := -1
            if v1 != -1 {
                v2 = memo[v1][dis - 1]
            }
            if v2 != -1 {
                allNegative = false
            }
            memo[node] = append(memo[node], v2)
        }

        dis++
    }

    return TreeAncestor{
        memo: memo,
    }
}

GetKthAncestor 时,可以按照 k 的二进制位来分解查询

例如,k = 7 = 1 + 2 + 4,二进制表示为:0000 0111

那么,可以:

  • 先寻找距离当前 node 为 1 的节点 v1
  • 再寻找距离 v1 为 2 的节点 v2
  • 最后寻找距离 v2 为 4 的节点 v3

那么,v3 就是我们要寻找的节点

相较于原来的暴力搜索(最坏需要 7 次查询),使用倍增法,只需要 3 次查询

这种方式,每次查询最多 32 次(当然本题的数据范围,最多查询 16 次,取决于 k 中 1 的个数)

完整代码如下:

type TreeAncestor struct {
    memo [][]int // memo[node][dis]: 距离 node 2^dis 的节点
}

func Constructor(n int, parent []int) TreeAncestor {
    memo := make([][]int, n)
    for i, father := range parent {
        memo[i] = make([]int, 0)
        memo[i] = append(memo[i], father) // memo[i][0] = father, dis = 2^0 = 1
    }

    // 预处理
    dis := 1
    allNegative := false
    for !allNegative {
        allNegative = true

        // memo[node][dis] = memo[memo[node][dis - 1]][dis - 1] <= 先找到距离 node 为 2^(dis-1) 的节点 v1
        // 再找距离节点 v1 为 2^(dis-1) 的节点 v2
        for node := 0; node < n; node++ {
            v1 := memo[node][dis - 1]
            v2 := -1
            if v1 != -1 {
                v2 = memo[v1][dis - 1]
            }
            if v2 != -1 {
                allNegative = false
            }
            memo[node] = append(memo[node], v2)
        }

        dis++
    }

    return TreeAncestor{
        memo: memo,
    }
}

func (this *TreeAncestor) GetKthAncestor(node int, k int) int {
    res := node
    dis := 0
    for k != 0 && res != -1 {
        if dis >= len(this.memo[res]) {
            return -1
        }
        if k & 1 != 0 {
            res = this.memo[res][dis]
        }
        dis++
        k >>= 1
    }
    return res
}

整体还是 动态规划 的思想

拓展

假设给你很多查询,每个查询都要寻找任意两个节点的 LCA,怎么办?

如果按照 LC.236 的方式,每次都去查询一次,肯定超时

我们还是可以利用倍增法来简化查询过程:

  • 先预处理一个 depth 数组,记录每个节点的深度
  • 对于每一个查询 (x, y),我们保证 y 的深度大于 x(可以 swap 一下)
  • 利用 GetKthAncestor 来得到距离 y 为 depth[y] - depth[x] 的节点 z
  • 这样,x、z 的深度相同,我们利用倍增法快速寻找 x、z 的 LCA:
    • dis 从一个较大值开始
    • 分别寻找距离 x、z 为 2^dis 的节点 v1、v2
      • 如果 v1 == -1,说明 dis 太大,我们让 dis–
      • 如果 v1 == v2,说明 v1 就是 x、z 的共同祖先,但不一定是最近的,我们让 dis–,看看能不能得到更近的
      • 如果 v1 != v2,说明 dis 太小,共同祖先还在上面,我们让 dis++(当然代码实现与这里有点不同)

可以发现,还有一点 二分 的思想

完整实现代码如下:

type TreeAncestor struct {
	memo  [][]int // memo[node][dis]: 距离 node 2^dis 的节点
	depth []int   // depth[node]: node 的深度
}

func Construct(n int, parent []int) TreeAncestor {
	return ConstructByParent(n, parent)
}

func ConstructByParent(n int, parent []int) TreeAncestor {
	edges := make([][]int, 0, n-1)
	for i := 0; i < n; i++ {
		if parent[i] != -1 {
			edges = append(edges, []int{parent[i], i})
		}
	}
	return ConstructByEdges(edges)
}

func ConstructByEdges(edges [][]int) TreeAncestor {
	n := len(edges) + 1
	graph := make([][]int, n)
	for _, edge := range edges {
		graph[edge[0]] = append(graph[edge[0]], edge[1])
		graph[edge[1]] = append(graph[edge[1]], edge[0])
	}

	// 初始化 depth 和 memo
	depth := make([]int, n)
	memo := make([][]int, n)
	var dfs func(cur, father int)

	dfs = func(cur, father int) {
		memo[cur] = make([]int, 0)
		memo[cur] = append(memo[cur], father) // memo[i][0] = father, dis = 2^0 = 1
		for _, next := range graph[cur] {
			if next != father {
				depth[next] = depth[cur] + 1
				dfs(next, cur)
			}
		}
	}
	dfs(0, -1)

	// 预处理 memo
	dis := 1
	allNegative := false
	for !allNegative {
		allNegative = true

		// memo[node][dis] = memo[memo[node][dis - 1]][dis - 1] <= 先找到距离 node 为 2^(dis-1) 的节点 v1
		// 再找距离节点 v1 为 2^(dis-1) 的节点 v2
		for node := 0; node < n; node++ {
			v1 := memo[node][dis-1]
			v2 := -1
			if v1 != -1 {
				v2 = memo[v1][dis-1]
			}
			if v2 != -1 {
				allNegative = false
			}
			memo[node] = append(memo[node], v2)
		}

		dis++
	}

	return TreeAncestor{
		depth: depth,
		memo:  memo,
	}
}

func (this *TreeAncestor) GetKthAncestor(node int, k int) int {
	res := node
	dis := 0
	for k != 0 && res != -1 {
		if dis >= len(this.memo[res]) {
			return -1
		}
		if k&1 != 0 {
			res = this.memo[res][dis]
		}
		dis++
		k >>= 1
	}
	return res
}

// 寻找任意两个节点的 LCA
func (this *TreeAncestor) GetLCA(x, y int) int {
	depthX, depthY := this.depth[x], this.depth[y]
	if depthX > depthY {
		// swap
		x, y = y, x
		depthX, depthY = depthY, depthX
	}

	// 将 x、y 置于同一层
	y = this.GetKthAncestor(y, depthY-depthX)
	if x == y {
		return x
	}

	// 逐层向上找(倍增)
	for dis := len(this.memo[x]) - 1; dis >= 0; dis-- {
		dx := this.memo[x][dis]
		dy := this.memo[y][dis]
		if dx == -1 { // dis 太大,减小 dis 重试(二分思想)
			continue
		}
        if dx == dy {
            // dis 太大,LCA 在下面,减小 dis
            continue
        }
		if dx != dy {
            // dis 太小,LCA 在上面,增加 dis
            // 这里并没有让 dis++,而是同时向上跳 2^dis 步
            // 留给下一次循环判断
			x, y = dx, dy
		}
	}

	return this.memo[x][0]
}

事实上,GetLCA 方法不仅适用于二叉树,还可以用于多叉树

参考资料