简介

AVL 是最早发明的平衡树,通过旋转维护任意点的左右子树高度差小于 11,从而实现 O(logn)O(\log{n}) 的时间复杂度。

模板题链接

洛谷 P3369 【模板】普通平衡树

左旋与右旋

维护二叉搜索树平衡,左旋 (zag) 和右旋 (zig) 是其基本操作。

1
2
3
4
5
6
     |                   |                    |
A B A
/ \ Zig(A) / \ Zag(B) / \
B [3] ========> [1] A ========> B [3]
/ \ / \ / \
[1] [2] [2] [3] [1] [2]

核心思想

二叉搜索树只有在插入和删除节点时才会改变节点之间关系,假设之前始终维护 AVL 树性质,每次改变后最多使左右子树差为 22

AVL44 类处理这个问题,由于左右对称只需讨论其中 22 类。

情况一:

若节点 AA 不满足条件,且最深子树为 LeftLeftLeft-Left ,右旋 AA 一次即可。

1
2
3
4
5
6
    |                   |
A B
/ Zig(A) / \
B ========> C A
/
C

情况二:

若节点 AA 不满足条件,且最深子树为 LeftRightLeft-Right,左旋 BB 一次,再右旋 AA 一次即可。

1
2
3
4
5
6
  |                   |                   |
A A C
/ Zag(B) / Zig(A) / \
B ========> C ========> B A
\ /
C B

时间复杂度:

AVL 维护了左右子树深度差小于 11,由此对于深度为 xxAVL 树其最少节点数为 fib(x+2)1\text{fib}(x+2)-1,其中 fib\text{fib} 为斐波那契数列,fib(0)=fib(1)=1\text{fib}(0)=\text{fib}(1)=1 。斐波那契数列是指数增长的,因此 AVL 数的最大树深为对数级别。

完整代码

二叉搜索树无法直接应用,但其多数代码片段依然可以在多数平衡树中使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
// Jamhus Tao / GreatLiangpi
// Start at 2022/10/6
// Please using int32_t and int64_t to replace the int and long long.
#include <bits/stdc++.h>
#define int int64_t
#define endl '\n'
#pragma GCC optimize(3, "Ofast", "inline")
using namespace std;

// AVL
const int MAX = 100005;
const int INF = 0x3f3f3f3f3f3f3f3fll;

// 代码实现约定:
// tr[0].sum, tr[0].depth 始终为 0, 其余随意; tr[root].father 始终为 0; root == 0 表示空树, next == 0 表示无儿子
struct AVL {
int key; // 键
int cnt; // 计数
int father; // 父节点
int next[2]; // 左右子节点
int sum; // 子树求和
int depth; // 子树深度
} tr[MAX];
int cnt_tr = 1;
int root = 0;

#define fa(idx) tr[idx].father
#define ls(idx) tr[idx].next[0]
#define rs(idx) tr[idx].next[1]
#define who(idx) ((idx) == rs(fa(idx))) // 是否为父的右儿子

// 清空
void clear() {
root = 0;
memset(tr, 0, sizeof(AVL) * cnt_tr);
cnt_tr = 1;
}

// 维护子树求和. x != 0
inline void up(int idx) {
tr[idx].sum = tr[ls(idx)].sum + tr[rs(idx)].sum + tr[idx].cnt;
tr[idx].depth = max(tr[ls(idx)].depth, tr[rs(idx)].depth) + 1;
}

// 旋转 x 为根的子树, 同时更新 x, w == 0 表示右旋 (使左儿子上升), w == 1 表示左旋. x != 0, tr[x].next[w] != 0
inline void rotate(int &x, bool w) {
int f = fa(x), s = tr[x].next[w], ss = tr[s].next[not w];
tr[f].next[who(x)] = s;
fa(s) = f;
tr[x].next[w] = ss;
fa(ss) = x;
tr[s].next[not w] = x;
fa(x) = s;
swap(x, s);
up(s);
up(x);
}

