本文的 C++ 代码基于 C++ 20 标准(不包含 C++ modules),对于之前的标准,可能需要做一些适配。
CART 分类树和回归树的内容各自在一个类中,分类树为 CartClassifier 类,回归树为 CartRegression 类。
数据结构设计
二叉树设计
// 二叉树结点
struct BinTreeNode {
std::string threshold_str_;
double threshold_ = -1;
std::string feature_name_;
std::shared_ptr<BinTreeNode> left_ = nullptr;
std::shared_ptr<BinTreeNode> right_ = nullptr;
[[nodiscard]] std::shared_ptr<BinTreeNode> copy() const {
auto node = std::make_shared<BinTreeNode>();
node->threshold_ = threshold_;
node->threshold_str_ = threshold_str_;
node->feature_name_ = feature_name_;
if(left_) node->left_ = left_->copy();
if(right_) node->right_ = right_->copy();
return node;
}
};
copy 模块用于二叉树结点的深复制,包括复制本身及其所有的子结点。
结点信息设计
struct Info {
std::shared_ptr<BinTreeNode> tree_;
size_t num_leaf_ = 0;
double a = 0;
std::pair<bool, std::string> key_str_{};
std::pair<bool, double> key_{};
};
实际上结点信息可以直接存储到二叉树结点 BinTreeNode 中。分开是为了保证代码的语义清晰,易于理解。
分类树
训练
/**
* @brief
* 训练决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @param feature_names 属性名
* @return 生成的决策树
*/
shared_ptr<BinTreeNode> CartClassifier::train(const vector<vector<string>>& X,
const vector<string>& y,
const vector<string>& feature_names) {
feature_names_ = feature_names;
// 创建 CART 决策树
tree_ = create_tree(X, y);
return tree_;
}
训练函数通过传递常量引用形参,防止训练集和属性集被篡改。如需要修改,可以在函数内部设置副本,针对副本进行修改。create_tree 是创建 CART 分类树的核心函数。
/**
* @brief
* 创建树
* @param X 训练集属性值
* @param y 训练集目标变量
* @return 训练好的决策树
*/
shared_ptr<BinTreeNode> CartClassifier::create_tree(const vector<vector<string>>& X,
const vector<string>& y) {
// 若 X 中样本全属于同一类别 C,则停止划分
auto tree = make_shared<BinTreeNode>();
if (unordered_set(y.begin(), y.end()).size() == 1) {
tree->threshold_str_ = y.front();
return tree;
}
// 若节点样本数小于 min_samples_split,或者属性集上的取值均相同
if (y.size() <= min_samples_split_ || set(X.begin(), X.end()).size() == 1) {
tree->threshold_str_ = majority_y(y);
return tree;
}
// 按照'基尼增益',从属性值中选择最优分裂属性的最优切分点
auto [best_split_point, best_feature_index] = choose_best_point_to_split(X, y);
const string best_feature_name = feature_names_[best_feature_index];
// 根据最优切分点,进行子树的划分
vector<vector<string>> sub_X1, sub_X2;
vector<string> sub_y1, sub_y2;
for (int i = 0; i < X.size(); i++)
if (X[i][best_feature_index] == best_split_point) {
sub_X1.emplace_back(X[i]);
sub_y1.emplace_back(y[i]);
} else {
sub_X2.(X[i]);
sub_y(y[i]);
}
tree->feature_name_ = best_feature_name;
tree->threshold_str_ = best_split_point;
tree->left_ = (sub_X1, sub_y1);
tree->right_ = (sub_X2, sub_y2);
tree;
}
create_tree 函数是一个递归创建决策树的过程。首先判断三种递归中止条件:
X中样本全部属于同一类别;- 当前节点样本数小于
min_samples_split_; - 属性集上的取值均相同
若满足终止条件,则选择 y 中最多的类别作为结果返回。若未满足终止条件,依次执行以下步骤:
- 根据基尼指数从属性值中选择最优分裂属性的最优切分点,具体过程如
choose_best_point_to_split函数所示; - 根据最优切分点对子树进行划分;
- 对于其子树再继续执行
create_tree函数完成划分过程。
/**
* @brief
* 统计每个类别出现的次数,返回出现次数最大的类别 ID
* @param y 目标变量集合
* @return 出现次数最大的类别
*/
string CartClassifier::majority_y(const vector<string>& y) {
// 统计 y 中的目标变量值的个数
unordered_map<string, int> y_count;
for (const string& v : y) {
if (!y_count.contains(v)) y_count[v] = 0;
++y_count[v];
}
return ranges::max_element(y_count, [](const pair<string, int>& a, const pair<string, int>& b) {
return a.second < b.second;
})->first;
}
majority_y 用于计算节点中出现次数最多的类别。包含以下步骤:
- 初始化一个空映射;
- 遍历
y并对其元素进行计数; - 从映射中查找出现次数最多的类别。
/**
* @brief
* 选择最优切分点
* @param X 训练集属性值
* @param y 训练集目标变量
* @return 最优切分点和最优切分点所在属性的索引
*/
pair<string, int> CartClassifier::choose_best_point_to_split(const vector<vector<string>>& X,
const vector<string>& y) {
string best_split_point;
int best_feature_index = -1;
double best_gini_index = numeric_limits<double>::infinity();
const size_t num_feature = X[0].size();
// 属性的个数
for (int i = 0; i < num_feature; i++) // 遍历每个属性
{
// 得到某个属性下的所有值,即某列,并去重,得到无重复的属性特征值
unordered_set<string> split_points;
for (const vector<string>& x : X) split_points.emplace(x[i]);
for (const string& split_point : split_points) // 计算各个候选切分点的基尼不纯度
{
vector<string> sub_y_left, sub_y_right;
for (int j = 0; j < X.size(); j++)
if (X[j][i] == split_point) sub_y_left.emplace_back(y[j]);
else sub_y_right.emplace_back(y[j]);
// 计算左子树的基尼不纯度
const double gini_impurity_left = cal_gini_impurity(sub_y_left);
// 计算右子树的基尼不纯度
const double gini_impurity_right = (sub_y_right);
pro_left = <>(sub_y_left.()) / <>(y.()),
pro_right = <>(sub_y_right.()) / <>(y.());
( gini_index = (pro_left, pro_right, gini_impurity_left, gini_impurity_right);
best_gini_index > gini_index)
{
best_gini_index = gini_index;
best_feature_index = i;
best_split_point = split_point;
}
}
}
{best_split_point, best_feature_index};
}
choose_best_point_to_split 是 CART 分类树中最核心的函数,该函数负责选择最优切分点。根据前面的理论推导,该函数的目的是计算取得最大基尼增益的属性值。该函数遍历每个属性的每个属性值,根据是否等于属性值(二分类问题)将数据集分割到左右子树,依次计算左右子树的基尼不纯度 G(left) 和 G(right),以及左右子树中数据样本在总样本中占的比例 P(left) 和 P(right),并且将 G(left), G(right), P(left), P(right) 代入 cal_gini_index 函数中计算基尼指数。最后选出具有最小基尼指数的属性值,作为当前节点的最优切分点,并返回最优切分点和最优分裂属性索引。
/**
* @brief
* 计算数据集的基尼不纯度
* @param y 目标变量集合
* @return 基尼不纯度 double
*/
double CartClassifier::cal_gini_impurity(const vector<string>& y) {
// 统计 y 中的目标变量值的个数
unordered_map<string, int> y_count;
for (const string& v : y) {
if (!y_count.contains(v)) y_count[v] = 0;
++y_count[v];
}
// 计算基尼不纯度
double gini_impurity = 1;
const auto num_samples = static_cast<double>(y.size());
for (const int& k : y_count | views::values) {
const double prob = k / num_samples;
gini_impurity -= prob * prob;
}
return gini_impurity;
}
cal_gini_impurity 用于计算基尼不纯度,包含以下步骤:
- 分析导入的数据集的最后一列(一般默认为数据类别),根据不同类别按出现次数统计到分类字典中;
- 遍历该字典,根据公式用 1 减去不同的类分布概率的平方和,得到最终的基尼不纯度。
/**
* @brief
* 计算基尼指数
* @param pro_left 左子树比例
* @param pro_right 右子树比例
* @param gini_impurity_left 左子树的基尼不纯度
* @param gini_impurity_right 右子树的基尼不纯度
* @return 基尼指数 double
*/
double CartClassifier::cal_gini_index(const double pro_left, const double pro_right,
const double gini_impurity_left, const double gini_impurity_right) {
return pro_left * gini_impurity_left + pro_right * gini_impurity_right;
}
cal_gini_index 通过公式计算基尼指数。
预测
/**
* @brief
* 使用决策树进行预测
* @param X 测试集属性值
* @return 预测值
*/
vector<string> CartClassifier::predict(const vector<vector<string>>& X) {
vector<string> y_preds;
for (const vector<string>& x : X) y_preds.emplace_back(classify(tree_, x));
return y_preds;
}
遍历测试集 X 的每个样本,使用 classify 函数分别对其进行预测,最终返回拼接好的预测结果。
/**
* @brief
* 分类预测
* @param tree 训练好的 CART 树
* @param x 待分类样本
* @return 预测类
*/
string CartClassifier::classify(const shared_ptr<BinTreeNode>& tree, const vector<string>& x) {
const string& first_str = tree->feature_name_; // 根节点
const size_t feature_index = distance(feature_names_.begin(), ranges::find(feature_names_, first_str));
const string& current_value = x[feature_index];
if (tree->left_ && current_value == tree->threshold_str_)
return classify(tree->left_, x);
if (tree->right_ && current_value != tree->threshold_str_)
return classify(tree->right_, x);
return tree->threshold_str_;
}
通过调用 classify 进行预测分类。参数 tree 的根节点代表属性,根节点的左右孩子节点代表属性的取值及路由方向。在递归遍历过程中,从根节点开始,递归遍历 CART 分类树,最终路由到某个叶子节点,叶子节点上的值即为该决策树的预测结果。
剪枝
/**
* @brief
* 代价复杂度剪枝 CCP
* @param X 训练集属性值
* @param y 训练集目标变量
* @return 剪枝后的决策树集合
*/
vector<shared_ptr<BinTreeNode>> CartClassifier::pruning(const vector<vector<string>>& X,
const vector<string>& y) {
// 递归计算对当前树的每个子树的 g(ti),挑选最小的 g(ti) 进行剪枝,得到新的 T,最终得到 n 个 T
return split_n_best_trees(X, y);
}
函数 pruning 根据不同的 α 区间生成不同剪枝程度的决策树集合。集合中越后面的决策树,剪枝程度越高。
/**
* @brief
* 根据 g(ti) 生成 n 个误差最小的树
* @param X 训练集属性值
* @param y 训练集目标变量
* @return n 个误差最小的树
*/
vector<shared_ptr<BinTreeNode>> CartClassifier::split_n_best_trees(const vector<vector<string>>& X,
const vector<string>& y) {
vector<shared_ptr<BinTreeNode>> trees;
shared_ptr<BinTreeNode> tree = tree_->copy();
while (tree)
if (shared_ptr<BinTreeNode> best_tree = split_1_best_trees(tree, X, y)) {
trees.emplace_back(best_tree);
tree = best_tree->copy();
} else
tree = nullptr;
return trees;
}
split_n_best_trees 函数通过调用 split_1_best_trees 函数递归生成 n 棵预测误差最小的树,每一次递归的初始树均为上一次递归得到的最优剪枝树。为了在递归过程中不破坏上一轮得到的最优剪枝树,使用了深拷贝。
/**
* @brief
* 计算α值,选出α值最小的剪枝树
* @param tree 决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @return α值最小的剪枝树
*/
shared_ptr<BinTreeNode> CartClassifier::split_1_best_trees(const shared_ptr<BinTreeNode>& tree,
const vector<vector<string>>& X,
const vector<string>& y) {
// 构建节点信息总集合
vector<Info> infoSet;
// 计算数据集长度
const size_t NT = X.size();
// 计算误差增加率,并生成信息集合
calErrorRatio(tree, X, y, NT, infoSet);
if (infoSet.empty()) return nullptr;
// a 的比较基准值
double baseValue = 1;
int bestNode = 0;
for (int i = 0; i < infoSet.size(); i++)
if (infoSet[i].a < baseValue) {
baseValue = infoSet[i].a;
bestNode = i;
} else if (infoSet[i].a == baseValue && infoSet[i].num_leaf_ > infoSet[bestNode].num_leaf_)
bestNode = i;
return prunBranch(tree, X, y, infoSet[bestNode]);
}
函数 split_1_best_tree 负责递归计算 α 值,并且选出 α 值最小的剪枝树。当前树的深度大于 1 时,开始进行 CCP 的迭代剪枝。在每次迭代内部,对每个分支节点进行 gi(t) 的计算,并选取最小值对应的子树进行剪枝。如果求得的最小 gi(t) 对应的子树有多个,则优先选取节点数目最多的子树作为修剪的对象。
/**
* @brief
* 计算非叶节点误差增加率
* @param tree 决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @param NT 数据集总样本数目
* @param infoSet 所有节点的信息总集合
* @return 各个节点的信息集
*/
Info CartClassifier::calErrorRatio(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X,
const vector<string>& y, const size_t NT, vector<Info>& infoSet) {
const string_view firstFeat = tree->feature_name_;
const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
if (tree->left_ && (tree->left_->left_ || tree->left_->right_)) {
// 划分数据集
vector<vector<string>> sub_X;
vector<string> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] == tree->threshold_str_) {
// 取第 i 行进 subData
// 相当于把 label 特征取值剔除,将其他特征取值输出
// 将每个符合条件的特征列表,组成列表集合
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
Info info = calErrorRatio(tree->left_, sub_X, sub_y, NT, infoSet);
// 在节点信息集中,增加分类前特征
info.key_str_ = {true, tree->threshold_str_};
infoSet.emplace_back(info);
}
if (tree->right_ && (tree->right_->left_ || tree->right_->right_)) {
// 划分数据集
vector<vector<string>> sub_X;
vector<string> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] != tree->threshold_str_) {
sub_X.(X[i]);
sub_y.(y[i]);
}
Info info = (tree->right_, sub_X, sub_y, NT, infoSet);
info.key_str_ = {, tree->threshold_str_};
infoSet.(info);
}
Ct = <>((y)) / <>(NT);
CTt = <>((tree, X, y)) / <>(NT);
Nt = (tree);
a = Nt == ? : (Ct - CTt) / <>(Nt - );
{tree, Nt, a};
}
每次迭代中 gi(t) 的计算,也就是 calErrorRatio 函数。该函数主要计算节点 t 的误差率 C(t)、节点 t 对应子树 Tt 的误差率 C(Tt)、子树叶子节点的数目 |Tt|。gi(t) 的计算采用递归的方法,最终将所有 info 合并成节点信息集合。
/**
* @brief
* 计算非叶节点的误差
* @param y 训练集目标变量
* @return 误差
*/
size_t CartClassifier::nodeError(const vector<string>& y) {
// 找到数量最多的类别
string majorClass = majority_y(y);
// 游历数据集每个元素,找出正确样本个数,如果不一致,错误加 1
return ranges::count_if(y, [&majorClass](const string& v) { return v != majorClass; });
}
/**
* @brief
* 计算叶节点的误差
* @param tree 生成的决策树
* @param X
* @param y 训练集目标变量
* @return 误差
*/
size_t CartClassifier::leafError(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X,
const vector<string>& y) {
size_t error = 0;
for (int i = 0; i < X.size(); i++)
if (classify(tree, X[i]) != y[i]) ++error;
return error;
}
/**
* @brief
* 获取叶节点数量
* @param tree 决策树
* @return 返回树的叶节点
*/
size_t CartClassifier::getNumLeaf(const shared_ptr<BinTreeNode>& tree) {
size_t numLeafs = 0;
if (tree->left_) numLeafs += getNumLeaf(tree->left_);
if (tree->right_) numLeafs += getNumLeaf(tree->right_);
if (!tree->left_ && !tree->right_) ++numLeafs;
return numLeafs;
}
/**
* @brief
* 根据误差增加率,剪掉子树
* @param tree 决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @param infoBran 需剪掉的子树信息集
* @return 剪枝后的决策树
*/
shared_ptr<BinTreeNode> CartClassifier::prunBranch(const shared_ptr<BinTreeNode>& tree,
const vector<vector<string>>& X,
const vector<string>& y,
const Info& infoBran) {
const string_view firstFeat = tree->feature_name_;
const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
if (tree->left_) {
// 划分数据集
vector<vector<string>> sub_X;
vector<string> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] == tree->threshold_str_) {
// 取第 i 行进 subData
// 相当于把 label 特征取值剔除,将其他特征取值输出
// 将每个符合条件的特征列表,组成列表集合
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
// 找到数量最多的类别
const string majorClass = majority_y(sub_y);
// 如果当前子树分类前特征和子树都和预处理相同,则把该子树剪掉
if (infoBran.key_str_.first && infoBran.key_str_.second == tree->threshold_str_ &&
tree->left_ == infoBran.tree_) {
// 剪掉子树,即返回最大类
tree->left_ = make_shared<BinTreeNode>();
tree->left_->threshold_str_ = majorClass;
return tree;
}
// 如果不相同,继续向下寻找
tree->left_ = (tree->left_, sub_X, sub_y, infoBran);
}
(tree->right_) {
vector<vector<string>> sub_X;
vector<string> sub_y;
( i = ; i < X.(); i++)
(X[i][labelIndex] != tree->threshold_str_) {
sub_X.(X[i]);
sub_y.(y[i]);
}
string majorClass = (sub_y);
(!infoBran.key_str_.first && infoBran.key_str_.second == tree->threshold_str_ &&
tree->right_ == infoBran.tree_) {
tree->right_ = <BinTreeNode>();
tree->right_->threshold_str_ = majorClass;
tree;
}
tree->right_ = (tree->right_, sub_X, sub_y, infoBran);
}
tree;
}
应用注意事项
- 为方便理解,代码仅考虑了离散字符串的分类,并未考虑其他离散值和连续值的分类,实际生产过程可能需要补充;
- 原则上来说,代码数据集中的字符串均需要通过编码(分类算法中编码无限制),以提升效率。为方便理解,本文章使用原始字符串,不影响结果;
- 代码中的
feature_name仅作画图需要,实际生产如无该需求,可以去掉该变量; - 对于 CCP 误差的计算,scikit-learn 使用基尼不纯度进行代替,因其不用每次使用预测计算,提高了效率。但基尼不纯度与误差之间仅具有相关性,无法通过基尼不纯度推导出误差,仅用作近似计算;
- 代码未考虑缺失值的处理;
- 代码没有适配多线程场景;
- 其他可能的算法时空复杂度的优化。
回归树
训练
CartRegressor 的创建和训练过程与 CartClassifier 类似。最重要的区别在于模型训练时切分点的选取。
/**
* @brief
* 训练决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @param feature_names 属性名
* @return 生成的决策树
*/
shared_ptr<BinTreeNode> CartRegressor::train(const vector<vector<double>>& X,
const vector<double>& y,
const vector<string>& feature_names) {
feature_names_ = feature_names;
tree_ = create_tree(X, y);
return tree_;
}
train 的入参类型发生了变化,这是因为回归树使用的是连续类型数据。
/**
* @brief
* 创建树
* @param X 映射后的数值属性集
* @param y 映射后的数值目标变量集
* @return 训练好的决策树
*/
shared_ptr<BinTreeNode> CartRegressor::create_tree(const vector<vector<double>>& X, const vector<double>& y) {
// 若 X 中样本全属于同一类别 C,则停止划分
auto tree = make_shared<BinTreeNode>();
if (unordered_set(y.begin(), y.end()).size() == 1) {
tree->threshold_ = y.front();
return tree;
}
// 若节点样本数小于 min_samples_split,或者属性集上的取值均相同
if (y.size() <= min_samples_split_ || set(X.begin(), X.end()).size() == 1) {
tree->threshold_ = accumulate(y.begin(), y.end(), 0.) / static_cast<double>(y.size());
return tree;
}
// 按照'平方误差最小',从 feature_names 中选择最优切分点
auto [best_split_point, best_feature_index] = choose_best_point_to_split(X, y);
const string_view best_feature_name = feature_names_[best_feature_index];
// 根据最优切分点,进行子树的划分
vector<vector<double>> sub_X1, sub_X2;
vector<double> sub_y1, sub_y2;
for (int i = 0; i < X.size(); i++)
(X[i][best_feature_index] <= best_split_point) {
sub_X(X[i]);
sub_y(y[i]);
} {
sub_X(X[i]);
sub_y(y[i]);
}
tree->feature_name_ = best_feature_name;
tree->threshold_ = best_split_point;
tree->left_ = (sub_X1, sub_y1);
tree->right_ = (sub_X2, sub_y2);
tree;
}
在函数 create_tree 中,主要有 3 处与分类树不同:
- 当满足递归终止条件'节点样本数小于
min_samples_split_'时,返回的预测值是该集合中所有目标变量的平均值; - 在
choose_best_point_to_split函数中,在回归树中采用'平方误差最小'的原则来选择最优切分点; - 使用最优属性和最优切分点划分数据集时相较分类树(处理匹配字符串'≤'和'>'的代码逻辑)做略微调整。
/**
* @brief
* 选择最优切分点
* @param X 映射后的数值属性集
* @param y 属性名称
* @return 最优切分点和最优切分点所在属性的索引
*/
pair<double, int> CartRegressor::choose_best_point_to_split(const vector<vector<double>>& X,
const vector<double>& y) {
double best_split_point = 0, best_loss_all = numeric_limits<double>::infinity();
int best_feature_index = -1;
const size_t num_feature = X[0].size();
// 属性的个数
for (int i = 0; i < num_feature; ++i) // 遍历每个属性
{
// 得到某个属性下的所有值,即某列,并去重,得到无重复的属性特征值
set<double> unique_feature_value;
vector<double> split_points;
for (const vector<double>& x : X) unique_feature_value.emplace(x[i]);
auto lit = unique_feature_value.begin(), rit = lit;
++rit;
while (rit != unique_feature_value.end()) {
split_points.emplace_back((*lit + *rit) / 2);
++lit;
++rit;
}
// 计算各个候选切分点的损失函数
for (const double split_point : split_points) {
vector<double> sub_y_left, sub_y_right;
for (int j = ; j < X.(); j++)
(X[j][i] <= split_point) sub_y_left.(y[j]);
sub_y_right.(y[j]);
sub_y_left_mean = (sub_y_left.(), sub_y_left.(), ) /
<>(sub_y_left.()),
sub_y_right_mean = (sub_y_right.(), sub_y_right.(), ) /
<>(sub_y_right.());
loss_left = , loss_right = ;
( j : sub_y_left) loss_left += (j - sub_y_left_mean, );
( j : sub_y_right) loss_right += (j - sub_y_right_mean, );
( loss_all = loss_left + loss_right; best_loss_all > loss_all) {
best_loss_all = loss_all;
best_feature_index = i;
best_split_point = split_point;
}
}
}
{best_split_point, best_feature_index};
}
choose_best_point_to_split 遍历所有属性值时,回归树中不再计算基尼不纯度和基尼增益,而是针对回归问题计算损失函数。分别计算了使用当前切分点划分的左右子树的残差平方和,再计算左右子树的总残差平方和。最后选出取得最小损失函数的切分点和属性索引,作为最优切分点和最优分裂属性。
预测
/**
* @brief
* 使用决策树进行预测
* @param X 测试集属性值
* @return 预测值
*/
vector<double> CartRegressor::predict(const vector<vector<double>>& X) {
vector<double> y_preds;
for (const vector<double>& x : X) y_preds.emplace_back(regression(tree_, x));
return y_preds;
}
/**
* @brief
* 回归预测
* @param tree 训练好的树
* @param x 待分类样本
* @return 预测类
*/
double CartRegressor::regression(const shared_ptr<BinTreeNode>& tree, const vector<double>& x) {
const string& first_str = tree->feature_name_; // 根节点
const size_t feature_index = distance(feature_names_.begin(), ranges::find(feature_names_, first_str));
const double current_value = x[feature_index];
if (tree->left_ && current_value <= tree->threshold_)
return regression(tree->left_, x);
if (tree->right_ && current_value > tree->threshold_)
return regression(tree->right_, x);
return tree->threshold_;
}
由于 CART 回归树与分类树的预测过程几乎完全相同,在此不做赘述。
剪枝
/**
* @brief
* 代价复杂度剪枝 CCP
* @param X 训练集属性值
* @param y 训练集目标变量
* @return 剪枝后的决策树集合
*/
vector<shared_ptr<BinTreeNode>> CartRegressor::pruning(const vector<vector<double>>& X,
const vector<double>& y) {
// 递归计算对当前树的每个子树的 g(ti),挑选最小的 g(ti) 进行剪枝,得到新的 T,最终得到 n 个 T
return split_n_best_trees(X, y);
}
/**
* @brief
* 根据 g(ti) 生成 n 个误差最小的树
* @param X 训练集属性值
* @param y 训练集目标变量
* @return n 个误差最小的树
*/
vector<shared_ptr<BinTreeNode>> CartRegressor::split_n_best_trees(const vector<vector<double>>& X,
const vector<double>& y) {
vector<shared_ptr<BinTreeNode>> trees;
shared_ptr<BinTreeNode> tree = tree_->copy();
while (tree)
if (shared_ptr<BinTreeNode> best_tree = split_1_best_trees(tree, X, y)) {
trees.emplace_back(best_tree);
tree = best_tree->copy();
} else
tree = nullptr;
return trees;
}
/**
* @brief
* 计算α值,选出α值最小的剪枝树
* @param tree 决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @return α值最小的剪枝树
*/
shared_ptr<BinTreeNode> CartRegressor::split_1_best_trees(const shared_ptr<BinTreeNode>& tree,
const vector<vector<double>>& X,
const vector<double>& y) {
// 构建节点信息总集合
vector<Info> infoSet;
// 计算数据集长度
const size_t NT = X.();
(tree, X, y, NT, infoSet);
(infoSet.()) ;
baseValue = ;
bestNode = ;
( i = ; i < infoSet.(); i++)
(infoSet[i].a < baseValue) {
baseValue = infoSet[i].a;
bestNode = i;
} (infoSet[i].a == baseValue && infoSet[i].num_leaf_ > infoSet[bestNode].num_leaf_)
bestNode = i;
(tree, X, y, infoSet[bestNode]);
}
{
string_view firstFeat = tree->feature_name_;
labelIndex = (feature_names_.(), ranges::(feature_names_, firstFeat));
(tree->left_ && (tree->left_->left_ || tree->left_->right_)) {
vector<vector<>> sub_X;
vector<> sub_y;
( i = ; i < X.(); i++)
(X[i][labelIndex] <= tree->threshold_) {
sub_X.(X[i]);
sub_y.(y[i]);
}
Info info = (tree->left_, sub_X, sub_y, NT, infoSet);
info.key_ = {, tree->threshold_};
infoSet.(info);
}
(tree->right_ && (tree->right_->left_ || tree->right_->right_)) {
vector<vector<>> sub_X;
vector<> sub_y;
( i = ; i < X.(); i++)
(X[i][labelIndex] > tree->threshold_) {
sub_X.(X[i]);
sub_y.(y[i]);
}
Info info = (tree->right_, sub_X, sub_y, NT, infoSet);
info.key_ = {, tree->threshold_};
infoSet.(info);
}
Rt = <>((y)) / <>(NT);
RTt = <>((tree, X, y)) / <>(NT);
Nt = (tree);
a = Nt == ? : (Rt - RTt) / <>(Nt - );
{tree, Nt, a};
}
{
mean_y = (y.(), y.(), ) / <>(y.());
error = ;
( & val : y) error += <>((val - mean_y, ));
error;
}
{
error = ;
( i = ; i < X.(); i++) {
pred = (tree, X[i]);
error += <>((pred - y[i], ));
}
error;
}
{
numLeafs = ;
(tree->left_) numLeafs += (tree->left_);
(tree->right_) numLeafs += (tree->right_);
(!tree->left_ && !tree->right_) ++numLeafs;
numLeafs;
}
{
string_view firstFeat = tree->feature_name_;
labelIndex = (feature_names_.(), ranges::(feature_names_, firstFeat));
(tree->left_) {
vector<vector<>> sub_X;
vector<> sub_y;
( i = ; i < X.(); i++)
(X[i][labelIndex] <= tree->threshold_) {
sub_X.(X[i]);
sub_y.(y[i]);
}
mean_val = (sub_y.(), sub_y.(), ) / <>(sub_y.());
(infoBran.key_.first && (infoBran.key_.second - tree->threshold_) < &&
tree->left_ == infoBran.tree_) {
tree->left_ = <BinTreeNode>();
tree->left_->threshold_ = mean_val;
tree;
}
tree->left_ = (tree->left_, sub_X, sub_y, infoBran);
}
(tree->right_) {
vector<vector<>> sub_X;
vector<> sub_y;
( i = ; i < X.(); i++)
(X[i][labelIndex] > tree->threshold_) {
sub_X.(X[i]);
sub_y.(y[i]);
}
mean_val = (sub_y.(), sub_y.(), ) / <>(sub_y.());
(!infoBran.key_.first && (infoBran.key_.second - tree->threshold_) < &&
tree->right_ == infoBran.tree_) {
tree->right_ = <BinTreeNode>();
tree->right_->threshold_ = mean_val;
tree;
}
tree->right_ = (tree->right_, sub_X, sub_y, infoBran);
}
tree;
}
回归树的剪枝与分类树类似,不同点在于回归树计算误差使用的是均方差。
应用注意事项
- 代码中的
feature_name仅作画图需要,实际生产如无该需求,可以去掉该变量; - 代码未考虑缺失值的处理;
- 分类树和回归树中的 CCP 算法,仅在误差计算中有区别。分类树中可以使用基尼系数或误分类率(从效率层面,推荐使用基尼系数),回归树中使用均方差;
- 代码没有适配多线程场景;
- 其他可能的算法时空复杂度的优化。

