跳转至

红黑树

这是一颗红黑树。所有图都是从 OI-Wiki 上抄的

红黑树

性质

  • 结点为红色或黑色。
  • NIL 结点(空叶子结点)为黑色。
  • 红色结点的子结点为黑色。
  • 从根结点到 NIL 结点的每条路径上的黑色结点数量相同。
  • 根结点为黑色(部分资料不要求这条性质(比如oi-wiki),这里要求这条性质是为了方便在插入和删除时判断)。

结点维护的信息

变量名 解释
ls 左子结点
rs 右子结点
fa 父结点
col 颜色
cnt 当前结点大小(维护相同大小的关键字)
siz 以当前结点为根的子树大小
num 关键字

旋转

旋转

旋转可以在不改变平衡树性质的同时改变结点深度,因此用来维护红黑树的性质。

从左到右称为左旋,从右到左称为右旋。在下文中,将使用其中的子结点描述旋转,即图中左旋中的绿色结点,右旋中的黄色结点,这样,如果被旋转的结点是父结点的左子结点,那么一定是右旋,如果是右子结点,一定是左旋。

插入

首先找到待插入结点的位置,和普通的二叉搜索树类似。如果待插入的关键字原本存在,则将该结点的 cnt\(1\)。如果不存在则插入结点,插入结点的颜色为红色,然后进行讨论,维护红黑性质。

在实现的时候需要注意处理 siz,在寻找插入结点的位置时就可以将经过结点的 siz\(1\)

Case 1

当前结点为根结点,将根结点变成黑色,结束修正。

Case 2

当前结点的父结点为黑色,满足性质,不用修正。

Case 3

当前结点 N 的叔结点 U 和父结点 P 都是红色。此时祖父结点 G 一定是黑色,则将父结点 P、叔结点 U、祖父结点 G 的颜色反转,同时祖父结点 G 变成红色,可能违反性质,因此继续维护祖父结点 G。

Case2

Case 4

当前结点 N 的父结点 P 为红色,叔结点 U 为黑色(或不存在),且当前结点 N、父结点 P、祖父结点 G 不共线。旋转当前结点 N,转化为 Case 5。

Case3

Case 5

当前结点 N 的父结点 P 为红色,叔结点 U 为黑色(或不存在),且当前结点 N、父结点 P、祖父结点 G 共线。旋转父结点 P,再将父结点 P 变成黑色,祖父结点 G 变成红色。这样满足了性质,结束修正。

Case4

删除

首先找到待删除结点的位置,如果待删除结点的 cnt 大于 \(1\),则将 cnt\(1\) 即可,否则要将该结点删除。

在实现的时候需要注意处理 siz,在寻找插入结点的位置时就可以将经过结点的 siz\(1\)

Case 1

当前结点有两个子结点,需要找到该结点的后继结点进行替换(只替换 keyvalcnt,不改变结点的父子关系),由于后继结点一定没有左子结点 (因为如果有左子结点,那么这个后继结点一定是假的),这样就转化为了 Case 2,3,4,5。

在实现的时候需要注意处理 siz,在找后继结点替换时需要将 从待删除结点到后继结点的路径上的所有点(除了待删除结点)siz 减去 后继结点的 cnt 别问我怎么知道的

Case 2

当前结点为红色,且有一个子结点。直接将子结点取代当前结点即可。

Case 3

当前结点为红色,且无子结点,直接删。

Case 4

当前结点为黑色,且有一个子结点。由于红黑树的性质,这个子结点一定是红色而且这个子结点没有子结点,因此交换当前结点与子结点的颜色,转化为 Case 2。

Case 5

当前结点为黑色,且无子结点。这种情况最麻烦,分为以下 5 种情况。为了方便维护,我们先不删除待删除结点,而是将性质修正之后再删除。

Case 5.0

当前结点为根结点,结束修正。

Case 5.1

兄弟结点 S 为红色。则父结点 P 一定为黑色。旋转兄弟结点 S,再交换父结点 P,兄弟结点 S 的颜色,转化为 Case 5.3、5.4、5.5。

Case5.1

Case 5.2

兄弟结点 S 为黑色,父结点 P 为黑色,侄结点 C、D 为黑色或不存在。将兄弟结点 S 变为红色,再维护父结点 P,转化为 Case 5 的任意一种。

