简介

平衡树种类众多,各有优劣,其公共功能是完成 O(logn)O(\log{n}) 的增删改查。本文将介绍 Splay,其优点是代码短,缺点是常数较大。此外,Splay 还可以处理区间翻转问题。Splay 不属于严格的平衡树,但与平衡树有相同的均摊复杂度。

模板题链接

洛谷 P3369 【模板】普通平衡树

核心思想

Splay 的核心为 splay 也是重要的基本操作,只要在完成任何搜索时通过 splay 操作将搜索到的节点旋转到根 ,即可非常神奇地得到 O(logn)O(\log{n}) 的时间复杂度。

splay 讨论 66 种情况:zig(右旋)、zag(左旋)、zig-zigzag-zagzig-zagzag-zig

由于左右对称只需讨论其中 33 种情况:

  • zig:完成一次右旋,当且仅当父节点为根,否则考虑后面两种情况。
  • zig-zig:对祖父节点完成一次右旋,再对父节点完成一次右旋。
  • zig-zag:对父节点完成一次左旋,再对祖父节点完成一次右旋。

代码虽短,涵盖了上述三种情况。对于左右对称的讨论在 rotate 中完成。

1
2
3
4
5
6
7
8
9
// splay 核心操作, 分 zig / zig-zig / zig-zag, 所有指针移动后都要 splay
inline void splay(int x) {
for (int f = tr[x].father; f; f = tr[x].father) {
if (tr[f].father)
rotate(who(x) == who(f) ? f : x);
rotate(x);
}
root = x;
}

通过势能分析法可以得到,上述操作能保证 O(logn)O(\log{n}) 的均摊复杂度。

基于旋转的操作

由于 Splay 本身基于大量旋转(虽然这使其常数很大),而其本身并不关心树的平衡因子。因此对于 Splay 而言,基于旋转的增删改查更适合它。另外,此处未提及的其他一些操作可以与 BST(二分搜索树)思路一致。

删除节点

对于 BST,删除节点通常将节点与其后继互换后删除。但这并不适合基于旋转的 Splay

删除节点时,搜索到待删除节点后 splay 为根,然后移除根节点。此时树分裂为左子树和右子树,且满足左子树节点均小于右子树节点。此时搜索到左子树的最大值 splay 为根,以保证根没有右儿子。将右子树的根作为左子树根的右儿子即可。

考虑到左子树可能为空等情况,又要分为若干分支。删除节点是 Splay 中最繁琐的操作,当然你也可以不移除点,将其作为濒死节点始终保留在树中,这对于算法竞赛来说通常无妨。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// 删除元素, 元素不存在返回 -1. 时间: O(logn)
bool del(int x) {
// case 1: 空树
if (root == 0)
return false;
int idx = root;
while (true) {
if (x == tr[idx].key)
break;
// case 2: 元素不存在
if (tr[idx].next[x > tr[idx].key] == 0) {
splay(idx);
return false;
}
idx = tr[idx].next[x > tr[idx].key];
}
tr[idx].cnt--;
up(idx);
splay(idx);
// case 3: 无需移除点
if (tr[idx].cnt)
return true;
// case 4: 移除后左子树为空
if (tr[idx].next[0] == 0) {
root = tr[idx].next[1];
tr[root].father = 0;
if (root)
up(root);
return true;
}
// case 5: 移除后左子树非空
root = tr[idx].next[0];
tr[root].father = 0;
maxi();
int s = tr[idx].next[1];
tr[root].next[1] = s;
tr[s].father = root;
up(root);
return true;
}

前驱后继

以搜索前驱为例,只需将原节点 splay 为根,然后搜索左子树的最大值节点即可。别忘了所有搜索结束时都要 splay

1
2
3
4
5
6
7
8
9
10
11
12
13
// 元素前驱索引. 特别地, 无前驱返回 -1, -1 前驱返回最大元素索引. 时间: O(logn)
iter pre(iter idx) {
if (idx == -1)
return maxi();
splay(idx);
idx = tr[idx].next[0];
if (idx == 0)
return -1;
while (tr[idx].next[1])
idx = tr[idx].next[1];
splay(idx);
return idx;
}