// AVL 维护平衡
inline void maintain(int idx) {
while (idx) {
if (abs(tr[ls(idx)].depth - tr[rs(idx)].depth) < 2) {
up(idx);
} else {
bool w = tr[ls(idx)].depth < tr[rs(idx)].depth;
int s = tr[idx].next[w];
if ((tr[ls(idx)].depth - tr[rs(idx)].depth) * (tr[ls(s)].depth - tr[rs(s)].depth) < 0)
rotate(s, not w);
rotate(idx, w);
if (fa(idx) == 0)
root = idx;
}
idx = fa(idx);
}
}

typedef int iter; // 索引, 避免与值混淆

// 建树, 建立平衡树, 0-index. 时间: O(nlogn)
pair<int, int> a[MAX]; // {key, cnt}
void build(int l, int r, int pre) {
if (l > r) return;
int nxt = cnt_tr++, mid = (l + r) / 2;
tr[pre].next[a[mid].first > tr[pre].key] = nxt;
tr[nxt].key = a[mid].first;
tr[nxt].cnt = a[mid].second;
fa(nxt) = pre;
build(l, mid - 1, nxt);
build(mid + 1, r, nxt);
if (nxt) up(nxt);
}
void build(int n) {
if (n == 0) return;
sort(a, a + n); // 瓶颈
int pre = -INF - 1, cnt = 0;
for (int i = 0; i < n; i++) {
if (a[i].first != pre) {
pre = a[i].first;
a[cnt++] = {pre, 0};
}
a[cnt - 1].second++;
}
build(0, cnt - 1, 0);
root = 1;
}

// 遍历. 时间: O(n)
iter walk_list[MAX];
int walk_cnt = 0;
void walk(iter idx) {
if (idx == 0)
return;
walk(ls(idx));
walk_list[walk_cnt++] = idx;
walk(rs(idx));
}

// 元素索引, 元素不存在返回 -1. 时间: O(logn)
iter find(int x) {
if (root == 0)
return -1;
int idx = root;
while (idx) {
if (x == tr[idx].key)
return idx;
idx = tr[idx].next[x > tr[idx].key];
}
return -1;
}

// 最小元素索引, 空树返回 -1. 时间: O(logn)
iter mini(iter idx = root) {
if (root == 0)
return -1;
while (ls(idx))
idx = ls(idx);
return idx;
}

// 最大元素索引, 空树返回 -1. 时间: O(logn)
iter maxi(iter idx = root) {
if (root == 0)
return -1;
while (rs(idx))
idx = rs(idx);
return idx;
}

// 元素前驱索引. 特别地, 无前驱返回 -1, -1 前驱返回最大元素索引. 时间: O(logn)
iter pre(iter idx) {
if (idx == -1)
return maxi();
if (ls(idx))
return maxi(ls(idx));
while (idx != root) {
if (who(idx) == 1)
return fa(idx);
idx = fa(idx);
}
return -1;
}

// 元素后继索引. 特别地, 无后继返回 -1, -1 后继返回最小元素索引. 时间: O(logn)
iter nxt(iter idx) {
if (idx == -1)
return mini();
if (rs(idx))
return mini(rs(idx));
while (idx != root) {
if (who(idx) == 0)
return fa(idx);
idx = fa(idx);
}
return -1;
}

// 插入元素. 时间: O(logn)
void add(int x) {
// case 1: 空树
if (root == 0) {
int nxt = cnt_tr++;
tr[nxt].key = x;
tr[nxt].cnt = 1;
fa(nxt) = 0;
root = nxt;
maintain(nxt);
return;
}
int idx = root;
while (true) {
// case 2: 元素已存在
if (x == tr[idx].key) {
tr[idx].cnt++;
maintain(idx);
return;
}
// case 3: 元素不存在
if (tr[idx].next[x > tr[idx].key] == 0) {
int nxt = cnt_tr++;
tr[idx].next[x > tr[idx].key] = nxt;
tr[nxt].key = x;
tr[nxt].cnt = 1;
fa(nxt) = idx;
maintain(nxt);
return;
}
idx = tr[idx].next[x > tr[idx].key];
}
}

