ACM
February 4, 2021

LCA

LCA

定义

对于有根树 T 的两个结点 u、v,最近公共祖先 LCA(T, u, v) 表示一个结点 x,满足 x 是 u 和 v 的祖先且 x 的深度尽可能大。在这里,一个节点也可以是它自己的祖先。

——百度百科

记某点集 $S \ = \ v_{1}, \ v_{2}, \ \cdots, \ v_{n}$ 的最近公共祖先为 $LCA(v_{1}, \ v_{2}, \ \cdots, \ v_{n})$ 或 $LCA(S)$

性质

  1. $LCA (u) \ = \ u$;
  2. u 是 v 的祖先,当且仅当 $LCA(u, \ v) \ = \ u$;
  3. 如果 u 不为 v 的祖先并且 v 不为 u 的祖先,那么 u, v 分别处于 $LCA(u, \ v)$ 的两棵不同子树中;
  4. 前序遍历中,$LCA(S)$ 出现在所有 S 中元素之前,后序遍历中 LCA(S) 则出现在所有 S 中元素之后;
  5. 两点集并的最近公共祖先为两点集分别的最近公共祖先的最近公共祖先,即 $LCA(A \ \cup \ B) \ = \ LCA(LCA(A), \ LCA(B))$;
  6. 两点的最近公共祖先必定处在树上两点间的最短路上;
  7. $d(u, \ v) \ = \ h(u) \ + \ h(v) \ - \ 2h(LCA(u, \ v))$,其中 d 是树上两点间的距离,h 代表某点到树根的距离。

求法

朴素算法

可以每次找深度比较大的那个点,让它向上跳。显然在树上,这两个点最后一定会相遇,相遇的位置就是想要求的 LCA。 或者先向上调整深度较大的点,令他们深度相同,然后再共同向上跳转,最后也一定会相遇。

朴素算法预处理时需要 dfs 整棵树,时间复杂度为 $O(n)$,单次查询时间复杂度为 $O(n)$。但由于随机树高为 $O(\log n)$,所以朴素算法在随机树上的单次查询时间复杂度为 $O(\log n)$。

倍增算法

倍增算法是最经典的 LCA 求法,他是朴素算法的改进算法。通过预处理 数组,游标可以快速移动,大幅减少了游标跳转次数。 表示点 的第 个祖先。 数组可以通过 dfs 预处理出来。

现在我们看看如何优化这些跳转: 在调整游标的第一阶段中,我们要将 两点跳转到同一深度。我们可以计算出 两点的深度之差,设其为 。通过将 进行二进制拆分,我们将 次游标跳转优化为「 的二进制表示所含 1 的个数」次游标跳转。 在第二阶段中,我们从最大的 开始循环尝试,一直尝试到 (包括 ),如果 ,则 ,那么最后的 LCA 为 。

倍增算法的预处理时间复杂度为 ,单次查询时间复杂度为 。 另外倍增算法可以通过交换 fa 数组的两维使较小维放在前面。这样可以减少 cache miss 次数,提高程序效率。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
#define MXN 50007
using namespace std;
std::vector<int> v[MXN];
std::vector<int> w[MXN];

int fa[MXN][31], cost[MXN][31], dep[MXN];
int n, m;
int a, b, c;
void dfs(int root, int fno) {
fa[root][0] = fno;
dep[root] = dep[fa[root][0]] + 1;
for (int i = 1; i < 31; ++i) {
fa[root][i] = fa[fa[root][i - 1]][i - 1];
cost[root][i] = cost[fa[root][i - 1]][i - 1] + cost[root][i - 1];
}
int sz = v[root].size();
for (int i = 0; i < sz; ++i) {
if (v[root][i] == fno) continue;
cost[v[root][i]][0] = w[root][i];
dfs(v[root][i], root);
}
}
int lca(int x, int y) {
if (dep[x] > dep[y]) swap(x, y);
int tmp = dep[y] - dep[x], ans = 0;
for (int j = 0; tmp; ++j, tmp >>= 1)
if (tmp & 1) ans += cost[y][j], y = fa[y][j];
if (y == x) return ans;
for (int j = 30; j >= 0 && y != x; --j) {
if (fa[x][j] != fa[y][j]) {
ans += cost[x][j] + cost[y][j];
x = fa[x][j];
y = fa[y][j];
}
}
ans += cost[x][0] + cost[y][0];
return ans;
}
int main() {
memset(fa, 0, sizeof(fa));
memset(cost, 0, sizeof(cost));
memset(dep, 0, sizeof(dep));
scanf("%d", &n);
for (int i = 1; i < n; ++i) {
scanf("%d %d %d", &a, &b, &c);
++a, ++b;
v[a].push_back(b);
v[b].push_back(a);
w[a].push_back(c);
w[b].push_back(c);
}
dfs(1, 0);
scanf("%d", &m);
for (int i = 0; i < m; ++i) {
scanf("%d %d", &a, &b);
++a, ++b;
printf("%d\n", lca(a, b));
}
return 0;
}

