跳到主要内容
C++ 红黑树封装笔记:从泛型复用看 Map 与 Set 的实现 | 极客日志
C++ 算法
C++ 红黑树封装笔记:从泛型复用看 Map 与 Set 的实现 通过分析STL源码,展示了如何用模板和仿函数让同一份红黑树代码同时支持map和set——用不同Value类型搭配KeyOfT提取键值。重点讨论了迭代器中序遍历中operator++的回溯逻辑,以及插入平衡后如何利用返回值实现operator[]。最后给出完整可运行代码。
参考文档,用的时候翻翻就行:
set 和 multiset:set 、multiset
map 和 multimap:map 、multimap
我自己封装红黑树时发现,平衡逻辑本身并不难,难的是把结构捋清楚——尤其是一棵红黑树怎么同时支撑 set 和 map。翻一版老一点的 STL 源码(比如 SGI STL 30,太新的版本改动太多反而不易读)能省很多力气。重点看 stl_tree.h、stl_map.h、stl_set.h 这几个文件就好。
泛型设计的关键:一个 Value 类型承载两种语义
rb_tree 的模板参数通常是这样:
template <class Key , class Value , class KeyOfValue , ...>
class rb_tree ;
第一个参数 Key 单纯给 find、erase 这类接口用,第二个参数 Value 才是节点真正存的东西。看名字会有点绕:源码里管这个 Value 叫 value_type,但它的含义是'节点内部数据的类型',未必是你平时说的 value。
对于 set,value_type 就是 Key,节点直接存键。
对于 map,value_type 是 pair<const Key, T>,节点存的是键值对。
这样,同一份红黑树代码既能处理单一的键,也能处理键值对,区别只在于实例化时传的 Value 不同。后面要比较键的时候,就需要通过一个仿函数 KeyOfValue 从 Value 中把 Key 抠出来。这就是 MapKeyOfT / SetKeyOfT 存在的意义——向上层屏蔽掉 Value 的内部结构。
{
{
kv.first;
}
};
{
{
key;
}
};
struct
MapKeyOfT
const K& operator () (const pair<K, V>& kv)
return
struct
SetKeyOfT
const K& operator () (const K& key)
return
红黑树内部但凡需要比较键的地方,一律用 KeyOfValue 而不是直接操作 Value,这样两种容器就能复用同一套插入、查找逻辑了。
迭代器:用中序遍历逻辑重载 operator++ / -- 迭代器本质上就是封装一个节点指针,再重载 *、->、==、!= 以及 ++、--。难点在 ++ 和 --,因为它们要模拟中序遍历的移动顺序。
中序顺序是'左子树→根→右子树'。对一个当前节点,++ 要找到下一个访问的节点:
右子树非空 :下一个节点就是右子树中最靠左的那个节点(右子树的中序第一个)。
右子树为空 :说明以当前节点为根的子树已经访问完了。沿 parent 指针回溯,找第一个是父亲左孩子的祖先,这个祖先的父亲就是下一个节点。如果一直回到根都没有满足条件的,说明整棵树都访问完了,返回 end()。
end() 用 nullptr 表示。--end() 时需要取出整棵树的最右节点,逻辑正好和上面相反。
对于 set,迭代器不应允许修改元素,因此 Ref 和 Ptr 都用 const K 相关类型;map 的 key 不可改但 value 可改,所以用 pair<const K, V>。这些通过模板参数控制就行。
红黑树的核心操作:插入 红黑树的插入平衡逻辑大致就是分 uncle 红/黑两类情况,旋转和变色大家都熟,不再啰嗦。唯一要注意的是插入成功时要返回 pair<Iterator, bool>,方便 map 的 operator[] 使用。
operator[] 的实现有了 Insert 返回迭代器,operator[] 就很简单:
V& operator [](const K& key) {
auto [it, flag] = _t .Insert ({ key, V () });
return it->second;
}
先尝试插入一个 key 对应的默认值,如果 key 已存在则 Insert 返回已有节点,无论哪种情况都能拿到迭代器,然后返回其 second 的引用。这样既支持插入新元素,也能修改已有元素的值。
完整代码 下面是封装的完整代码。为了可读性,我把红黑树节点定义省略了,只给出迭代器和红黑树模板,以及 map、set 的部分。
红黑树及迭代器 #include <utility>
using namespace std;
template <class T , class Ref , class Ptr >
struct RBTreeIterator
{
typedef RBTreeNode<T> Node;
typedef RBTreeIterator<T, Ref, Ptr> Self;
Node* _node;
RBTreeIterator (Node* node) :_node(node) { }
Ref operator *() { return _node->_data; }
Ptr operator ->() { return &_node->_data; }
Self& operator ++()
{
if (_node->_right)
{
Node* minRight = _node->_right;
while (minRight->_left)
{
minRight = minRight->_left;
}
_node = minRight;
}
else
{
Node* cur = _node;
Node* parent = cur->_parent;
while (parent && cur == parent->_right)
{
cur = parent;
parent = parent->_parent;
}
_node = parent;
}
return *this ;
}
bool operator !=(const Self& s) { return _node != s._node; }
bool operator ==(const Self& s) { return _node == s._node; }
};
template <class K , class T , class KeyOfT >
struct RBTree
{
typedef RBTreeNode<T> Node;
public :
typedef RBTreeIterator<T, T&, T*> Iterator;
typedef RBTreeIterator<T, const T&, const T*> ConstIterator;
~RBTree ()
{
Destroy (_root);
_root = nullptr ;
}
Iterator Begin ()
{
Node* minLeft = _root;
while (minLeft && minLeft->_left)
{
minLeft = minLeft->_left;
}
return Iterator (minLeft);
}
Iterator End () { return Iterator (nullptr ); }
ConstIterator Begin () const
{
Node* minLeft = _root;
while (minLeft && minLeft->_left)
{
minLeft = minLeft->_left;
}
return ConstIterator (minLeft);
}
ConstIterator End () const { return ConstIterator (nullptr ); }
pair<Iterator, bool > Insert (const T& data)
{
if (_root == nullptr )
{
_root = new Node (data);
_root->_col = BLACK;
return { Iterator (_root), true };
}
KeyOfT kot;
Node* parent = nullptr ;
Node* cur = _root;
while (cur)
{
if (kot (cur->_data) < kot (data))
{
parent = cur;
cur = cur->_right;
}
else if (kot (data) < kot (cur->_data))
{
parent = cur;
cur = cur->_left;
}
else
{
return { Iterator (cur), false };
}
}
cur = new Node (data);
Node* newnode = cur;
cur->_col = RED;
if (kot (parent->_data) < kot (data))
{
parent->_right = cur;
}
else
{
parent->_left = cur;
}
cur->_parent = parent;
while (parent && parent->_col == RED)
{
Node* grandfather = parent->_parent;
if (grandfather->_left == parent)
{
Node* uncle = grandfather->_right;
if (uncle && uncle->_col == RED)
{
parent->_col = uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_left)
{
RotateR (grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else
{
RotateL (parent);
RotateR (grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break ;
}
}
else
{
Node* uncle = grandfather->_left;
if (uncle && uncle->_col == RED)
{
parent->_col = uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_right)
{
RotateL (grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else
{
RotateR (parent);
RotateL (grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break ;
}
}
}
_root->_col = BLACK;
return { Iterator (newnode), true };
}
void RotateR (Node* parent)
{
Node* subL = parent->_left;
Node* subLR = subL->_right;
parent->_left = subLR;
if (subLR) subLR->_parent = parent;
Node* parentParent = parent->_parent;
subL->_right = parent;
parent->_parent = subL;
if (parent == _root)
{
_root = subL;
subL->_parent = nullptr ;
}
else
{
if (parentParent->_left == parent)
{
parentParent->_left = subL;
}
else
{
parentParent->_right = subL;
}
subL->_parent = parentParent;
}
}
void RotateL (Node* parent)
{
Node* subR = parent->_right;
Node* subRL = subR->_left;
parent->_right = subRL;
if (subRL) subRL->_parent = parent;
Node* parentParent = parent->_parent;
subR->_left = parent;
parent->_parent = subR;
if (parentParent == nullptr )
{
_root = subR;
subR->_parent = nullptr ;
}
else
{
if (parent == parentParent->_left)
{
parentParent->_left = subR;
}
else
{
parentParent->_right = subR;
}
subR->_parent = parentParent;
}
}
Iterator Find (const K& key)
{
KeyOfT kot;
Node* cur = _root;
while (cur)
{
if (kot (cur->_data) < key)
{
cur = cur->_right;
}
else if (kot (cur->_data) > key)
{
cur = cur->_left;
}
else
{
return Iterator (cur);
}
}
return End ();
}
int Height () { return _Height(_root); }
int Size () { return _Size(_root); }
private :
int _Size(Node* root)
{
if (root == nullptr ) return 0 ;
return _Size(root->_left) + _Size(root->_right) + 1 ;
}
int _Height(Node* root)
{
if (root == nullptr ) return 0 ;
int leftHeight = _Height(root->_left);
int rightHeight = _Height(root->_right);
return leftHeight > rightHeight ? leftHeight + 1 : rightHeight + 1 ;
}
void Destroy (Node* root)
{
if (root == nullptr ) return ;
Destroy (root->_left);
Destroy (root->_right);
delete root;
}
private :
Node* _root = nullptr ;
};
Map #pragma once
#include "RBTree.h"
namespace jqj
{
template <class K , class V >
class map
{
struct MapKeyOfT
{
const K& operator () (const pair<K, V>& kv)
{
return kv.first;
}
};
public :
typedef typename RBTree<K, pair<const K, V>, MapKeyOfT>::Iterator iterator;
typedef typename RBTree<K, pair<const K, V>, MapKeyOfT>::ConstIterator const_iterator;
iterator begin () { return _t .Begin (); }
iterator end () { return _t .End (); }
const_iterator begin () const { return _t .Begin (); }
const_iterator end () const { return _t .End (); }
pair<iterator, bool > insert (const pair<K, V>& kv)
{
return _t .Insert (kv);
}
iterator find (const K& key) { return _t .Find (key); }
V& operator [](const K& key)
{
auto [it, flag] = _t .Insert ({ key, V () });
return it->second;
}
private :
RBTree<K, pair<const K, V>, MapKeyOfT> _t ;
};
}
Set #pragma once
#include "RBTree.h"
namespace jqj
{
template <class K >
class set
{
struct SetKeyOfT
{
const K& operator () (const K& key)
{
return key;
}
};
public :
typedef typename RBTree<K, const K, SetKeyOfT>::Iterator iterator;
typedef typename RBTree<K, const K, SetKeyOfT>::ConstIterator const_iterator;
iterator begin () { return _t .Begin (); }
iterator end () { return _t .End (); }
const_iterator begin () const { return _t .Begin (); }
const_iterator end () const { return _t .End (); }
pair<iterator, bool > insert (const K& key)
{
return _t .Insert (key);
}
iterator find (const K& key) { return _t .Find (key); }
private :
RBTree<K, const K, SetKeyOfT> _t ;
};
}
测试 #define _CRT_SECURE_NO_WARNINGS 1
#include <iostream>
#include <vector>
using namespace std;
#include "RBTree.h"
#include "Map.h"
#include "Set.h"
template <class T>
void func (const jqj::set<T>& s)
{
typename jqj::set<T>::const_iterator it = s.begin ();
while (it != s.end ())
{
cout << *it << " " ;
++it;
}
cout << endl;
}
void Test_set ()
{
jqj::set<int > s;
s.insert (1 ); s.insert (2 ); s.insert (1 ); s.insert (5 );
s.insert (0 ); s.insert (10 ); s.insert (8 );
jqj::set<int >::iterator it = s.begin ();
while (it != s.end ())
{
cout << *it << " " ;
++it;
}
cout << endl;
func (s);
}
void Test_map ()
{
jqj::map<string, string> dict;
dict.insert ({ "sort" , "排序" });
dict.insert ({ "left" , "左边" });
dict.insert ({ "right" , "右边" });
dict["string" ] = "字符串" ;
dict["left" ] = "左边 xxx" ;
auto it = dict.begin ();
while (it != dict.end ())
{
it->second += 'x' ;
cout << it->first << ":" << it->second << endl;
++it;
}
cout << endl;
for (auto & [k, v] : dict)
{
cout << k << ":" << v << endl;
}
cout << endl;
string arr[] = { "苹果" , "西瓜" , "苹果" , "西瓜" , "苹果" , "苹果" , "西瓜" ,
"苹果" , "香蕉" , "苹果" , "香蕉" };
jqj::map<string, int > countMap;
for (auto & e : arr)
{
countMap[e]++;
}
for (auto & [k, v] : countMap)
{
cout << k << ":" << v << endl;
}
cout << endl;
}
int main ()
{
Test_set ();
Test_map ();
return 0 ;
}
运行结果符合预期:set 输出有序且去重,map 正确统计了词频,operator[] 也能修改已有值。至此,一套基于红黑树的 map 和 set 就可以跑起来了。
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
Gemini 图片去水印 基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
Base64 字符串编码/解码 将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
Base64 文件转换器 将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online
Markdown转HTML 将 Markdown(GFM)转为 HTML 片段,浏览器内 marked 解析;与 HTML转Markdown 互为补充。 在线工具,Markdown转HTML在线工具,online
HTML转Markdown 将 HTML 片段转为 GitHub Flavored Markdown,支持标题、列表、链接、代码块与表格等;浏览器内处理,可链接预填。 在线工具,HTML转Markdown在线工具,online