简介

二叉搜索树(BST)用于快速增删改查数据,对外部可以看作一个集合。BST 的平均时间复杂度是 O(logn)O(\log{n}),但是遗憾的是其最坏时间复杂度是 O(n)O(n) 的。为了解决 BST 的这个问题,之后会介绍多种平衡树。平衡树通过不同的维护方式,保证了其始终有 O(logn)O(\log{n}) 的时间复杂度,BST 是所有平衡树的基础,同时其多数操作在多数平衡树中是通用的。

模板题链接

只有最后一组数据会 TLE。

洛谷 P3369 【模板】普通平衡树

核心思想

定义: 对于一棵有根树的任意节点,如果满足左子树所有节点的权值均小于该节点的权值,且右子树均大于该值,则称之为 BST。即对于任意节点 xx 有权 max{kl}<k<min{kr}\max\{k_l\} < k < \min\{k_r\},其中 kk 表示节点 xx 的权值,{kl}\{k_l\} 表示 xx 左子树权值的集合。一般不允许 BST 包含相同权值的元素。

性质:

  • max{kl}<k<min{kr}\max\{k_l\} < k < \min\{k_r\}

  • BST 的中序遍历即为树上所有元素的升序排列。

  • 对于树上任意节点,其子树表示升序排列中的一段连续区间。

根据 BST 的定义,在 BST 上查找元素时,只需从根搜索,遇到当节点大于时向左儿子搜索,小于时向右儿子搜索,直到等于或没有儿子时返回,时间复杂度即为元素在树上的深度,而最坏时间复杂度即为最大节点深度。由于二叉树每层可以容纳的节点数是指数增长的,因此树的平均深度为 logn\log{n}。但又由于 BST 无法保证树深稳定为 logn\log{n} 因此无法保证稳定的时间复杂度,有此引出了平衡树。

平衡树的定义: 对于一棵二叉搜索树的任意节点,如果其左右子树大小差始终不超过 11,则称之平衡树。由于维护一棵平衡树的时间代价十分高,通常称所有左右子树高度差具有一定约束的都为平衡树。

代码实现

结构定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// 二叉搜索树
struct BST {
int key; // 键
int cnt; // 计数
int father; // 父节点
int next[2]; // 左右子节点
int sum; // 子树求和, 用于求元素排名
} tr[MAX];
int cnt_tr = 1;
int root = 0;

#define fa(idx) tr[idx].father
#define ls(idx) tr[idx].next[0]
#define rs(idx) tr[idx].next[1]
#define who(idx) ((idx) == rs(fa(idx))) // 是否为父的右儿子

// 维护子树求和 (单步)
inline void up(int idx) { tr[idx].sum = tr[ls(idx)].sum + tr[rs(idx)].sum + tr[idx].cnt; }

// 维护子树求和
inline void maintain(int idx) {
while (idx) {
up(idx);
idx = fa(idx);
}
}

查找元素

从根搜索,遇到当节点大于时向左儿子搜索,小于时向右儿子搜索,直到等于或没有儿子时返回。

1
2
3
4
5
6
7
8
9
10
11
12
// 元素索引, 元素不存在返回 -1. 时间: O(logn)
iter find(int x) {
if (root == 0)
return -1;
int idx = root;
while (idx) {
if (x == tr[idx].key)
return idx;
idx = tr[idx].next[x > tr[idx].key];
}
return -1;
}

插入元素

搜索元素在树上的位置,如果元素存在则直接返回,否则当无法继续向下搜索时插入元素到搜索的方向。

其中 maintain 函数用于维护树上的一些值,此处用于维护子树求和,用于求元素排名。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
// 插入元素. 时间: O(logn)
void add(int x) {
// case 1: 空树
if (root == 0) {
int nxt = cnt_tr++;
tr[nxt].key = x;
tr[nxt].cnt = 1;
tr[nxt].father = 0;
root = nxt;
maintain(nxt);
return;
}
int idx = root;
while (true) {
// case 2: 元素已存在
if (x == tr[idx].key) {
tr[idx].cnt++;
maintain(idx);
return;
}
// case 3: 元素不存在
if (tr[idx].next[x > tr[idx].key] == 0) {
int nxt = cnt_tr++;
tr[idx].next[x > tr[idx].key] = nxt;
tr[nxt].key = x;
tr[nxt].cnt = 1;
tr[nxt].father = idx;
maintain(nxt);
return;
}
idx = tr[idx].next[x > tr[idx].key];
}
}

删除元素

