跳到主要内容C++ 红黑树封装 map 和 set 详解 | 极客日志C++算法
C++ 红黑树封装 map 和 set 详解
综述由AI生成详细讲解了使用 C++ 红黑树实现标准库中 map 和 set 容器的过程。内容包括红黑树节点结构定义、仿函数 KeyOfT 的设计以支持 map 的 key 比较、插入与旋转逻辑、迭代器的中序遍历实现(++/--操作)、以及 map 的 [] 重载原理。提供了完整的源码示例,涵盖 RBTree、set 和 map 类的模板实现。
技术博主22 浏览 1、map 和 set 的整体框架
因为 map 和 set 的底层都是红黑树,所以我们考虑用一个红黑树的类模版去实例化 map 和 set 对象。不过,map 节点中存储的是一个 pair 对象,而 set 中存储的是一个 key 对象。所以我们第一步就是先调整一下我们之前实现的红黑树的节点结构。
template<class T>
struct RBTreeNode {
RBTreeNode(const T& data) :_data(data), _parent(nullptr), _left(nullptr), _right(nullptr), _col(RED) {}
T _data;
RBTreeNode* _parent;
RBTreeNode* _left;
RBTreeNode* _right;
Colour _col;
};
无论是 set 还是 map 的查找都是根据 key 来进行的,所以我们的红黑树类模版的第一个参数类型是接受上层传递过来的 K 类型。
再有,set 和 map 的 insert(插入),一个是 pair 类型,一个是 key 类型。所以我们的红黑树类模版的第二个参数类型是接受上层传递过来的 T 类型。
最后,我们还要在上层传递一个仿函数用于底层红黑树查找和插入删除时用 key 比较。因为 set 直接插入一个 key 可以用于比较,但是 map 插入一个 pair,而 pair 支持的比较是 first 和 second 一起比较,而我们只希望用 key 比较。又因为底层并不知道上层是 set 还是 map,所以我们要在上层传递一个仿函数给到下层让下层用上层的仿函数拿到 key 来比较大小!
set 整体框架:
namespace my_set {
template<class K>
class set {
struct SetKeyOfT {
const K& operator()(const K& key) { return key; }
};
public:
bool insert(const K& key) { return _t.(key); }
:
RBTree<K, K, SetKeyOfT> ;
};
}
insert
private
_t
namespace my_map {
template<class K,class V>
class map {
struct MapKeyOfT {
const K& operator()(const pair<K,V>& kv) { return kv.first; }
};
public:
bool insert(const pair<K,V>& kv) { return _t.insert(kv); }
private:
RBTree<K, pair<K,V>, MapKeyOfT> _t;
};
}
enum Colour { RED, BLACK };
template<class T>
struct RBTreeNode {
RBTreeNode(const T& data) :_data(data), _parent(nullptr), _left(nullptr), _right(nullptr), _col(RED) {}
T _data;
RBTreeNode* _parent;
RBTreeNode* _left;
RBTreeNode* _right;
Colour _col;
};
template<class K, class T,class KeyOfT>
class RBTree {
typedef RBTreeNode<T> Node;
public:
KeyOfT kot;
RBTree() = default;
RBTree(const RBTree& rbt) { _root = _copy(rbt._root); }
RBTree& operator=(RBTree tmp) { std::swap(_root, tmp._root); return *this; }
~RBTree() { _Destroy(_root); }
Node* Find(const K& key) {
Node* cur = _root;
while (cur) {
if (kot(cur->_data) < kot(key)) {
cur = cur->_right;
} else if (kot(cur->_data) > kot(key)) {
cur = cur->_left;
} else {
return cur;
}
}
return nullptr;
}
bool insert(const T& data) {
if (_root == nullptr) {
_root = new Node(data);
_root->_col = BLACK;
return true;
}
Node* parent = nullptr;
Node* cur = _root;
while (cur) {
if (kot(data) < kot(cur->_data))
{
parent = cur;
cur = parent->_left;
}
else if (kot(data) > kot(cur->_data))
{
parent = cur;
cur = parent->_right;
}
else {
return false;
}
}
cur = new Node(data);
cur->_col = RED;
if (kot(data) < kot(parent->_data))
parent->_left = cur;
else
parent->_right = cur;
cur->_parent = parent;
while (parent && parent->_col == RED) {
Node* grandfather = parent->_parent;
if (parent == grandfather->_right) {
Node* uncle = grandfather->_left;
if (uncle && uncle->_col == RED)
{
uncle->_col = parent->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_right) {
RotateL(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else {
RotateR(parent);
RotateL(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
else
{
Node* uncle = grandfather->_right;
if (uncle && uncle->_col == RED)
{
uncle->_col = parent->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_left) {
RotateR(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else {
RotateL(parent);
RotateR(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
}
_root->_col = BLACK;
return true;
}
private:
void RotateR(Node* parent) {
Node* subL = parent->_left;
Node* subLR = subL->_right;
Node* pParent = parent->_parent;
parent->_left = subLR;
if (subLR)
subLR->_parent = parent;
subL->_right = parent;
parent->_parent = subL;
if (pParent == nullptr) {
_root = subL;
subL->_parent = nullptr;
}
else {
if (pParent->_left == parent) {
pParent->_left = subL;
}
else {
pParent->_right = subL;
}
subL->_parent = pParent;
}
}
void RotateL(Node* parent) {
Node* pParent = parent->_parent;
Node* subR = parent->_right;
Node* subRL = subR->_left;
subR->_left = parent;
parent->_parent = subR;
parent->_right = subRL;
if (subRL)
subRL->_parent = parent;
if (pParent == nullptr) {
_root = subR;
subR->_parent = nullptr;
}
else {
if (pParent->_left == parent) {
pParent->_left = subR;
}
else {
pParent->_right = subR;
}
subR->_parent = pParent;
}
}
Node* _copy(Node* root) {
if (root == nullptr) return nullptr;
Node* newNode = new Node(root->_data);
newNode->_left = _copy(root->_left);
newNode->_right = _copy(root->_right);
return newNode;
}
void _Destroy(Node* root) {
if (root == nullptr) return;
_Destroy(root->_left);
_Destroy(root->_right);
delete root;
}
private:
Node* _root = nullptr;
};
2、map 和 set 迭代器的实现
map 和 set 迭代器的实现其实思路与 list 迭代器的实现几乎一样,它们都是由一个一个的节点组成的。所以我们都是通过封装一个节点的指针,重载*、->、++、--、比较等运算符!
template<class T,class Ref,class Ptr>
struct RBTreeIterator {
typedef RBTreeNode<T> Node;
typedef RBTreeIterator<T,Ref,Ptr> Self;
Node* _node;
Node* _root;
RBTreeIterator(Node* node,Node* root) :_node(node), _root(root) {}
Self& operator++() { }
Self& operator--() { }
Ref operator*() { return _node->_data; }
Ptr operator->() { return &_node->_data; }
bool operator!=(const Self& s) const { return _node != s._node; }
bool operator==(const Self& s) const { return _node == s._node; }
};
- iterator 实现的⼤框架跟 list 的 iterator 思路是⼀致的,⽤⼀个类型封装结点的指针,再通过重载运算 符实现,迭代器像指针⼀样访问的⾏为。
- 这⾥的难点是 operator++和operator--的实现。之前使⽤部分,我们分析了,map 和 set 的迭代器⾛ 的是中序遍历,左⼦树->根结点->右⼦树,那么 begin()会返回中序第⼀个结点的 iterator 也就是最左结点的迭代器。
- 迭代器++的核⼼逻辑就是不看全局,只看局部,只考虑当前中序局部要访问的下⼀个结点。
- 迭代器++时,如果 it 指向的结点的右⼦树不为空,代表当前结点已经访问完了,要访问下⼀个结点 是右⼦树的中序第⼀个,⼀棵树中序第⼀个是最左结点,所以直接找右⼦树的最左结点即可。
- 迭代器++时,如果 it 指向的结点的右⼦树空,代表当前结点已经访问完了且当前结点所在的⼦树也 访问完了,要访问的下⼀个结点在当前结点的祖先⾥⾯,所以要沿着当前结点到根的祖先路径向上 找。
- 如果当前结点是⽗亲的左,根据中序左⼦树->根结点->右⼦树,那么下⼀个访问的结点就是当前结 点的⽗亲;
- 如果当前结点是⽗亲的右,根据中序左⼦树->根结点->右⼦树,当前当前结点所在的⼦树访问完 了,当前结点所在⽗亲的⼦树也访问完了,那么下⼀个访问的需要继续往根的祖先中去找,直到找 到孩⼦是⽗亲左的那个祖先就是中序要问题的下⼀个结点。
- end()如何表⽰呢?当 it 指向最右节点时,++it 时,一直往上找,找不到孩子是父亲的左的那个祖先,这时⽗亲为空了,那我们就把 it 中的结点指针 置为 nullptr,我们⽤ nullptr 去充当 end。需要注意的是 stl 源码空,红⿊树增加了⼀个哨兵位头结点 做为 end(),这哨兵位头结点和根互为⽗亲,左指向最左结点,右指向最右结点。相⽐我们⽤ nullptr 作为 end(),差别不⼤,他能实现的,我们也能实现。只是--end()判断到结点时空,特殊处 理⼀下,让迭代器结点指向最右结点。具体参考迭代器--实现。
- 迭代器--的实现跟++的思路完全类似,逻辑正好反过来即可,因为他访问顺序是右⼦树->根结点-> 左⼦树。
- set 的 iterator 也不⽀持修改,我们把 set 的第⼆个模板参数改成 const K 即可。
- map 的 iterator 不⽀持修改 key 但是可以修改 value,我们把 map 的第⼆个模板参数 pair 的第⼀个参 数改成 const K 即可。
Self& operator++() {
if (_node->_right) {
Node* leftMost = _node->_right;
while (leftMost->_left) {
leftMost = leftMost->_left;
}
_node = leftMost;
}
else {
Node* cur = _node;
Node* parent = cur->_parent;
while (parent && cur == parent->_right) {
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
因为我们用空来代表数的末尾,但我们--时,如果此时迭代器恰好在 end()位置,那我们就要单独考虑这种情况。其他逻辑就和++相反!
要找到树的最右节点,那么就还必须知道根节点_root!所以我们在红黑树中构造一个迭代器时,不仅要传递当前位置的节点,还要传根节点!
Self& operator--() {
if (_node == nullptr) {
Node* rightMost = _root;
while (rightMost&&rightMost->_right) {
rightMost = rightMost->_right;
}
_node = rightMost;
}
else if (_node->_left) {
Node* rightMost = _node->_left;
while (rightMost->_right) {
rightMost = rightMost->_right;
}
_node = rightMost;
}
else {
Node* cur = _node;
Node* parent = cur->_parent;
while (parent && cur == parent->_left) {
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
typedef RBTreeNode<T> Node;
typedef RBTreeIterator<T, T&, T*> Iterator;
typedef RBTreeIterator<T, const T&, const T*> ConstIterator;
Iterator Begin() {
Node* leftMost = _root;
while (leftMost && leftMost->_left) {
leftMost = leftMost->_left;
}
return Iterator(leftMost, _root);
}
Iterator End() {
return Iterator(nullptr, _root);
}
ConstIterator CBegin() const {
Node* leftMost = _root;
while (leftMost && leftMost->_left) {
leftMost = leftMost->_left;
}
return Iterator(leftMost, _root);
}
ConstIterator CEnd() const {
return Iterator(nullptr, _root);
}
3、map 支持 []
map 的 [] 重载主要是复用了底层红黑树当中的 insert 接口,insert 充当查找和插入的功能!
我们在使用 [] 时,如果 map 中有和我们外面传递的 key 相同的 key,那 insert 就会插入失败并且返回这个 key 的迭代器。如果不存在,那 insert 就会插入成功并且返回这个 key 的迭代器。所以我们要调整一下 insert 的返回值为 pair<iterator,bool>
而 map 的 [] 还支持修改(key 所映射的 val)的功能,insert 返回的迭代器在修改 val 时发挥作用!
V& operator[](const K& key) {
pair<iterator, bool> ret = _t.Insert(make_pair(key, V()));
return ret.first->second;
}
4、完整源码
set.h
#pragma once
#include"RBTree.h"
namespace my_set {
template<class K>
class set {
struct SetKeyOfT {
const K& operator()(const K& key) { return key; }
};
public:
typedef typename RBTree<K, const K, SetKeyOfT>::Iterator iterator;
typedef typename RBTree<K, const K, SetKeyOfT>::ConstIterator const_iterator;
iterator begin() { return _t.Begin(); }
iterator end() { return _t.End(); }
const_iterator cbegin() const { return _t.CBegin(); }
const_iterator cend() const { return _t.CEnd(); }
pair<iterator,bool> insert(const K& key) { return _t.Insert(key); }
iterator find(const K& key) { return _t.Find(key); }
private:
RBTree<K, const K, SetKeyOfT> _t;
};
}
map.h
#pragma once
#include"RBTree.h"
namespace my_map {
template<class K,class V>
class map {
struct MapKeyOfT {
const K& operator()(const pair<K,V>& kv) { return kv.first; }
};
public:
typedef typename RBTree<K, pair<const K, V>, MapKeyOfT>::Iterator iterator;
typedef typename RBTree<K, pair<const K, V>, MapKeyOfT>::ConstIterator const_iterator;
iterator begin() { return _t.Begin(); }
iterator end() { return _t.End(); }
const_iterator cbegin() const { return _t.CBegin(); }
const_iterator cend() const { return _t.CEnd(); }
pair<iterator, bool> insert(const pair<K,V>& kv) { return _t.Insert(kv); }
iterator find(const K& key) { return _t.Find(key); }
V& operator[](const K& key) {
pair<iterator, bool> ret = _t.Insert(make_pair(key, V()));
return ret.first->second;
}
private:
RBTree<K, pair<const K,V>, MapKeyOfT> _t;
};
}
RBTree.h
#pragma once
#include<iostream>
using namespace std;
enum Colour { RED, BLACK };
template<class T>
struct RBTreeNode {
RBTreeNode(const T& data) :_data(data), _parent(nullptr), _left(nullptr), _right(nullptr), _col(RED) {}
T _data;
RBTreeNode* _parent;
RBTreeNode* _left;
RBTreeNode* _right;
Colour _col;
};
template<class T,class Ref,class Ptr>
struct RBTreeIterator {
typedef RBTreeNode<T> Node;
typedef RBTreeIterator<T,Ref,Ptr> Self;
Node* _node;
Node* _root;
RBTreeIterator(Node* node,Node* root) :_node(node), _root(root) {}
Self& operator++() {
if (_node->_right) {
Node* leftMost = _node->_right;
while (leftMost->_left) {
leftMost = leftMost->_left;
}
_node = leftMost;
}
else {
Node* cur = _node;
Node* parent = cur->_parent;
while (parent && cur == parent->_right) {
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
Self& operator--() {
if (_node == nullptr) {
Node* rightMost = _root;
while (rightMost&&rightMost->_right) {
rightMost = rightMost->_right;
}
_node = rightMost;
}
else if (_node->_left) {
Node* rightMost = _node->_left;
while (rightMost->_right) {
rightMost = rightMost->_right;
}
_node = rightMost;
}
else {
Node* cur = _node;
Node* parent = cur->_parent;
while (parent && cur == parent->_left) {
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
Ref operator*() { return _node->_data; }
Ptr operator->() { return &_node->_data; }
bool operator!=(const Self& s) const { return _node != s._node; }
bool operator==(const Self& s) const { return _node == s._node; }
};
template<class K, class T,class KeyOfT>
class RBTree {
public:
typedef RBTreeNode<T> Node;
typedef RBTreeIterator<T, T&, T*> Iterator;
typedef RBTreeIterator<T, const T&, const T*> ConstIterator;
Iterator Begin() {
Node* leftMost = _root;
while (leftMost && leftMost->_left) {
leftMost = leftMost->_left;
}
return Iterator(leftMost, _root);
}
Iterator End() {
return Iterator(nullptr, _root);
}
ConstIterator CBegin() const {
Node* leftMost = _root;
while (leftMost && leftMost->_left) {
leftMost = leftMost->_left;
}
return Iterator(leftMost, _root);
}
ConstIterator CEnd() const {
return Iterator(nullptr, _root);
}
KeyOfT kot;
RBTree() = default;
RBTree(const RBTree& rbt) { _root = _copy(rbt._root); }
RBTree& operator=(RBTree tmp) { std::swap(_root, tmp._root); return *this; }
~RBTree() { _Destroy(_root); }
Node* Find(const K& key) {
Node* cur = _root;
while (cur) {
if (kot(cur->_data) < kot(key)) {
cur = cur->_right;
}
else if (kot(cur->_data) > kot(key)) {
cur = cur->_left;
}
else {
return cur;
}
}
return nullptr;
}
pair<Iterator,bool> Insert(const T& data) {
if (_root == nullptr) {
_root = new Node(data);
_root->_col = BLACK;
return make_pair(Iterator(_root,_root),true);
}
Node* parent = nullptr;
Node* cur = _root;
while (cur) {
if (kot(data) < kot(cur->_data))
{
parent = cur;
cur = parent->_left;
}
else if (kot(data) > kot(cur->_data))
{
parent = cur;
cur = parent->_right;
}
else {
return make_pair(Iterator(cur, _root), false);
}
}
cur = new Node(data);
Node* newnode = cur;
cur->_col = RED;
if (kot(data) < kot(parent->_data))
parent->_left = cur;
else
parent->_right = cur;
cur->_parent = parent;
while (parent && parent->_col == RED) {
Node* grandfather = parent->_parent;
if (parent == grandfather->_right) {
Node* uncle = grandfather->_left;
if (uncle && uncle->_col == RED)
{
uncle->_col = parent->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_right) {
RotateL(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else {
RotateR(parent);
RotateL(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
else
{
Node* uncle = grandfather->_right;
if (uncle && uncle->_col == RED)
{
uncle->_col = parent->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_left) {
RotateR(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else {
RotateL(parent);
RotateR(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
}
_root->_col = BLACK;
return make_pair(Iterator(newnode, _root), true);
}
private:
void RotateR(Node* parent) {
Node* subL = parent->_left;
Node* subLR = subL->_right;
Node* pParent = parent->_parent;
parent->_left = subLR;
if (subLR)
subLR->_parent = parent;
subL->_right = parent;
parent->_parent = subL;
if (pParent == nullptr) {
_root = subL;
subL->_parent = nullptr;
}
else {
if (pParent->_left == parent) {
pParent->_left = subL;
}
else {
pParent->_right = subL;
}
subL->_parent = pParent;
}
}
void RotateL(Node* parent) {
Node* pParent = parent->_parent;
Node* subR = parent->_right;
Node* subRL = subR->_left;
subR->_left = parent;
parent->_parent = subR;
parent->_right = subRL;
if (subRL)
subRL->_parent = parent;
if (pParent == nullptr) {
_root = subR;
subR->_parent = nullptr;
}
else {
if (pParent->_left == parent) {
pParent->_left = subR;
}
else {
pParent->_right = subR;
}
subR->_parent = pParent;
}
}
Node* _copy(Node* root) {
if (root == nullptr) return nullptr;
Node* newNode = new Node(root->_data);
newNode->_left = _copy(root->_left);
newNode->_right = _copy(root->_right);
return newNode;
}
void _Destroy(Node* root) {
if (root == nullptr) return;
_Destroy(root->_left);
_Destroy(root->_right);
delete root;
}
private:
Node* _root = nullptr;
};
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- Base64 字符串编码/解码
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
- Base64 文件转换器
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online
- Markdown转HTML
将 Markdown(GFM)转为 HTML 片段,浏览器内 marked 解析;与 HTML转Markdown 互为补充。 在线工具,Markdown转HTML在线工具,online
- HTML转Markdown
将 HTML 片段转为 GitHub Flavored Markdown,支持标题、列表、链接、代码块与表格等;浏览器内处理,可链接预填。 在线工具,HTML转Markdown在线工具,online