C++ 基于红黑树封装 map 与 set 详解
在完成了红黑树的基础实现后,接下来需要处理一些关键细节,主要是 const 迭代器的控制以及 map 特有的 operator[] 实现。这部分内容往往容易踩坑,尤其是涉及模板参数和类型转换时。
一、红黑树封装中的 const 迭代器
1. 核心思路
要实现 const 迭代器,关键在于通过模板参数控制 operator* 和 operator-> 的返回值类型。普通迭代器返回引用或指针,允许修改;const 迭代器则返回 const 引用或指针,禁止修改。
我们需要在迭代器结构体中定义三个模板参数:节点类型、指针类型、引用类型。通过传入不同的类型实例化,就能区分普通迭代器和 const 迭代器。
2. Set 的实现
Set 容器存储的是 key,且 key 本身不可变。因此,无论是普通迭代器还是 const 迭代器,其底层都应该是红黑树的 const 迭代器。这样无论怎么访问,都无法修改 key 的值,符合 Set 的设计语义。
在代码层面,我们将 iterator typedef 为红黑树的 const_iterator。注意这里需要使用 typename 关键字,因为它是依赖类型。
3. Map 的特殊处理
Map 存储的是 pair<key, value>。与 Set 不同,Map 的 key 不可变,但 value 是可变的。如果直接将 Map 的普通迭代器也设为 const 迭代器,那么 value 也无法修改了,这不符合需求。
解决方案是在 Map 内部将存储类型定义为 pair<const K, V>。这样 key 被强制为 const,而 value 保持可变。迭代器逻辑可以复用红黑树的普通迭代器,自然就能满足 key 只读、value 可写的要求。
二、operator[] 的实现
Map 需要提供 operator[] 来支持通过下标访问和插入。它的行为逻辑是:
- 如果 key 存在,返回对应的 value 引用。
- 如果 key 不存在,插入一个默认构造的 value,并返回该 value 引用。
这个功能必须借助 insert 方法来实现。为了配合 operator[],insert 方法的返回值最好是一个 pair<iterator, bool>,其中 iterator 指向新插入或已存在的元素,bool 表示是否插入成功。
1. Insert 的返回值调整
之前的 insert 可能只返回 void 或者简单的状态,现在需要调整为返回 pair。同时,由于 Map 和 Set 对迭代器的定义不同,insert 在 Set 中可能会遇到编译问题。
2. Set 中的编译问题
当我们在 Set 中调用 insert 并尝试获取返回的 iterator 时,会发现编译报错。这是因为 Set 的 iterator 别名实际上是 const_iterator,而 insert 返回的是红黑树内部的普通迭代器(用于内部操作)。
解决思路是接收红黑树内部的迭代器,然后将其转换为 Set 定义的 iterator(即 const_iterator)。这需要提供一个构造函数,允许用普通迭代器初始化 const 迭代器。
// Set 中的 insert 实现
pair<iterator, bool> insert(const K& key) {
// 先调用底层红黑树的 insert
pair<typename RBTree<K, K, SetKeyOfT>::iterator, bool> ret = _t.Insert(key);
// 将底层迭代器转换为 Set 的 iterator (const_iterator)
return pair<iterator, bool>(ret.first, ret.second);
}
同时需要在迭代器类中添加接受普通迭代器的构造函数,以完成隐式转换。
三、Map 的 operator[] 实现
有了上述基础,Map 的 operator[] 实现就很简单了:
V& operator[](const K& key) {
// 尝试插入 key-value,value 默值构造
pair<iterator, bool> ret = insert(make_pair(key, V()));
// 返回 value 的引用
return ret.first->second;
}
测试代码验证了插入和修改的功能均正常。对于 Set,我们同样验证了迭代器遍历和插入功能,确保 key 无法被修改。
四、完整源码参考
以下是整理后的核心代码结构,包含红黑树、Set 和 Map 的封装。
RBTree.h
#pragma once
#include <iostream>
using namespace std;
enum Colour { RED, BLACK };
template<class T>
struct RBTreeNode {
RBTreeNode<T>* _left;
RBTreeNode<T>* _right;
RBTreeNode<T>* _parent;
Colour _col;
T _data;
RBTreeNode(const T& data) : _left(nullptr), _right(nullptr), _parent(nullptr), _col(RED), _data(data) {}
};
template<class T, class Ptr, class Ref>
struct __TreeIterator {
typedef RBTreeNode<T> Node;
typedef __TreeIterator<T, Ptr, Ref> Self;
// 解决 insert 中 set 的问题,提供从普通迭代器到 const 迭代器的转换
typedef __TreeIterator<T, T*, T&> iterator;
__TreeIterator(const iterator& it) : _node(it._node) {};
Node* _node;
__TreeIterator(Node* node) : _node(node) {}
Ref operator*() { return _node->_data; }
Ptr operator->() { return &(_node->_data); }
bool operator!=(const Self& s) { return _node != s._node; }
bool operator==(const Self& s) const { return _node == s._node; }
Self& operator++() {
if (_node->_right != nullptr) {
Node* curleft = _node->_right;
while (curleft->_left) curleft = curleft->_left;
_node = curleft;
} else {
Node* cur = _node;
Node* parent = _node->_parent;
while (parent) {
if (parent->_left == cur) break;
else { cur = parent; parent = parent->_parent; }
}
_node = parent;
}
return *this;
}
Self& operator--() {
if (_node->_left) {
Node* subRight = _node->_left;
while (subRight->_right) subRight = subRight->_right;
_node = subRight;
} else {
Node* cur = _node;
Node* parent = cur->_parent;
while (parent && cur == parent->_left) {
cur = cur->_parent;
parent = parent->_parent;
}
_node = parent;
}
return *this;
}
};
template<class K, class T, class KeyOfT>
class RBTree {
typedef RBTreeNode<T> Node;
public:
typedef __TreeIterator<T, T*, T&> iterator;
typedef __TreeIterator<T, const T*, const T&> const_iterator;
iterator begin() {
Node* leftMin = _root;
while (leftMin && leftMin->_left) leftMin = leftMin->_left;
return iterator(leftMin);
}
iterator end() { return iterator(nullptr); }
const_iterator begin() const {
Node* leftMin = _root;
while (leftMin && leftMin->_left) leftMin = leftMin->_left;
return const_iterator(leftMin);
}
const_iterator end() const { return const_iterator(nullptr); }
Node* Find(const K& key) {
Node* cur = _root;
KeyOfT kot;
while (cur) {
if (kot(cur->_data) < key) cur = cur->_right;
else if (kot(cur->_data) > key) cur = cur->_left;
else return cur;
}
return nullptr;
}
pair<iterator, bool> Insert(const T& data) {
if (_root == nullptr) {
_root = new Node(data);
_root->_col = BLACK;
return make_pair(iterator(_root), true);
}
Node* cur = _root;
Node* parent = nullptr;
KeyOfT kot;
while (cur) {
if (kot(data) < kot(cur->_data)) {
parent = cur; cur = cur->_left;
} else if (kot(data) > kot(cur->_data)) {
parent = cur; cur = cur->_right;
} else {
return make_pair(iterator(cur), false);
}
}
cur = new Node(data);
cur->_col = RED;
Node* newnode = cur;
if (kot(cur->_data) < kot(parent->_data)) parent->_left = cur;
else parent->_right = cur;
cur->_parent = parent;
while (parent && parent->_col == RED) {
Node* grandfather = parent->_parent;
if (parent == grandfather->_left) {
Node* uncle = grandfather->_right;
if (uncle && uncle->_col == RED) {
parent->_col = BLACK;
uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
} else {
if (cur == parent->_left) {
RotateR(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
} else {
RotateL(parent);
RotateR(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
} else {
Node* uncle = grandfather->_left;
if (uncle && uncle->_col == RED) {
parent->_col = BLACK;
uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
} else {
if (cur == parent->_right) {
RotateL(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
} else {
RotateR(parent);
RotateL(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
}
_root->_col = BLACK;
return make_pair(iterator(newnode), true);
}
void RotateL(Node* parent) {
Node* cur = parent->_right;
Node* curleft = cur->_left;
parent->_right = curleft;
if (curleft) curleft->_parent = parent;
cur->_left = parent;
Node* ppnode = parent->_parent;
parent->_parent = cur;
if (ppnode == nullptr) {
_root = cur;
cur->_parent = nullptr;
} else {
if (ppnode->_left == parent) ppnode->_left = cur;
else ppnode->_right = cur;
cur->_parent = ppnode;
}
}
void RotateR(Node* parent) {
Node* cur = parent->_left;
Node* curright = cur->_right;
parent->_left = curright;
if (curright) curright->_parent = parent;
cur->_right = parent;
Node* ppnode = parent->_parent;
parent->_parent = cur;
if (ppnode == nullptr) {
cur->_parent = nullptr;
_root = cur;
} else {
if (ppnode->_left == parent) ppnode->_left = cur;
else ppnode->_right = cur;
cur->_parent = ppnode;
}
}
bool CheckColour(Node* root, int blacknum, int benchmark) {
if (root == nullptr) {
if (blacknum != benchmark) return false;
return true;
}
if (root->_col == BLACK) ++blacknum;
if (root->_col == RED && root->_parent && root->_parent->_col == RED) {
cout << "出现连续红色节点" << endl;
return false;
}
return CheckColour(root->_left, blacknum, benchmark) && CheckColour(root->_right, blacknum, benchmark);
}
bool IsBalance() { return IsBalance(_root); }
bool IsBalance(Node* root) {
if (root == nullptr) return true;
if (root->_col != BLACK) return false;
int benchmark = 0;
Node* cur = _root;
while (cur) {
if (cur->_col == BLACK) ++benchmark;
cur = cur->_left;
}
return CheckColour(root, 0, benchmark);
}
private:
Node* _root = nullptr;
};
myset.h
#pragma once
#include "RBTree.h"
namespace jyf {
template<class K>
class set {
struct SetKeyOfT {
const K& operator()(const K& key) { return key; }
};
public:
typedef typename RBTree<K, K, SetKeyOfT>::const_iterator iterator;
typedef typename RBTree<K, K, SetKeyOfT>::const_iterator const_iterator;
iterator begin() const { return _t.begin(); }
iterator end() const { return _t.end(); }
pair<iterator, bool> insert(const K& key) {
pair<typename RBTree<K, K, SetKeyOfT>::iterator, bool> ret = _t.Insert(key);
return pair<iterator, bool>(ret.first, ret.second);
}
private:
RBTree<K, K, SetKeyOfT> _t;
};
}
mymap.h
#pragma once
#include "RBTree.h"
namespace jyf {
template<class K, class V>
class map {
struct MapKeyOfT {
const K& operator()(const pair<const K, V>& kv) { return kv.first; }
};
public:
typedef typename RBTree<K, pair<const K, V>, MapKeyOfT>::iterator iterator;
typedef typename RBTree<K, pair<const K, V>, MapKeyOfT>::const_iterator const_iterator;
iterator begin() { return _t.begin(); }
iterator end() { return _t.end(); }
const_iterator begin() const { return _t.begin(); }
const_iterator end() const { return _t.end(); }
V& operator[](const K& key) {
pair<iterator, bool> ret = insert(make_pair(key, V()));
return ret.first->second;
}
pair<iterator, bool> insert(const pair<K, V>& kv) {
return _t.Insert(kv);
}
private:
RBTree<K, pair<const K, V>, MapKeyOfT> _t;
};
}
test.cpp
#include <iostream>
#include <string>
using namespace std;
#include "mymap.h"
#include "myset.h"
int main() {
jyf::map<int, int> m;
m.insert(make_pair(1, 1));
m.insert(make_pair(3, 3));
m.insert(make_pair(2, 2));
auto mit = m.begin();
while (mit != m.end()) {
// 不能修改 key,可以修改 value
// mit->first = 1; mit->second = 2;
cout << mit->first << ":" << mit->second << endl;
++mit;
}
cout << endl;
for (const auto& kv : m) {
cout << kv.first << ":" << kv.second << endl;
}
cout << endl;
jyf::set<int> s;
s.insert(5); s.insert(2); s.insert(2);
s.insert(12); s.insert(22); s.insert(332); s.insert(7);
jyf::set<int>::iterator it = s.begin();
while (it != s.end()) {
// 不应该允许修改 key
// if (*it % 2 == 0) { *it += 10; }
cout << *it << " ";
++it;
}
cout << endl;
jyf::map<string, string> dict;
dict.insert(make_pair("sort", "xxx"));
dict["left"]; // 插入
for (const auto& kv : dict) {
cout << kv.first << ":" << kv.second << endl;
}
cout << endl;
dict["left"] = "左边"; // 修改
dict["sort"] = "排序"; // 修改
dict["right"] = "右边"; // 插入 + 修改
for (const auto& kv : dict) {
cout << kv.first << ":" << kv.second << endl;
}
cout << endl;
return 0;
}
以上便是红黑树封装 map 与 set 的核心逻辑与完整代码。理解这些细节对于掌握 STL 容器的底层原理非常有帮助。