Case5.2

Case 5.3

兄弟结点 S 为黑色,父结点 P 为红色,侄结点 C、D 为黑色或不存在。将兄弟结点 S、父结点 P 的颜色交换,修正完成。

Case5.3

Case 5.4

兄弟结点 S 为黑色,与当前结点 N 反向的侄结点 D 为红色(父结点、与当前结点 N 同向的侄结点 C 的颜色不关心)。旋转兄弟结点 S,交换父结点 P、兄弟结点 S 的颜色。反向侄结点 D 变为黑色,修正完成。

Case5.4

Case 5.5

兄弟结点 S 为黑色,与当前结点 N 同向的侄结点 C 为红色,与当前结点 N 反向的侄结点 D 为黑色或不存在(父结点的颜色不关心)。旋转同向的侄结点 C,交换兄弟结点 S,同向侄结点 C 的颜色,转化为 Case 5.4。

Case5.5

红黑树黑高的改变

在插入操作中,只有 Case 1 才会增加黑高;在删除操作中,只有当进入 Case 5(包括由 Case 1 变为 Case 5)时,且只使用 Case 5.2 进行修正直到 Case 5.0 时,黑高才会减小。

例题

普通平衡树

数组实现
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#include <iostream>
#define LS(x) son[x][0]
#define RS(x) son[x][1]
using namespace std;
const int MAXN = 1e5 + 5;
class RedBlackTree {
private:
    static const bool RED = true, BLACK = false;
    int tot, root, fa[MAXN], cnt[MAXN], siz[MAXN], num[MAXN], son[MAXN][2];
    bool col[MAXN];
    int newNode(int val) {
        num[++tot] = val, siz[tot] = 1, cnt[tot] = 1, col[tot] = RED;
        return fa[tot] = LS(tot) = RS(tot) = 0, tot;
    }
    void rotate(int x) {
        int y = fa[x], z = fa[y];
        bool pos = x == RS(y);
        if (z)
            son[z][y == RS(z)] = x;
        else
            root = x;
        son[y][pos] = son[x][!pos], son[x][!pos] = y;
        fa[son[y][pos]] = y, fa[y] = x, fa[x] = z;
        siz[x] = siz[y], siz[y] = cnt[y] + siz[LS(y)] + siz[RS(y)];
    }

public:
    RedBlackTree() { tot = root = 0; }
    void insert(int val) {
        int x = root, y = 0, z = 0, u = 0;
        while (x && num[x] != val)
            ++siz[x], y = x, x = son[x][val > num[x]];
        if (x)
            return ++cnt[x], ++siz[x], void();
        x = newNode(val);
        if (!y)
            root = x;
        else {
            fa[x] = y, son[y][val > num[y]] = x;
            while (x != root && col[x] == RED && col[fa[x]] == RED) {
                y = fa[x], z = fa[y], u = son[fa[y]][y == LS(fa[y])], col[z] = RED;
                if (u && col[u] == RED)
                    col[y] = BLACK, col[u] = BLACK, x = z;
                else if ((x == LS(y)) ^ (y == LS(z)))
                    rotate(x), rotate(x), col[x] = BLACK;
                else
                    rotate(y), col[y] = BLACK, x = y;
            }
        }
        col[root] = BLACK;
    }
    void erase(int val) {
        int x = root, y = 0, z = 0;
        while (x && num[x] != val)
            x = son[x][val > num[x]];
        for (y = x, --cnt[x]; y; y = fa[y])
            --siz[y];
        if (cnt[x])
            return;
        if (LS(x) && RS(x)) {
            for (y = x, x = RS(x); LS(x);)
                x = LS(x);
            num[y] = num[x], cnt[y] = cnt[x];
            for (z = x; z != y; z = fa[z])
                siz[z] -= cnt[y];
        }
        y = x, z = 0;
        if ((LS(x) > 0 || RS(x) > 0) && col[x] == BLACK)
            col[LS(x) + RS(x)] = BLACK, col[x] = RED;
        if (col[x] == RED) {
            z = son[x][LS(x) == 0];
            if (z)
                fa[z] = fa[x];
        } else
            for (int u = 0, c = 0, d = 0; x != root;) {
                bool pos = x == RS(fa[x]);
                u = son[fa[x]][!pos], c = son[u][pos], d = son[u][!pos];
                if (col[u] == RED)
                    rotate(u), col[u] = BLACK, col[fa[x]] = RED;
                else if (d && col[d] == RED) {
                    rotate(u), swap(col[u], col[fa[x]]), col[d] = BLACK;
                    break;
                } else if (c && col[c] == RED)
                    rotate(c), col[u] = RED, col[c] = BLACK;
                else if (col[fa[x]] == RED) {
                    col[u] = RED, col[fa[x]] = BLACK;
                    break;
                } else
                    col[u] = RED, x = fa[x];
            }
        if (y == root)
            root = z;
        else
            son[fa[y]][y == RS(fa[y])] = z;
    }
    int queryRnk(int val) {
        int ans = 1;
        for (int x = root; x; x = son[x][num[x] <= val]) {
            if (num[x] == val)
                return ans + siz[LS(x)];
            if (num[x] <= val)
                ans += siz[LS(x)] + cnt[x];
        }
        return ans;
    }
    int queryKth(int k) {
        for (int x = root; x;) {
            if (k > siz[LS(x)] && k <= siz[LS(x)] + cnt[x])
                return num[x];
            else if (k <= siz[LS(x)])
                x = LS(x);
            else
                k -= siz[LS(x)] + cnt[x], x = RS(x);
        }
    }
    int queryPre(int val) {
        int x = root, y = 0;
        while (x) {
            if (val <= num[x])
                x = LS(x);
            else
                y = x, x = RS(x);
        }
        return num[y];
    }
    int querySuc(int val) {
        int x = root, y = 0;
        while (x) {
            if (val < num[x])
                y = x, x = LS(x);
            else
                x = RS(x);
        }
        return num[y];
    }
};
指针实现 略有点抽象
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#include <iostream>
using namespace std;
template < typename Key >
class RedBlackTree {
private:
    const static bool RED = true, BLACK = false;
    struct Node {
        int siz, cnt;
        bool col;
        Key key;
        Node *fa, *lc, *rc;
        inline Node(Key key) {
            this->key = key, siz = 1, cnt = 1, col = RED;
            fa = lc = rc = nullptr;
        }
        inline bool left() {
            return this == this->fa->lc;
        }
        inline bool allLeft() {
            return this == this->fa->lc && this->fa == this->fa->fa->lc;
        }
        inline bool allRight() {
            return this == this->fa->rc && this->fa == this->fa->fa->rc;
        }
        inline Node *uncle() {
            Node *fa = this->fa, *grandFa = fa->fa;
            return fa == grandFa->lc ? grandFa->rc : grandFa->lc;
        }
    };
    Node *root;
    inline void rotate(Node *node) {
        Node *oldFa = node->fa, *grandFa = node->fa->fa;
        if (grandFa == nullptr)
            root = node;
        else if (oldFa->left())
            grandFa->lc = node;
        else
            grandFa->rc = node;
        if (node->left()) {
            oldFa->lc = node->rc, node->rc = oldFa, node->siz = oldFa->siz, oldFa->siz -= node->lc == nullptr ? 0 : node->lc->siz;
            if (oldFa->lc != nullptr)
                oldFa->lc->fa = oldFa;
        } else {
            oldFa->rc = node->lc, node->lc = oldFa, node->siz = oldFa->siz, oldFa->siz -= node->rc == nullptr ? 0 : node->rc->siz;
            if (oldFa->rc != nullptr)
                oldFa->rc->fa = oldFa;
        }
        oldFa->fa = node, node->fa = grandFa, oldFa->siz -= node->cnt;
    }

public:
    RedBlackTree() {
        root = nullptr;
    }
    inline void insert(Key key) {
        Node *node = root, *fa = nullptr, *grandFa = nullptr, *uncle = nullptr;
        while (node != nullptr && node->key != key)
            ++node->siz, fa = node, node = key < node->key ? node->lc : node->rc;
        if (node != nullptr)
            return ++node->cnt, ++node->siz, void();
        node = new Node(key);
        if (fa == nullptr)
            root = node;
        else {
            node->fa = fa;
            if (key < fa->key)
                fa->lc = node;
            else
                fa->rc = node;
            for (bool temp; node != root && node->col == RED && node->fa->col == RED;) {
                fa = node->fa, grandFa = fa->fa;
                uncle = node->uncle();
                if (uncle != nullptr && uncle->col == RED)
                    fa->col = BLACK, uncle->col = BLACK, node = grandFa;
                else if (node->allLeft() || node->allRight())
                    rotate(fa), fa->col = BLACK, node = fa;
                else
                    rotate(node), rotate(node), node->col = BLACK;
                grandFa->col = RED;
            }
        }
        root->col = BLACK;
    }
    inline void erase(Key key) {
        Node *node = root, *oldNode = nullptr, *child = nullptr;
        while (node != nullptr && node->key != key)
            node = key < node->key ? node->lc : node->rc;
        for (oldNode = node, --node->cnt; oldNode != nullptr; oldNode = oldNode->fa)
            --oldNode->siz;
        if (node->cnt)
            return;
        if (node->lc != nullptr && node->rc != nullptr) {
            oldNode = node;
            for (node = node->rc; node->lc != nullptr;)
                node = node->lc;
            oldNode->key = node->key, oldNode->cnt = node->cnt;
            for (Node *newNode = node; newNode != oldNode; newNode = newNode->fa)
                newNode->siz -= oldNode->cnt;
        }
        oldNode = node;
        if ((node->lc != nullptr || node->rc != nullptr) && node->col == BLACK) {
            if (node->lc != nullptr)
                node->lc->col = BLACK;
            else
                node->rc->col = BLACK;
            node->col = RED;
        }
        if (node->col == RED) {
            if (node->lc != nullptr)
                child = node->lc;
            else if (node->rc != nullptr)
                child = node->rc;
            if (child != nullptr)
                child->fa = node->fa;
        } else {
            Node *bro = nullptr, *close = nullptr, *distant = nullptr;
            while (node != root) {
                close = node->left() ? (bro = node->fa->rc)->lc : (bro = node->fa->lc)->rc, distant = node->left() ? bro->rc : bro->lc;
                if (bro->col == RED)
                    rotate(bro), bro->col = BLACK, node->fa->col = RED;
                else if (distant != nullptr && distant->col == RED) {
                    rotate(bro), swap(bro->col, node->fa->col), distant->col = BLACK;
                    break;
                } else if (close != nullptr && close->col == RED)
                    rotate(close), bro->col = RED, close->col = BLACK;
                else {
                    bro->col = RED;
                    if (node->fa->col == RED) {
                        node->fa->col = BLACK;
                        break;
                    } else
                        node = node->fa;
                }
            }
        }
        if (oldNode == root)
            root = child;
        else if (oldNode->left())
            oldNode->fa->lc = child;
        else
            oldNode->fa->rc = child;
        return delete (oldNode), void();
    }
    inline int qryRank(Key key) {
        int ans = 1;
        for (Node *node = root; node != nullptr;) {
            if (node->key == key)
                return ans += node->lc == nullptr ? 0 : node->lc->siz;
            else if (node->key > key)
                node = node->lc;
            else
                ans += (node->lc == nullptr ? 0 : node->lc->siz) + node->cnt, node = node->rc;
        }
        return ans;
    }
    inline Key qryKth(int rank) {
        for (Node *node = root; node != nullptr;) {
            if (node->lc == nullptr) {
                if (rank <= node->cnt)
                    return node->key;
                else
                    rank -= node->cnt, node = node->rc;
            } else if (rank > node->lc->siz && rank <= node->lc->siz + node->cnt)
                return node->key;
            else if (rank <= node->lc->siz)
                node = node->lc;
            else
                rank -= node->lc->siz + node->cnt, node = node->rc;
        }
    }
    inline Key qryPre(Key key) {
        Node *node = root, *fa = nullptr;
        while (node != nullptr) {
            if (key <= node->key)
                node = node->lc;
            else
                fa = node, node = node->rc;
        }
        return fa->key;
    }
    inline Key qrySuc(Key key) {
        Node *node = root, *fa = nullptr;
        while (node != nullptr) {
            if (key < node->key)
                fa = node, node = node->lc;
            else
                node = node->rc;
        }
        return fa->key;
    }
};