查询 x 的排名

Splay 查询排名非常优雅,只需将 x 对于的节点 splay 为根,返回左子树求和 +1+1 即可。

1
2
3
4
5
6
// 元素 x 排名, x 不存在返回 -1. 1-index, 时间: O(logn)
int rk(int x) {
if (find(x) == -1)
return -1;
return tr[tr[root].next[0]].sum + 1;
}

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
// Jamhus Tao / GreatLiangpi
// Start at 2022/10/6
// Please using int32_t and int64_t to replace the int and long long.
#include <bits/stdc++.h>
#define int int64_t
#define endl '\n'
#pragma GCC optimize(3, "Ofast", "inline")
using namespace std;

// Splay
const int MAX = 100005;
const int INF = 0x3f3f3f3f3f3f3f3fll;

// 代码实现约定:
// tr[0].sum 始终为 0, 其余随意; tr[root].father 始终为 0; root == 0 表示空树, next == 0 表示无儿子
struct Splay {
int key; // 键
int cnt; // 计数
int father; // 父节点
int next[2]; // 左右子节点
int sum; // 子树求和
} tr[MAX]; // 1-index
int root = 0;
int cnt_tr = 1;

// 清空
void clear() {
root = 0;
memset(tr, 0, sizeof(Splay) * cnt_tr);
cnt_tr = 1;
}

// 维护子树求和, idx != 0
inline void up(int idx) { tr[idx].sum = tr[tr[idx].next[0]].sum + tr[tr[idx].next[1]].sum + tr[idx].cnt; }

// 为父的右儿子, idx != root and idx != 0
inline bool who(int idx) { return idx == tr[tr[idx].father].next[1]; }

// 旋转使 idx 上升, idx != root and idx != 0
inline void rotate(int idx) {
bool w = who(idx);
int f = tr[idx].father, ff = tr[f].father, s = tr[idx].next[not w];
tr[ff].next[who(f)] = idx;
tr[idx].father = ff;
tr[f].next[w] = s;
tr[s].father = f;
tr[idx].next[not w] = f;
tr[f].father = idx;
up(f);
up(idx);
}

// splay 核心操作, 分 zig / zig-zig / zig-zag, 所有指针移动后都要 splay. tar 表示 splay 到变为谁的子节点
inline void splay(int idx, int tar = 0) {
for (int f = tr[idx].father; f != tar; f = tr[idx].father) {
if (tr[f].father != tar)
rotate(who(idx) == who(f) ? f : idx);
rotate(idx);
}
if (tar == 0)
root = idx;
}

typedef int iter; // 索引, 避免与值混淆

// 建树, 0-index. 时间: O(nlogn)
int a[MAX];
void build(int n) {
if (n == 0) return;
sort(a, a + n); // 瓶颈
int pre = -INF - 1;
for (int i = 0; i < n; i++) {
if (a[i] == pre) {
tr[cnt_tr - 1].cnt++;
} else {
tr[cnt_tr].key = a[i];
tr[cnt_tr].cnt = 1;
tr[cnt_tr - 1].next[1] = cnt_tr;
tr[cnt_tr].father = cnt_tr - 1;
cnt_tr++;
}
pre = a[i];
}
for (int i = cnt_tr - 1; i >= 1; i--)
up(i);
root = 1;
}

// 遍历. 时间: O(n)
iter walk_list[MAX];
int walk_cnt = 0;
void walk(iter idx) {
if (idx == 0) return;
walk(tr[idx].next[0]);
walk_list[walk_cnt++] = idx;
walk(tr[idx].next[1]);
}

// 元素索引, 元素不存在返回 -1. 时间: O(logn)
iter find(int x) {
if (root == 0)
return -1;
int idx = root;
while (true) {
if (x == tr[idx].key) {
splay(idx);
return idx;
}
if (tr[idx].next[x > tr[idx].key] == 0) {
splay(idx);
return -1;
}
idx = tr[idx].next[x > tr[idx].key];
}
}

// 最小元素索引, 空树返回 -1. 时间: O(logn)
iter mini() {
if (root == 0)
return -1;
int idx = root;
while (tr[idx].next[0])
idx = tr[idx].next[0];
splay(idx);
return idx;
}