删除元素是最为复杂的操作。如果元素存在,要将其分为 33 类:

  • 元素为叶子节点。直接删除即可。
  • 元素为链节点,即没有左儿子或右儿子。删除节点后将其唯一儿子连接到父节点。
  • 元素左右儿子均存在。则其前驱或后继一定在其子树中,与其前驱或后继交换值后,删除其前驱或后继。前驱或后继一定属于上述两种情况。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
void del_idx(iter idx) {
// case 1: idx 处在一条链 (没有左儿子) / idx 是叶子节点
if (ls(idx) == 0) {
tr[fa(idx)].next[who(idx)] = rs(idx);
tr[rs(idx)].father = fa(idx);
if (fa(idx) == 0)
root = rs(idx);
maintain(fa(idx));
return;
}
// case 2: idx 处在一条链 (没有右儿子)
if (rs(idx) == 0) {
tr[fa(idx)].next[who(idx)] = ls(idx);
tr[ls(idx)].father = fa(idx);
if (fa(idx) == 0)
root = ls(idx);
maintain(fa(idx));
return;
}
// case 3: 否则, 寻找后继交换后删除
int pre = idx;
idx = rs(idx);
while (ls(idx))
idx = ls(idx);
swap(tr[pre].key, tr[idx].key);
swap(tr[pre].cnt, tr[idx].cnt);
del_idx(idx);
}

// 删除元素, 元素不存在返回 -1. 时间: O(logn)
bool del(int x) {
// case 1: 元素不存在
if (root == 0)
return false;
int idx = root;
while (idx) {
if (x == tr[idx].key) {
tr[idx].cnt--;
if (tr[idx].cnt) {
maintain(idx);
return true;
}
// case 2: 元素存在
del_idx(idx);
return true;
}
idx = tr[idx].next[x > tr[idx].key];
}
return false;
}

前驱后继

以查找后继为例,需要分为 22 类:

  • 如果节点存在右儿子,则后继一定在右子树上,为右子树最小节点。
  • 否则,后继一定是其祖先节点,不断向父节点搜索,找到第一个大于的点即为后继。
  • 最后一个节点可能没有后继,依然按照上一种情况处理。

前驱同理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// 元素前驱索引. 特别地, 无前驱返回 -1, -1 前驱返回最大元素索引. 时间: O(logn)
iter pre(iter idx) {
if (idx == -1)
return maxi();
if (ls(idx))
return maxi(ls(idx));
while (idx != root) {
if (who(idx) == 1)
return fa(idx);
idx = fa(idx);
}
return -1;
}

// 元素后继索引. 特别地, 无后继返回 -1, -1 后继返回最小元素索引. 时间: O(logn)
iter nxt(iter idx) {
if (idx == -1)
return mini();
if (rs(idx))
return mini(rs(idx));
while (idx != root) {
if (who(idx) == 0)
return fa(idx);
idx = fa(idx);
}
return -1;
}

快速建树

在已知初始状态包含哪些元素时,可以 O(n)O(n) 快速建树(除去排序)。每次将区间中位数作为根节点即可建一棵相对平衡的树。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// 建树, 将形成一棵平衡树, 0-index. 时间: O(nlogn)
pair<int, int> a[MAX]; // {key, cnt}
void build(int l, int r, int pre) {
if (l > r) return;
int nxt = cnt_tr++, mid = (l + r) / 2;
tr[pre].next[a[mid].first > tr[pre].key] = nxt;
tr[nxt].key = a[mid].first;
tr[nxt].cnt = a[mid].second;
tr[nxt].father = pre;
build(l, mid - 1, nxt);
build(mid + 1, r, nxt);
if (nxt) up(nxt);
}
void build(int n) {
if (n == 0) return;
sort(a, a + n); // 瓶颈
int pre = -INF - 1, cnt = 0;
for (int i = 0; i < n; i++) {
if (a[i].first != pre) {
pre = a[i].first;
a[cnt++] = {pre, 0};
}
a[cnt - 1].second++;
}
build(0, cnt - 1, 0);
root = 1;
}

元素排名

求元素排名需要维护子树求和 tr[idx].sum,通过与查找元素相同的方式查找,只需在当前指针每次向右儿子转移时,累加左子树求和与当前节点大小即可,最终得到的数即为排名。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// 元素 x 排名, x 不存在返回 -1. 1-index, 时间: O(logn)
int rk(int x) {
if (root == 0)
return -1;
int idx = root, cnt = 0;
while (idx) {
if (x < tr[idx].key) {
idx = ls(idx);
} else {
cnt += tr[ls(idx)].sum;
if (x == tr[idx].key)
return cnt + 1;
cnt += tr[idx].cnt;
idx = rs(idx);
}
}
return -1;
}

第 k 元素

