简介
二叉搜索树(BST
)用于快速增删改查数据,对外部可以看作一个集合。BST
的平均时间复杂度是 O ( log n ) O(\log{n}) O ( log n ) ,但是遗憾的是其最坏时间复杂度是 O ( n ) O(n) O ( n ) 的。为了解决 BST
的这个问题,之后会介绍多种平衡树。平衡树通过不同的维护方式,保证了其始终有 O ( log n ) O(\log{n}) O ( log n ) 的时间复杂度,BST
是所有平衡树的基础,同时其多数操作在多数平衡树中是通用的。
模板题链接
只有最后一组数据会 TLE。
洛谷 P3369 【模板】普通平衡树
核心思想
定义: 对于一棵有根树的任意节点,如果满足左子树所有节点的权值均小于该节点的权值,且右子树均大于该值,则称之为 BST
。即对于任意节点 x x x 有权 max { k l } < k < min { k r } \max\{k_l\} < k < \min\{k_r\} max { k l } < k < min { k r } ,其中 k k k 表示节点 x x x 的权值,{ k l } \{k_l\} { k l } 表示 x x x 左子树权值的集合。一般不允许 BST
包含相同权值的元素。
性质:
根据 BST
的定义,在 BST
上查找元素时,只需从根搜索,遇到当节点大于时向左儿子搜索,小于时向右儿子搜索,直到等于或没有儿子时返回,时间复杂度即为元素在树上的深度,而最坏时间复杂度即为最大节点深度。由于二叉树每层可以容纳的节点数是指数增长的,因此树的平均深度为 log n \log{n} log n 。但又由于 BST
无法保证树深稳定为 log n \log{n} log n 因此无法保证稳定的时间复杂度,有此引出了平衡树。
平衡树的定义: 对于一棵二叉搜索树的任意节点,如果其左右子树大小差始终不超过 1 1 1 ,则称之平衡树。由于维护一棵平衡树的时间代价十分高,通常称所有左右子树高度差具有一定约束的都为平衡树。
代码实现
结构定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 struct BST { int key; int cnt; int father; int next[2 ]; int sum; } tr[MAX]; int cnt_tr = 1 ;int root = 0 ;#define fa(idx) tr[idx].father #define ls(idx) tr[idx].next[0] #define rs(idx) tr[idx].next[1] #define who(idx) ((idx) == rs(fa(idx))) inline void up (int idx) { tr[idx].sum = tr[ls (idx)].sum + tr[rs (idx)].sum + tr[idx].cnt; }inline void maintain (int idx) { while (idx) { up (idx); idx = fa (idx); } }
查找元素
从根搜索,遇到当节点大于时向左儿子搜索,小于时向右儿子搜索,直到等于或没有儿子时返回。
1 2 3 4 5 6 7 8 9 10 11 12 iter find (int x) { if (root == 0 ) return -1 ; int idx = root; while (idx) { if (x == tr[idx].key) return idx; idx = tr[idx].next[x > tr[idx].key]; } return -1 ; }
插入元素
搜索元素在树上的位置,如果元素存在则直接返回,否则当无法继续向下搜索时插入元素到搜索的方向。
其中 maintain
函数用于维护树上的一些值,此处用于维护子树求和,用于求元素排名。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 void add (int x) { if (root == 0 ) { int nxt = cnt_tr++; tr[nxt].key = x; tr[nxt].cnt = 1 ; tr[nxt].father = 0 ; root = nxt; maintain (nxt); return ; } int idx = root; while (true ) { if (x == tr[idx].key) { tr[idx].cnt++; maintain (idx); return ; } if (tr[idx].next[x > tr[idx].key] == 0 ) { int nxt = cnt_tr++; tr[idx].next[x > tr[idx].key] = nxt; tr[nxt].key = x; tr[nxt].cnt = 1 ; tr[nxt].father = idx; maintain (nxt); return ; } idx = tr[idx].next[x > tr[idx].key]; } }
删除元素
删除元素是最为复杂的操作。如果元素存在,要将其分为 3 3 3 类:
元素为叶子节点。直接删除即可。
元素为链节点,即没有左儿子或右儿子。删除节点后将其唯一儿子连接到父节点。
元素左右儿子均存在。则其前驱或后继一定在其子树中,与其前驱或后继交换值后,删除其前驱或后继。前驱或后继一定属于上述两种情况。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 void del_idx (iter idx) { if (ls (idx) == 0 ) { tr[fa (idx)].next[who (idx)] = rs (idx); tr[rs (idx)].father = fa (idx); if (fa (idx) == 0 ) root = rs (idx); maintain (fa (idx)); return ; } if (rs (idx) == 0 ) { tr[fa (idx)].next[who (idx)] = ls (idx); tr[ls (idx)].father = fa (idx); if (fa (idx) == 0 ) root = ls (idx); maintain (fa (idx)); return ; } int pre = idx; idx = rs (idx); while (ls (idx)) idx = ls (idx); swap (tr[pre].key, tr[idx].key); swap (tr[pre].cnt, tr[idx].cnt); del_idx (idx); } bool del (int x) { if (root == 0 ) return false ; int idx = root; while (idx) { if (x == tr[idx].key) { tr[idx].cnt--; if (tr[idx].cnt) { maintain (idx); return true ; } del_idx (idx); return true ; } idx = tr[idx].next[x > tr[idx].key]; } return false ; }
前驱后继
以查找后继为例,需要分为 2 2 2 类:
如果节点存在右儿子,则后继一定在右子树上,为右子树最小节点。
否则,后继一定是其祖先节点,不断向父节点搜索,找到第一个大于的点即为后继。
最后一个节点可能没有后继,依然按照上一种情况处理。
前驱同理。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 iter pre (iter idx) { if (idx == -1 ) return maxi (); if (ls (idx)) return maxi (ls (idx)); while (idx != root) { if (who (idx) == 1 ) return fa (idx); idx = fa (idx); } return -1 ; } iter nxt (iter idx) { if (idx == -1 ) return mini (); if (rs (idx)) return mini (rs (idx)); while (idx != root) { if (who (idx) == 0 ) return fa (idx); idx = fa (idx); } return -1 ; }
快速建树
在已知初始状态包含哪些元素时,可以 O ( n ) O(n) O ( n ) 快速建树(除去排序)。每次将区间中位数作为根节点即可建一棵相对平衡的树。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 pair<int , int > a[MAX]; void build (int l, int r, int pre) { if (l > r) return ; int nxt = cnt_tr++, mid = (l + r) / 2 ; tr[pre].next[a[mid].first > tr[pre].key] = nxt; tr[nxt].key = a[mid].first; tr[nxt].cnt = a[mid].second; tr[nxt].father = pre; build (l, mid - 1 , nxt); build (mid + 1 , r, nxt); if (nxt) up (nxt); } void build (int n) { if (n == 0 ) return ; sort (a, a + n); int pre = -INF - 1 , cnt = 0 ; for (int i = 0 ; i < n; i++) { if (a[i].first != pre) { pre = a[i].first; a[cnt++] = {pre, 0 }; } a[cnt - 1 ].second++; } build (0 , cnt - 1 , 0 ); root = 1 ; }
元素排名
求元素排名需要维护子树求和 tr[idx].sum
,通过与查找元素相同的方式查找,只需在当前指针每次向右儿子转移时,累加左子树求和与当前节点大小即可,最终得到的数即为排名。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 int rk (int x) { if (root == 0 ) return -1 ; int idx = root, cnt = 0 ; while (idx) { if (x < tr[idx].key) { idx = ls (idx); } else { cnt += tr[ls (idx)].sum; if (x == tr[idx].key) return cnt + 1 ; cnt += tr[idx].cnt; idx = rs (idx); } } return -1 ; }
第 k 元素
与元素排名相似,需要维护子树求和 tr[idx].sum
。查找时如果 rk
大于左子树与当前节点求和,则向右子树查找,rk
减去已经抵消的排名即可;如果 rk
仅大于左子树求和,则返回当前节点;否则,向左子树查找。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 iter kth (int rk) { if (tr[root].sum < rk) return -1 ; int idx = root; while (true ) { if (rk <= tr[ls (idx)].sum) { idx = ls (idx); } else { rk -= tr[ls (idx)].sum + tr[idx].cnt; if (rk <= 0 ) return idx; idx = rs (idx); } } }
完整代码
二叉搜索树无法直接应用,但其多数代码片段依然可以在多数平衡树中使用。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 #include <bits/stdc++.h> #define int int64_t #define endl '\n' #pragma GCC optimize(3, "Ofast" , "inline" ) using namespace std;const int MAX = 100005 ;const int INF = 0x3f3f3f3f3f3f3f3f ll;struct BST { int key; int cnt; int father; int next[2 ]; int sum; } tr[MAX]; int cnt_tr = 1 ;int root = 0 ;#define fa(idx) tr[idx].father #define ls(idx) tr[idx].next[0] #define rs(idx) tr[idx].next[1] #define who(idx) ((idx) == rs(fa(idx))) void clear () { root = 0 ; memset (tr, 0 , sizeof (BST) * cnt_tr); cnt_tr = 1 ; } inline void up (int idx) { tr[idx].sum = tr[ls (idx)].sum + tr[rs (idx)].sum + tr[idx].cnt; }inline void maintain (int idx) { while (idx) { up (idx); idx = fa (idx); } } typedef int iter; pair<int , int > a[MAX]; void build (int l, int r, int pre) { if (l > r) return ; int nxt = cnt_tr++, mid = (l + r) / 2 ; tr[pre].next[a[mid].first > tr[pre].key] = nxt; tr[nxt].key = a[mid].first; tr[nxt].cnt = a[mid].second; fa (nxt) = pre; build (l, mid - 1 , nxt); build (mid + 1 , r, nxt); if (nxt) up (nxt); } void build (int n) { if (n == 0 ) return ; sort (a, a + n); int pre = -INF - 1 , cnt = 0 ; for (int i = 0 ; i < n; i++) { if (a[i].first != pre) { pre = a[i].first; a[cnt++] = {pre, 0 }; } a[cnt - 1 ].second++; } build (0 , cnt - 1 , 0 ); root = 1 ; } iter walk_list[MAX]; int walk_cnt = 0 ;void walk (iter idx) { if (idx == 0 ) return ; walk (ls (idx)); walk_list[walk_cnt++] = idx; walk (rs (idx)); } iter find (int x) { if (root == 0 ) return -1 ; int idx = root; while (idx) { if (x == tr[idx].key) return idx; idx = tr[idx].next[x > tr[idx].key]; } return -1 ; } iter mini (iter idx = root) { if (root == 0 ) return -1 ; while (ls (idx)) idx = ls (idx); return idx; } iter maxi (iter idx = root) { if (root == 0 ) return -1 ; while (rs (idx)) idx = rs (idx); return idx; } iter pre (iter idx) { if (idx == -1 ) return maxi (); if (ls (idx)) return maxi (ls (idx)); while (idx != root) { if (who (idx) == 1 ) return fa (idx); idx = fa (idx); } return -1 ; } iter nxt (iter idx) { if (idx == -1 ) return mini (); if (rs (idx)) return mini (rs (idx)); while (idx != root) { if (who (idx) == 0 ) return fa (idx); idx = fa (idx); } return -1 ; } void add (int x) { if (root == 0 ) { int nxt = cnt_tr++; tr[nxt].key = x; tr[nxt].cnt = 1 ; fa (nxt) = 0 ; root = nxt; maintain (nxt); return ; } int idx = root; while (true ) { if (x == tr[idx].key) { tr[idx].cnt++; maintain (idx); return ; } if (tr[idx].next[x > tr[idx].key] == 0 ) { int nxt = cnt_tr++; tr[idx].next[x > tr[idx].key] = nxt; tr[nxt].key = x; tr[nxt].cnt = 1 ; fa (nxt) = idx; maintain (nxt); return ; } idx = tr[idx].next[x > tr[idx].key]; } } void del_idx (iter idx) { if (ls (idx) == 0 ) { tr[fa (idx)].next[who (idx)] = rs (idx); fa (rs (idx)) = fa (idx); if (fa (idx) == 0 ) root = rs (idx); maintain (fa (idx)); return ; } if (rs (idx) == 0 ) { tr[fa (idx)].next[who (idx)] = ls (idx); fa (ls (idx)) = fa (idx); if (fa (idx) == 0 ) root = ls (idx); maintain (fa (idx)); return ; } int nxt = rs (idx); while (ls (nxt)) nxt = ls (nxt); tr[fa (idx)].next[who (idx)] = nxt; fa (ls (idx)) = nxt; fa (rs (nxt)) = idx; int &tmp = tr[fa (nxt)].next[who (nxt)]; fa (rs (idx)) = nxt; tmp = idx; swap (fa (idx), fa (nxt)); swap (ls (idx), ls (nxt)); swap (rs (idx), rs (nxt)); if (root == idx) root = nxt; del_idx (idx); } bool del (int x) { if (root == 0 ) return false ; int idx = root; while (idx) { if (x == tr[idx].key) { tr[idx].cnt--; if (tr[idx].cnt) { maintain (idx); return true ; } del_idx (idx); return true ; } idx = tr[idx].next[x > tr[idx].key]; } return false ; } int rk (int x) { if (root == 0 ) return -1 ; int idx = root, cnt = 0 ; while (idx) { if (x < tr[idx].key) { idx = ls (idx); } else { cnt += tr[ls (idx)].sum; if (x == tr[idx].key) return cnt + 1 ; cnt += tr[idx].cnt; idx = rs (idx); } } return -1 ; } iter kth (int rk) { if (tr[root].sum < rk) return -1 ; int idx = root; while (true ) { if (rk <= tr[ls (idx)].sum) { idx = ls (idx); } else { rk -= tr[ls (idx)].sum + tr[idx].cnt; if (rk <= 0 ) return idx; idx = rs (idx); } } } iter lower_bound (int x) { if (root == 0 ) return -1 ; int idx = root; while (true ) { if (x == tr[idx].key) return idx; if (x < tr[idx].key) { if (ls (idx) == 0 ) return idx; idx = ls (idx); } else { if (rs (idx) == 0 ) return nxt (idx); idx = rs (idx); } } } iter upper_bound (int x) { if (root == 0 ) return -1 ; int idx = root; while (true ) { if (x == tr[idx].key) return nxt (idx); if (x < tr[idx].key) { if (ls (idx) == 0 ) return idx; idx = ls (idx); } else { if (rs (idx) == 0 ) return nxt (idx); idx = rs (idx); } } }