// 最大元素索引, 空树返回 -1. 时间: O(logn)
iter maxi() {
if (root == 0)
return -1;
int idx = root;
while (tr[idx].next[1])
idx = tr[idx].next[1];
splay(idx);
return idx;
}

// 元素前驱索引. 特别地, 无前驱返回 -1, -1 前驱返回最大元素索引. 时间: O(logn)
iter pre(iter idx) {
if (idx == -1)
return maxi();
splay(idx);
idx = tr[idx].next[0];
if (idx == 0)
return -1;
while (tr[idx].next[1])
idx = tr[idx].next[1];
splay(idx);
return idx;
}

// 元素后继索引. 特别地, 无后继返回 -1, -1 后继返回最小元素索引. 时间: O(logn)
iter nxt(iter idx) {
if (idx == -1)
return mini();
splay(idx);
idx = tr[idx].next[1];
if (idx == 0)
return -1;
while (tr[idx].next[0])
idx = tr[idx].next[0];
splay(idx);
return idx;
}

// 插入元素. 时间: O(logn)
void add(int x) {
// case 1: 空树
if (root == 0) {
int nxt = cnt_tr++;
tr[nxt].key = x;
tr[nxt].cnt = 1;
tr[nxt].father = 0;
up(nxt);
splay(nxt);
return;
}
int idx = root;
while (true) {
// case 2: 元素已存在
if (x == tr[idx].key) {
tr[idx].cnt++;
up(idx);
splay(idx);
return;
}
// case 3: 元素不存在
if (tr[idx].next[x > tr[idx].key] == 0) {
int nxt = cnt_tr++;
tr[idx].next[x > tr[idx].key] = nxt;
tr[nxt].key = x;
tr[nxt].cnt = 1;
tr[nxt].father = idx;
up(nxt);
up(idx);
splay(nxt);
return;
}
idx = tr[idx].next[x > tr[idx].key];
}
}

// 删除元素, 元素不存在返回 false. 时间: O(logn)
bool del(int x) {
// case 1: 元素不存在 (同时 splay 为根)
if (find(x) == -1)
return -1;
tr[root].cnt--;
// case 2: 无需移除点
if (tr[root].cnt)
return true;
// case 3: 移除后左子树为空
if (tr[root].next[0] == 0) {
root = tr[root].next[1];
tr[root].father = 0;
if (root)
up(root);
return true;
}
// case 4: 移除后左子树非空
int ri = tr[root].next[1];
root = tr[root].next[0];
tr[root].father = 0;
maxi();
tr[root].next[1] = ri;
tr[ri].father = root;
up(root);
return true;
}

// 元素 x 排名, x 不存在返回 -1. 1-index, 时间: O(logn)
int rk(int x) {
if (find(x) == -1)
return -1;
return tr[tr[root].next[0]].sum + 1;
}

// 排名 rk 索引, 总数不足 rk 返回 -1. 1-index, 时间: O(logn)
iter kth(int rk) {
if (tr[root].sum < rk)
return -1;
int idx = root, cnt = 0;
while (true) {
if (rk <= tr[tr[idx].next[0]].sum) {
idx = tr[idx].next[0];
} else {
rk -= tr[tr[idx].next[0]].sum + tr[idx].cnt;
if (rk <= 0) {
splay(idx);
return idx;
}
idx = tr[idx].next[1];
}
}
}

// 第一个大于等于 x 的元素索引, 不存在返回 -1. 时间: O(logn)
iter lower_bound(int x) {
if (root == 0)
return -1;
find(x);
if (x <= tr[root].key)
return root;
return nxt(root);
}

// 第一个严格大于 x 的元素索引, 不存在返回 -1. 时间: O(logn)
iter upper_bound(int x) {
if (root == 0)
return -1;
find(x);
if (x < tr[root].key)
return root;
return nxt(root);
}

区间翻转问题

洛谷 P3391 【模板】文艺平衡树

简明题意

你需要提供一种操作:翻转一个区间,例如原有序序列是 5 4 3 2 15\ 4\ 3\ 2\ 1,翻转区间是 [2,4][2,4] 的话,结果是 5 2 3 4 15\ 2\ 3\ 4\ 1