与元素排名相似,需要维护子树求和 tr[idx].sum。查找时如果 rk 大于左子树与当前节点求和,则向右子树查找,rk 减去已经抵消的排名即可;如果 rk 仅大于左子树求和,则返回当前节点;否则,向左子树查找。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// 排名 rk 索引, 总数不足 rk 返回 -1. 1-index, 时间: O(logn)
iter kth(int rk) {
if (tr[root].sum < rk)
return -1;
int idx = root;
while (true) {
if (rk <= tr[ls(idx)].sum) {
idx = ls(idx);
} else {
rk -= tr[ls(idx)].sum + tr[idx].cnt;
if (rk <= 0)
return idx;
idx = rs(idx);
}
}
}

完整代码

二叉搜索树无法直接应用,但其多数代码片段依然可以在多数平衡树中使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
// Jamhus Tao / GreatLiangpi
// Start at 2022/10/6
// Please using int32_t and int64_t to replace the int and long long.
#include <bits/stdc++.h>
#define int int64_t
#define endl '\n'
#pragma GCC optimize(3, "Ofast", "inline")
using namespace std;

// 二叉搜索树
const int MAX = 100005;
const int INF = 0x3f3f3f3f3f3f3f3fll;

// 代码实现约定:
// tr[0].sum 始终为 0, 其余随意; tr[root].father 始终为 0; root == 0 表示空树, next == 0 表示无儿子
struct BST {
int key; // 键
int cnt; // 计数
int father; // 父节点
int next[2]; // 左右子节点
int sum; // 子树求和
} tr[MAX];
int cnt_tr = 1;
int root = 0;

#define fa(idx) tr[idx].father
#define ls(idx) tr[idx].next[0]
#define rs(idx) tr[idx].next[1]
#define who(idx) ((idx) == rs(fa(idx))) // 是否为父的右儿子

// 清空
void clear() {
root = 0;
memset(tr, 0, sizeof(BST) * cnt_tr);
cnt_tr = 1;
}

// 维护子树求和 (单步)
inline void up(int idx) { tr[idx].sum = tr[ls(idx)].sum + tr[rs(idx)].sum + tr[idx].cnt; }

// 维护子树求和
inline void maintain(int idx) {
while (idx) {
up(idx);
idx = fa(idx);
}
}

typedef int iter; // 索引, 避免与值混淆

// 建树, 将形成一棵平衡树, 0-index. 时间: O(nlogn)
pair<int, int> a[MAX]; // {key, cnt}
void build(int l, int r, int pre) {
if (l > r) return;
int nxt = cnt_tr++, mid = (l + r) / 2;
tr[pre].next[a[mid].first > tr[pre].key] = nxt;
tr[nxt].key = a[mid].first;
tr[nxt].cnt = a[mid].second;
fa(nxt) = pre;
build(l, mid - 1, nxt);
build(mid + 1, r, nxt);
if (nxt) up(nxt);
}
void build(int n) {
if (n == 0) return;
sort(a, a + n); // 瓶颈
int pre = -INF - 1, cnt = 0;
for (int i = 0; i < n; i++) {
if (a[i].first != pre) {
pre = a[i].first;
a[cnt++] = {pre, 0};
}
a[cnt - 1].second++;
}
build(0, cnt - 1, 0);
root = 1;
}

// 遍历. 时间: O(n)
iter walk_list[MAX];
int walk_cnt = 0;
void walk(iter idx) {
if (idx == 0) return;
walk(ls(idx));
walk_list[walk_cnt++] = idx;
walk(rs(idx));
}

// 元素索引, 元素不存在返回 -1. 时间: O(logn)
iter find(int x) {
if (root == 0)
return -1;
int idx = root;
while (idx) {
if (x == tr[idx].key)
return idx;
idx = tr[idx].next[x > tr[idx].key];
}
return -1;
}

// 最小元素索引, 空树返回 -1. 时间: O(logn)
iter mini(iter idx = root) {
if (root == 0)
return -1;
while (ls(idx))
idx = ls(idx);
return idx;
}

// 最大元素索引, 空树返回 -1. 时间: O(logn)
iter maxi(iter idx = root) {
if (root == 0)
return -1;
while (rs(idx))
idx = rs(idx);
return idx;
}

// 元素前驱索引. 特别地, 无前驱返回 -1, -1 前驱返回最大元素索引. 时间: O(logn)
iter pre(iter idx) {
if (idx == -1)
return maxi();
if (ls(idx))
return maxi(ls(idx));
while (idx != root) {
if (who(idx) == 1)
return fa(idx);
idx = fa(idx);
}
return -1;
}