// 删除索引, 直接删除索引而非计数 -1. 时间: O(logn)
void del_idx(iter idx) {
// case 1: idx 处在一条链 (没有左儿子) / idx 是叶子节点
if (ls(idx) == 0) {
tr[fa(idx)].next[who(idx)] = rs(idx);
fa(rs(idx)) = fa(idx);
if (fa(idx) == 0)
root = rs(idx);
maintain(fa(idx));
return;
}
// case 2: idx 处在一条链 (没有右儿子)
if (rs(idx) == 0) {
tr[fa(idx)].next[who(idx)] = ls(idx);
fa(ls(idx)) = fa(idx);
if (fa(idx) == 0)
root = ls(idx);
maintain(fa(idx));
return;
}
// case 3: 否则, 寻找后继交换后删除
int nxt = rs(idx);
while (ls(nxt))
nxt = ls(nxt);
// 交换两点
tr[fa(idx)].next[who(idx)] = nxt;
fa(ls(idx)) = nxt;
fa(rs(nxt)) = idx;
int &tmp = tr[fa(nxt)].next[who(nxt)];
fa(rs(idx)) = nxt;
tmp = idx;
swap(fa(idx), fa(nxt));
swap(ls(idx), ls(nxt));
swap(rs(idx), rs(nxt));
if (root == idx)
root = nxt;
del_idx(idx);
}

// 删除元素, 元素不存在返回 false. 时间: O(logn)
bool del(int x) {
// case 1: 元素不存在
if (root == 0)
return false;
int idx = root;
while (idx) {
if (x == tr[idx].key) {
tr[idx].cnt--;
if (tr[idx].cnt) {
maintain(idx);
return true;
}
// case 2: 元素存在
del_idx(idx);
return true;
}
idx = tr[idx].next[x > tr[idx].key];
}
return false;
}

// 元素 x 排名, x 不存在返回 -1. 1-index, 时间: O(logn)
int rk(int x) {
if (root == 0)
return -1;
int idx = root, cnt = 0;
while (idx) {
if (x < tr[idx].key) {
idx = ls(idx);
} else {
cnt += tr[ls(idx)].sum;
if (x == tr[idx].key)
return cnt + 1;
cnt += tr[idx].cnt;
idx = rs(idx);
}
}
return -1;
}

// 排名 rk 索引, 总数不足 rk 返回 -1. 1-index, 时间: O(logn)
iter kth(int rk) {
if (tr[root].sum < rk)
return -1;
int idx = root;
while (true) {
if (rk <= tr[ls(idx)].sum) {
idx = ls(idx);
} else {
rk -= tr[ls(idx)].sum + tr[idx].cnt;
if (rk <= 0)
return idx;
idx = rs(idx);
}
}
}

// 第一个大于等于 x 的元素索引, 不存在返回 -1. 时间: O(logn)
iter lower_bound(int x) {
if (root == 0)
return -1;
int idx = root;
while (true) {
if (x == tr[idx].key)
return idx;
if (x < tr[idx].key) {
if (ls(idx) == 0)
return idx;
idx = ls(idx);
} else {
if (rs(idx) == 0)
return nxt(idx);
idx = rs(idx);
}
}
}

// 第一个严格大于 x 的元素索引, 不存在返回 -1. 时间: O(logn)
iter upper_bound(int x) {
if (root == 0)
return -1;
int idx = root;
while (true) {
if (x == tr[idx].key)
return nxt(idx);
if (x < tr[idx].key) {
if (ls(idx) == 0)
return idx;
idx = ls(idx);
} else {
if (rs(idx) == 0)
return nxt(idx);
idx = rs(idx);
}
}
}