给定数字 n (n105)n~(n \le 10^5)q (q105)q~(q \le 10^5) 分别表示原序列 [1..n][1..n] 和操作次数。输出最终序列。

解题思路

考虑使用 Splay 维护,只需将 l-1 splay 到根节点,将 r+1 splayl-1 的右儿子(添加虚点 0n+1)。根据二叉平衡树的性质,r+1 节点的左子树即为区间 [l, r],此时只需给 r+1 的左儿子打上 lazy 标记,之后 push_down 即可。另外,由于子树交换无法维护键值有序,因此需要使用 第 k 大 二叉搜索。

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
// Jamhus Tao / GreatLiangpi
// Start at 2022/10/6
// Please using int32_t and int64_t to replace the int and long long.
#include <bits/stdc++.h>
#define int int64_t
#define endl '\n'
#pragma GCC optimize(3, "Ofast", "inline")
using namespace std;

// Splay 平衡树
struct Splay {
int father;
int next[2];
int sum;
bool lazy;
} tr[MAX]; // 1-index
int root = 0;
int cnt_tr = 1;

// 清空
void clear() {
memset(tr, 0, sizeof(Splay) * cnt_tr);
cnt_tr = 1;
}

// 维护子树求和
inline void up(int x) { tr[x].sum = tr[tr[x].next[0]].sum + tr[tr[x].next[1]].sum + 1; }

// 维护 lazy 标记
inline void done(int x) {
if (tr[x].lazy) {
int &li = tr[x].next[0];
int &ri = tr[x].next[1];
tr[li].lazy ^= true;
tr[ri].lazy ^= true;
swap(li, ri);
tr[x].lazy = false;
}
}

// 为父的右儿子
inline bool who(int x) { return x == tr[tr[x].father].next[1]; }

// 旋转使 x 上升
inline void rotate(int x) {
bool w = who(x);
int f = tr[x].father, ff = tr[f].father, s = tr[x].next[not w];
tr[ff].next[who(f)] = x;
tr[x].father = ff;
tr[f].next[w] = s;
tr[s].father = f;
tr[x].next[not w] = f;
tr[f].father = x;
up(f);
up(x);
}

// splay 核心操作
inline void splay(int x, int tar) {
for (int f = tr[x].father; f != tar; f = tr[x].father) {
if (tr[f].father != tar)
rotate(who(x) == who(f) ? f : x);
rotate(x);
}
if (tar == 0)
root = x;
}

typedef int iter;

// 排名 rk 索引
iter kth(int rk, int tar) {
int idx = root, cnt = 0;
while (true) {
done(idx);
if (tr[tr[idx].next[0]].sum >= rk) {
idx = tr[idx].next[0];
} else {
rk -= tr[tr[idx].next[0]].sum + 1;
if (rk <= 0) {
splay(idx, tar);
return idx;
}
idx = tr[idx].next[1];
}
}
}

// 建树, 过程保证数值与索引相等
void build(int n) {
for (int i = 1; i <= n; i++) {
tr[cnt_tr - 1].next[1] = cnt_tr;
tr[cnt_tr].father = cnt_tr - 1;
cnt_tr++;
}
for (int i = cnt_tr - 1; i >= 1; i--)
up(i);
root = 1;
cnt_tr = n + 1;
}

// 导出答案
vector<int> out;
void walk(int idx) {
done(idx);
if (tr[idx].next[0]) walk(tr[idx].next[0]);
out.push_back(idx - 1);
if (tr[idx].next[1]) walk(tr[idx].next[1]);
}

void solve() {
int n, q;
cin >> n >> q;
build(n + 2);
while (q--) {
int l, r;
cin >> l >> r;
kth(l, 0);
kth(r + 2, root);
tr[tr[tr[root].next[1]].next[0]].lazy ^= true;
}
walk(root);
for (int i = 1; i <= n; i++)
cout << out[i] << ' ';
cout << endl;
}

int32_t main() {
cin.tie(0);
cout.tie(0);
ios::sync_with_stdio(false);

solve();

return 0;
}