// 元素后继索引. 特别地, 无后继返回 -1, -1 后继返回最小元素索引. 时间: O(logn)
iter nxt(iter idx) {
if (idx == -1)
return mini();
if (rs(idx))
return mini(rs(idx));
while (idx != root) {
if (who(idx) == 0)
return fa(idx);
idx = fa(idx);
}
return -1;
}

// 插入元素. 时间: O(logn)
void add(int x) {
// case 1: 空树
if (root == 0) {
int nxt = cnt_tr++;
tr[nxt].key = x;
tr[nxt].cnt = 1;
fa(nxt) = 0;
root = nxt;
maintain(nxt);
return;
}
int idx = root;
while (true) {
// case 2: 元素已存在
if (x == tr[idx].key) {
tr[idx].cnt++;
maintain(idx);
return;
}
// case 3: 元素不存在
if (tr[idx].next[x > tr[idx].key] == 0) {
int nxt = cnt_tr++;
tr[idx].next[x > tr[idx].key] = nxt;
tr[nxt].key = x;
tr[nxt].cnt = 1;
fa(nxt) = idx;
maintain(nxt);
return;
}
idx = tr[idx].next[x > tr[idx].key];
}
}

// 删除索引, 直接删除索引而非计数 -1. 时间: O(logn)
void del_idx(iter idx) {
// case 1: idx 处在一条链 (没有左儿子) / idx 是叶子节点
if (ls(idx) == 0) {
tr[fa(idx)].next[who(idx)] = rs(idx);
fa(rs(idx)) = fa(idx);
if (fa(idx) == 0)
root = rs(idx);
maintain(fa(idx));
return;
}
// case 2: idx 处在一条链 (没有右儿子)
if (rs(idx) == 0) {
tr[fa(idx)].next[who(idx)] = ls(idx);
fa(ls(idx)) = fa(idx);
if (fa(idx) == 0)
root = ls(idx);
maintain(fa(idx));
return;
}
// case 3: 否则, 寻找后继交换后删除
int nxt = rs(idx);
while (ls(nxt))
nxt = ls(nxt);
// 交换两点
tr[fa(idx)].next[who(idx)] = nxt;
fa(ls(idx)) = nxt;
fa(rs(nxt)) = idx;
int &tmp = tr[fa(nxt)].next[who(nxt)];
fa(rs(idx)) = nxt;
tmp = idx;
swap(fa(idx), fa(nxt));
swap(ls(idx), ls(nxt));
swap(rs(idx), rs(nxt));
if (root == idx)
root = nxt;
del_idx(idx);
}

// 删除元素, 元素不存在返回 false. 时间: O(logn)
bool del(int x) {
// case 1: 元素不存在
if (root == 0)
return false;
int idx = root;
while (idx) {
if (x == tr[idx].key) {
tr[idx].cnt--;
if (tr[idx].cnt) {
maintain(idx);
return true;
}
// case 2: 元素存在
del_idx(idx);
return true;
}
idx = tr[idx].next[x > tr[idx].key];
}
return false;
}

// 元素 x 排名, x 不存在返回 -1. 1-index, 时间: O(logn)
int rk(int x) {
if (root == 0)
return -1;
int idx = root, cnt = 0;
while (idx) {
if (x < tr[idx].key) {
idx = ls(idx);
} else {
cnt += tr[ls(idx)].sum;
if (x == tr[idx].key)
return cnt + 1;
cnt += tr[idx].cnt;
idx = rs(idx);
}
}
return -1;
}

// 排名 rk 索引, 总数不足 rk 返回 -1. 1-index, 时间: O(logn)
iter kth(int rk) {
if (tr[root].sum < rk)
return -1;
int idx = root;
while (true) {
if (rk <= tr[ls(idx)].sum) {
idx = ls(idx);
} else {
rk -= tr[ls(idx)].sum + tr[idx].cnt;
if (rk <= 0)
return idx;
idx = rs(idx);
}
}
}

// 第一个大于等于 x 的元素索引, 不存在返回 -1. 时间: O(logn)
iter lower_bound(int x) {
if (root == 0)
return -1;
int idx = root;
while (true) {
if (x == tr[idx].key)
return idx;
if (x < tr[idx].key) {
if (ls(idx) == 0)
return idx;
idx = ls(idx);
} else {
if (rs(idx) == 0)
return nxt(idx);
idx = rs(idx);
}
}
}

// 第一个严格大于 x 的元素索引, 不存在返回 -1. 时间: O(logn)
iter upper_bound(int x) {
if (root == 0)
return -1;
int idx = root;
while (true) {
if (x == tr[idx].key)
return nxt(idx);
if (x < tr[idx].key) {
if (ls(idx) == 0)
return idx;
idx = ls(idx);
} else {
if (rs(idx) == 0)
return nxt(idx);
idx = rs(idx);
}
}
}