Tarjan · 离线

思路

Tarjan 算法 是一种 离线算法,需要使用 并查集 记录某个结点的祖先结点。做法如下:

  1. 首先接受输入(邻接链表)、查询(存储在另一个邻接链表内)。查询边其实是虚拟加上去的边,为了方便,每次输入查询边的时候,将这个边及其反向边都加入到 queryEdge 数组里。
  2. 然后对其进行一次 DFS 遍历,同时使用 visited 数组进行记录某个结点是否被访问过、parent 记录当前结点的父亲结点。
  3. 其中涉及到了 回溯思想,我们每次遍历到某个结点的时候,认为这个结点的根结点就是它本身。让以这个结点为根节点的 DFS 全部遍历完毕了以后,再将 这个结点的根节点 设置为 这个结点的父一级结点
  4. 回溯的时候,如果以该节点为起点,queryEdge 查询边的另一个结点也恰好访问过了,则直接更新查询边的 LCA 结果。
  5. 最后输出结果。

Tarjan 算法需要初始化并查集,所以预处理的时间复杂度为 $O(n)$,Tarjan 算法处理所有 m 次询问的时间复杂度为 $O(m \ + \ n)$。但是 Tarjan 算法的常数比倍增算法大。

需要注意的是,Tarjan 算法中使用的并查集性质比较特殊,在仅使用路径压缩优化的情况下,单次调用 find() 函数的时间复杂度为均摊 $O(1)$,而不是 $O(\log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <algorithm>
#include <iostream>
using namespace std;

class Edge
{
public:
int toVertex, fromVertex;
int next;
int LCA;
Edge() : toVertex(-1), fromVertex(-1), next(-1), LCA(-1) {};
Edge(int u, int v, int n) : fromVertex(u), toVertex(v), next(n), LCA(-1) {};
};

const int MAX = 100;
int head[MAX], queryHead[MAX];
Edge edge[MAX], queryEdge[MAX];
int parent[MAX], visited[MAX];
int vertexCount, edgeCount, queryCount;

void init()
{
for (int i = 0; i <= vertexCount; i++)
{
parent[i] = i;
}
}

int find(int x) {
if (parent[x] == x) {
return x;
}
else {
return find(parent[x]);
}
}

void tarjan(int u) {
parent[u] = u;
visited[u] = 1;

for (int i = head[u]; i != -1; i = edge[i].next) {
Edge& e = edge[i];
if (!visited[e.toVertex]) {
tarjan(e.toVertex);
parent[e.toVertex] = u;
}
}

for (int i = queryHead[u]; i != -1; i = queryEdge[i].next) {
Edge& e = queryEdge[i];
if (visited[e.toVertex]) {
queryEdge[i ^ 1].LCA = e.LCA = find(e.toVertex);
}
}
}

int main() {
memset(head, 0xff, sizeof(head));
memset(queryHead, 0xff, sizeof(queryHead));

cin >> vertexCount >> edgeCount >> queryCount;
int count = 0;
for (int i = 0; i < edgeCount; i++) {
int start = 0, end = 0;
cin >> start >> end;

edge[count] = Edge(start, end, head[start]);
head[start] = count;
count++;

edge[count] = Edge(end, start, head[end]);
head[end] = count;
count++;
}

count = 0;
for (int i = 0; i < queryCount; i++) {
int start = 0, end = 0;
cin >> start >> end;

queryEdge[count] = Edge(start, end, queryHead[start]);
queryHead[start] = count;
count++;

queryEdge[count] = Edge(end, start, queryHead[end]);
queryHead[end] = count;
count++;
}

init();
tarjan(1);

for (int i = 0; i < queryCount; i++) {
Edge& e = queryEdge[i * 2];
cout << "(" << e.fromVertex << "," << e.toVertex << ") " << e.LCA << endl;
}

return 0;
}

在线倍增

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 10;

int n, m, s;
struct lca {
int cnt, head[maxn];
struct edge { int to, next; } e[maxn << 1];
void add(int u, int v) {
e[++cnt] = { v, head[u] }; head[u] = cnt;
e[++cnt] = { u, head[v] }; head[v] = cnt;
}

int dep[maxn];
int lg[maxn];
int fa[maxn][22];

void init() {
for (int i = 1; i <= n; ++i) {
lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
// cout << i << " -> " << lg[i] << endl;
}
}

void dfs(int now, int pre) {
fa[now][0] = pre; dep[now] = dep[pre] + 1;
for (int i = 1; i <= lg[dep[now]]; ++i)
fa[now][i] = fa[fa[now][i - 1]][i - 1];
for (int i = head[now]; i; i = e[i].next)
if (e[i].to != pre) dfs(e[i].to, now);
}

int LCA(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
while (dep[x] > dep[y]) x = fa[x][lg[dep[x] - dep[y]] - 1];
if (x == y) return x;
for (int k = lg[dep[x]] - 1; k >= 0; --k)
if (fa[x][k] != fa[y][k])
x = fa[x][k], y = fa[y][k];
return fa[x][0];
}

} Lca;
int main() {
scanf("%d%d%d", &n, &m, &s);
for (int i = 1; i < n; ++i) {
int x, y;
scanf("%d%d", &x, &y);
Lca.add(x, y);
}
Lca.init();
Lca.dfs(s, 0);
for (int i = 1; i <= m; ++i) {
int x, y;
scanf("%d%d", &x, &y);
printf("%d\n", Lca.LCA(x, y));
}
return 0;
}


Tarjan

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 10;

int n, m, r;
int ans[maxn];
struct LCA {
int vis[maxn];

int s[maxn];
void init() { for (int i = 1; i <= n; ++i) s[i] = i; }
int find(int x) { return s[x] == x ? x : s[x] = find(s[x]); }

int tot, first[maxn];
struct query { int to, next; } q[maxn << 1];
void insert(int x, int y) {
q[++tot] = { y, first[x] }; first[x] = tot;
q[++tot] = { x, first[y] }; first[y] = tot;
}

int cnt, head[maxn];
struct edge { int to, next; } e[maxn << 1];
void add(int u, int v) {
e[++cnt] = { v, head[u] }; head[u] = cnt;
e[++cnt] = { u, head[v] }; head[v] = cnt;
}

void Tarjan(int u, int fa) {
for (int i = head[u]; i; i = e[i].next) {
if (e[i].to == fa) continue;
Tarjan(e[i].to, u);
s[e[i].to] = u;
}
for (int i = first[u]; i; i = q[i].next) {
if (!vis[q[i].to]) continue;
ans[(i + 1) / 2] = find(q[i].to);
}
vis[u] = 1;
}
} Lca;

int main() {
scanf("%d%d%d", &n, &m, &r);
Lca.init();
for (int i = 1; i < n; ++i) {
int x, y;
scanf("%d%d", &x, &y);
Lca.add(x, y);
}
for (int i = 1; i <= m; ++i) {
int x, y;
scanf("%d%d", &x, &y);
Lca.insert(x, y);
}
Lca.Tarjan(r, -1);
for (int i = 1; i <= m; ++i) printf("%d\n", ans[i]);
return 0;
}


About this Post

This post is written by OwlllOvO, licensed under CC BY-NC 4.0.

#C++#LCA