跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
C++AI算法

从零开始实现决策树——手撕 CART 算法(C++)

综述由AI生成基于 C++20 标准从零实现 CART 决策树算法,包含分类树与回归树两部分。详细阐述了数据结构设计、训练过程(基尼指数与平方误差最小化)、预测方法及代价复杂度剪枝(CCP)的实现细节。同时列出了离散值处理、缺失值及多线程适配等实际应用中的注意事项。

赛博朋克发布于 2026/3/21更新于 2026/6/134 浏览

本文的 C++ 代码基于 C++ 20 标准(不包含 C++ modules),对于之前的标准,可能需要做一些适配。

CART 分类树和回归树的内容各自在一个类中,分类树为 CartClassifier 类,回归树为 CartRegression 类。

数据结构设计

二叉树设计

// 二叉树结点
struct BinTreeNode {
    std::string threshold_str_;
    double threshold_ = -1;
    std::string feature_name_;
    std::shared_ptr<BinTreeNode> left_ = nullptr;
    std::shared_ptr<BinTreeNode> right_ = nullptr;
    [[nodiscard]] std::shared_ptr<BinTreeNode> copy() const {
        auto node = std::make_shared<BinTreeNode>();
        node->threshold_ = threshold_;
        node->threshold_str_ = threshold_str_;
        node->feature_name_ = feature_name_;
        if(left_) node->left_ = left_->copy();
        if(right_) node->right_ = right_->copy();
        return node;
    }
};

copy 模块用于二叉树结点的深复制,包括复制本身及其所有的子结点。

结点信息设计

struct Info {
    std::shared_ptr<BinTreeNode> tree_;
    size_t num_leaf_ = 0;
    double a = 0;
    std::pair<bool, std::string> key_str_{};
    std::pair<bool, double> key_{};
};

实际上结点信息可以直接存储到二叉树结点 BinTreeNode 中。分开是为了保证代码的语义清晰,易于理解。

分类树

训练

/**
 * @brief
 * 训练决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @param feature_names 属性名
 * @return 生成的决策树
 */
shared_ptr<BinTreeNode> CartClassifier::train(const vector<vector<string>>& X,
                                              const vector<string>& y,
                                              const vector<string>& feature_names) {
    feature_names_ = feature_names;
    // 创建 CART 决策树
    tree_ = create_tree(X, y);
    return tree_;
}

训练函数通过传递常量引用形参,防止训练集和属性集被篡改。如需要修改,可以在函数内部设置副本,针对副本进行修改。create_tree 是创建 CART 分类树的核心函数。

/**
 * @brief
 * 创建树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return 训练好的决策树
 */
shared_ptr<BinTreeNode> CartClassifier::create_tree(const vector<vector<string>>& X,
                                                    const vector<string>& y) {
    // 若 X 中样本全属于同一类别 C,则停止划分
    auto tree = make_shared<BinTreeNode>();
    if (unordered_set(y.begin(), y.end()).size() == 1) {
        tree->threshold_str_ = y.front();
        return tree;
    }
    // 若节点样本数小于 min_samples_split,或者属性集上的取值均相同
    if (y.size() <= min_samples_split_ || set(X.begin(), X.end()).size() == 1) {
        tree->threshold_str_ = majority_y(y);
        return tree;
    }
    // 按照'基尼增益',从属性值中选择最优分裂属性的最优切分点
    auto [best_split_point, best_feature_index] = choose_best_point_to_split(X, y);
    const string best_feature_name = feature_names_[best_feature_index];
    // 根据最优切分点,进行子树的划分
    vector<vector<string>> sub_X1, sub_X2;
    vector<string> sub_y1, sub_y2;
    for (int i = 0; i < X.size(); i++)
        if (X[i][best_feature_index] == best_split_point) {
            sub_X1.emplace_back(X[i]);
            sub_y1.emplace_back(y[i]);
        } else {
            sub_X2.emplace_back(X[i]);
            sub_y2.emplace_back(y[i]);
        }
    tree->feature_name_ = best_feature_name;
    tree->threshold_str_ = best_split_point;
    tree->left_ = create_tree(sub_X1, sub_y1);
    tree->right_ = create_tree(sub_X2, sub_y2);
    return tree;
}

create_tree 函数是一个递归创建决策树的过程。首先判断三种递归中止条件:

  • X 中样本全部属于同一类别;
  • 当前节点样本数小于 min_samples_split_;
  • 属性集上的取值均相同

若满足终止条件,则选择 y 中最多的类别作为结果返回。若未满足终止条件,依次执行以下步骤:

  1. 根据基尼指数从属性值中选择最优分裂属性的最优切分点,具体过程如 choose_best_point_to_split 函数所示;
  2. 根据最优切分点对子树进行划分;
  3. 对于其子树再继续执行 create_tree 函数完成划分过程。
/**
 * @brief
 * 统计每个类别出现的次数,返回出现次数最大的类别 ID
 * @param y 目标变量集合
 * @return 出现次数最大的类别
 */
string CartClassifier::majority_y(const vector<string>& y) {
    // 统计 y 中的目标变量值的个数
    unordered_map<string, int> y_count;
    for (const string& v : y) {
        if (!y_count.contains(v)) y_count[v] = 0;
        ++y_count[v];
    }
    return ranges::max_element(y_count, [](const pair<string, int>& a, const pair<string, int>& b) {
        return a.second < b.second;
    })->first;
}

majority_y 用于计算节点中出现次数最多的类别。包含以下步骤:

  1. 初始化一个空映射;
  2. 遍历 y 并对其元素进行计数;
  3. 从映射中查找出现次数最多的类别。
/**
 * @brief
 * 选择最优切分点
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return 最优切分点和最优切分点所在属性的索引
 */
pair<string, int> CartClassifier::choose_best_point_to_split(const vector<vector<string>>& X,
                                                             const vector<string>& y) {
    string best_split_point;
    int best_feature_index = -1;
    double best_gini_index = numeric_limits<double>::infinity();
    const size_t num_feature = X[0].size();
    // 属性的个数
    for (int i = 0; i < num_feature; i++) // 遍历每个属性
    {
        // 得到某个属性下的所有值,即某列,并去重,得到无重复的属性特征值
        unordered_set<string> split_points;
        for (const vector<string>& x : X) split_points.emplace(x[i]);
        for (const string& split_point : split_points) // 计算各个候选切分点的基尼不纯度
        {
            vector<string> sub_y_left, sub_y_right;
            for (int j = 0; j < X.size(); j++)
                if (X[j][i] == split_point) sub_y_left.emplace_back(y[j]);
                else sub_y_right.emplace_back(y[j]);
            // 计算左子树的基尼不纯度
            const double gini_impurity_left = cal_gini_impurity(sub_y_left);
            // 计算右子树的基尼不纯度
            const double gini_impurity_right = cal_gini_impurity(sub_y_right);
            // 计算该切分点的基尼指数
            const double pro_left = static_cast<double>(sub_y_left.size()) / static_cast<double>(y.size()),
                        pro_right = static_cast<double>(sub_y_right.size()) / static_cast<double>(y.size());
            if (const double gini_index = cal_gini_index(pro_left, pro_right, gini_impurity_left, gini_impurity_right);
                best_gini_index > gini_index)
                // 取基尼指数最大的属性索引和切分点
            {
                best_gini_index = gini_index;
                best_feature_index = i;
                best_split_point = split_point;
            }
        }
    }
    return {best_split_point, best_feature_index};
}

choose_best_point_to_split 是 CART 分类树中最核心的函数,该函数负责选择最优切分点。根据前面的理论推导,该函数的目的是计算取得最大基尼增益的属性值。该函数遍历每个属性的每个属性值,根据是否等于属性值(二分类问题)将数据集分割到左右子树,依次计算左右子树的基尼不纯度 G(left) 和 G(right),以及左右子树中数据样本在总样本中占的比例 P(left) 和 P(right),并且将 G(left), G(right), P(left), P(right) 代入 cal_gini_index 函数中计算基尼指数。最后选出具有最小基尼指数的属性值,作为当前节点的最优切分点,并返回最优切分点和最优分裂属性索引。

/**
 * @brief
 * 计算数据集的基尼不纯度
 * @param y 目标变量集合
 * @return 基尼不纯度 double
 */
double CartClassifier::cal_gini_impurity(const vector<string>& y) {
    // 统计 y 中的目标变量值的个数
    unordered_map<string, int> y_count;
    for (const string& v : y) {
        if (!y_count.contains(v)) y_count[v] = 0;
        ++y_count[v];
    }
    // 计算基尼不纯度
    double gini_impurity = 1;
    const auto num_samples = static_cast<double>(y.size());
    for (const int& k : y_count | views::values) {
        const double prob = k / num_samples;
        gini_impurity -= prob * prob;
    }
    return gini_impurity;
}

cal_gini_impurity 用于计算基尼不纯度,包含以下步骤:

  1. 分析导入的数据集的最后一列(一般默认为数据类别),根据不同类别按出现次数统计到分类字典中;
  2. 遍历该字典,根据公式用 1 减去不同的类分布概率的平方和,得到最终的基尼不纯度。
/**
 * @brief
 * 计算基尼指数
 * @param pro_left 左子树比例
 * @param pro_right 右子树比例
 * @param gini_impurity_left 左子树的基尼不纯度
 * @param gini_impurity_right 右子树的基尼不纯度
 * @return 基尼指数 double
 */
double CartClassifier::cal_gini_index(const double pro_left, const double pro_right,
                                      const double gini_impurity_left, const double gini_impurity_right) {
    return pro_left * gini_impurity_left + pro_right * gini_impurity_right;
}

cal_gini_index 通过公式计算基尼指数。

预测

/**
 * @brief
 * 使用决策树进行预测
 * @param X 测试集属性值
 * @return 预测值
 */
vector<string> CartClassifier::predict(const vector<vector<string>>& X) {
    vector<string> y_preds;
    for (const vector<string>& x : X) y_preds.emplace_back(classify(tree_, x));
    return y_preds;
}

遍历测试集 X 的每个样本,使用 classify 函数分别对其进行预测,最终返回拼接好的预测结果。

/**
 * @brief
 * 分类预测
 * @param tree 训练好的 CART 树
 * @param x 待分类样本
 * @return 预测类
 */
string CartClassifier::classify(const shared_ptr<BinTreeNode>& tree, const vector<string>& x) {
    const string& first_str = tree->feature_name_; // 根节点
    const size_t feature_index = distance(feature_names_.begin(), ranges::find(feature_names_, first_str));
    const string& current_value = x[feature_index];
    if (tree->left_ && current_value == tree->threshold_str_)
        return classify(tree->left_, x);
    if (tree->right_ && current_value != tree->threshold_str_)
        return classify(tree->right_, x);
    return tree->threshold_str_;
}

通过调用 classify 进行预测分类。参数 tree 的根节点代表属性,根节点的左右孩子节点代表属性的取值及路由方向。在递归遍历过程中,从根节点开始,递归遍历 CART 分类树,最终路由到某个叶子节点,叶子节点上的值即为该决策树的预测结果。

剪枝

/**
 * @brief
 * 代价复杂度剪枝 CCP
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return 剪枝后的决策树集合
 */
vector<shared_ptr<BinTreeNode>> CartClassifier::pruning(const vector<vector<string>>& X,
                                                        const vector<string>& y) {
    // 递归计算对当前树的每个子树的 g(ti),挑选最小的 g(ti) 进行剪枝,得到新的 T,最终得到 n 个 T
    return split_n_best_trees(X, y);
}

函数 pruning 根据不同的 α 区间生成不同剪枝程度的决策树集合。集合中越后面的决策树,剪枝程度越高。

/**
 * @brief
 * 根据 g(ti) 生成 n 个误差最小的树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return n 个误差最小的树
 */
vector<shared_ptr<BinTreeNode>> CartClassifier::split_n_best_trees(const vector<vector<string>>& X,
                                                                   const vector<string>& y) {
    vector<shared_ptr<BinTreeNode>> trees;
    shared_ptr<BinTreeNode> tree = tree_->copy();
    while (tree)
        if (shared_ptr<BinTreeNode> best_tree = split_1_best_trees(tree, X, y)) {
            trees.emplace_back(best_tree);
            tree = best_tree->copy();
        } else
            tree = nullptr;
    return trees;
}

split_n_best_trees 函数通过调用 split_1_best_trees 函数递归生成 n 棵预测误差最小的树,每一次递归的初始树均为上一次递归得到的最优剪枝树。为了在递归过程中不破坏上一轮得到的最优剪枝树,使用了深拷贝。

/**
 * @brief
 * 计算α值,选出α值最小的剪枝树
 * @param tree 决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return α值最小的剪枝树
 */
shared_ptr<BinTreeNode> CartClassifier::split_1_best_trees(const shared_ptr<BinTreeNode>& tree,
                                                           const vector<vector<string>>& X,
                                                           const vector<string>& y) {
    // 构建节点信息总集合
    vector<Info> infoSet;
    // 计算数据集长度
    const size_t NT = X.size();
    // 计算误差增加率,并生成信息集合
    calErrorRatio(tree, X, y, NT, infoSet);
    if (infoSet.empty()) return nullptr;
    // a 的比较基准值
    double baseValue = 1;
    int bestNode = 0;
    for (int i = 0; i < infoSet.size(); i++)
        if (infoSet[i].a < baseValue) {
            baseValue = infoSet[i].a;
            bestNode = i;
        } else if (infoSet[i].a == baseValue && infoSet[i].num_leaf_ > infoSet[bestNode].num_leaf_)
            bestNode = i;
    return prunBranch(tree, X, y, infoSet[bestNode]);
}

函数 split_1_best_tree 负责递归计算 α 值,并且选出 α 值最小的剪枝树。当前树的深度大于 1 时,开始进行 CCP 的迭代剪枝。在每次迭代内部,对每个分支节点进行 gi(t) 的计算,并选取最小值对应的子树进行剪枝。如果求得的最小 gi(t) 对应的子树有多个,则优先选取节点数目最多的子树作为修剪的对象。

/**
 * @brief
 * 计算非叶节点误差增加率
 * @param tree 决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @param NT 数据集总样本数目
 * @param infoSet 所有节点的信息总集合
 * @return 各个节点的信息集
 */
Info CartClassifier::calErrorRatio(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X,
                                   const vector<string>& y, const size_t NT, vector<Info>& infoSet) {
    const string_view firstFeat = tree->feature_name_;
    const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
    if (tree->left_ && (tree->left_->left_ || tree->left_->right_)) {
        // 划分数据集
        vector<vector<string>> sub_X;
        vector<string> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] == tree->threshold_str_) {
                // 取第 i 行进 subData
                // 相当于把 label 特征取值剔除,将其他特征取值输出
                // 将每个符合条件的特征列表,组成列表集合
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        Info info = calErrorRatio(tree->left_, sub_X, sub_y, NT, infoSet);
        // 在节点信息集中,增加分类前特征
        info.key_str_ = {true, tree->threshold_str_};
        infoSet.emplace_back(info);
    }
    if (tree->right_ && (tree->right_->left_ || tree->right_->right_)) {
        // 划分数据集
        vector<vector<string>> sub_X;
        vector<string> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] != tree->threshold_str_) {
                // 取第 i 行进 subData
                // 相当于把 label 特征取值剔除,将其他特征取值输出
                // 将每个符合条件的特征列表,组成列表集合
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        Info info = calErrorRatio(tree->right_, sub_X, sub_y, NT, infoSet);
        // 在节点信息集中,增加分类前特征
        info.key_str_ = {false, tree->threshold_str_};
        infoSet.emplace_back(info);
    }
    // 计算节点误差率
    const double Ct = static_cast<double>(nodeError(y)) / static_cast<double>(NT);
    // 计算子树误差率
    const double CTt = static_cast<double>(leafError(tree, X, y)) / static_cast<double>(NT);
    // 计算叶节点数目
    const size_t Nt = getNumLeaf(tree);
    const double a = Nt == 1 ? 2 : (Ct - CTt) / static_cast<double>(Nt - 1);
    return {tree, Nt, a};
}

每次迭代中 gi(t) 的计算,也就是 calErrorRatio 函数。该函数主要计算节点 t 的误差率 C(t)、节点 t 对应子树 Tt 的误差率 C(Tt)、子树叶子节点的数目 |Tt|。gi(t) 的计算采用递归的方法,最终将所有 info 合并成节点信息集合。

/**
 * @brief
 * 计算非叶节点的误差
 * @param y 训练集目标变量
 * @return 误差
 */
size_t CartClassifier::nodeError(const vector<string>& y) {
    // 找到数量最多的类别
    string majorClass = majority_y(y);
    // 游历数据集每个元素,找出正确样本个数,如果不一致,错误加 1
    return ranges::count_if(y, [&majorClass](const string& v) { return v != majorClass; });
}

/**
 * @brief
 * 计算叶节点的误差
 * @param tree 生成的决策树
 * @param X
 * @param y 训练集目标变量
 * @return 误差
 */
size_t CartClassifier::leafError(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X,
                                 const vector<string>& y) {
    size_t error = 0;
    for (int i = 0; i < X.size(); i++)
        if (classify(tree, X[i]) != y[i]) ++error;
    return error;
}

/**
 * @brief
 * 获取叶节点数量
 * @param tree 决策树
 * @return 返回树的叶节点
 */
size_t CartClassifier::getNumLeaf(const shared_ptr<BinTreeNode>& tree) {
    size_t numLeafs = 0;
    if (tree->left_) numLeafs += getNumLeaf(tree->left_);
    if (tree->right_) numLeafs += getNumLeaf(tree->right_);
    if (!tree->left_ && !tree->right_) ++numLeafs;
    return numLeafs;
}
/**
 * @brief
 * 根据误差增加率,剪掉子树
 * @param tree 决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @param infoBran 需剪掉的子树信息集
 * @return 剪枝后的决策树
 */
shared_ptr<BinTreeNode> CartClassifier::prunBranch(const shared_ptr<BinTreeNode>& tree,
                                                   const vector<vector<string>>& X,
                                                   const vector<string>& y,
                                                   const Info& infoBran) {
    const string_view firstFeat = tree->feature_name_;
    const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
    if (tree->left_) {
        // 划分数据集
        vector<vector<string>> sub_X;
        vector<string> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] == tree->threshold_str_) {
                // 取第 i 行进 subData
                // 相当于把 label 特征取值剔除,将其他特征取值输出
                // 将每个符合条件的特征列表,组成列表集合
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        // 找到数量最多的类别
        const string majorClass = majority_y(sub_y);
        // 如果当前子树分类前特征和子树都和预处理相同,则把该子树剪掉
        if (infoBran.key_str_.first && infoBran.key_str_.second == tree->threshold_str_ &&
            tree->left_ == infoBran.tree_) {
            // 剪掉子树,即返回最大类
            tree->left_ = make_shared<BinTreeNode>();
            tree->left_->threshold_str_ = majorClass;
            return tree;
        }
        // 如果不相同,继续向下寻找
        tree->left_ = prunBranch(tree->left_, sub_X, sub_y, infoBran);
    }
    if (tree->right_) {
        // 划分数据集
        vector<vector<string>> sub_X;
        vector<string> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] != tree->threshold_str_) {
                // 取第 i 行进 subData
                // 相当于把 label 特征取值剔除,将其他特征取值输出
                // 将每个符合条件的特征列表,组成列表集合
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        // 找到数量最多的类别
        const string majorClass = majority_y(sub_y);
        // 如果当前子树分类前特征和子树都和预处理相同,则把该子树剪掉
        if (!infoBran.key_str_.first && infoBran.key_str_.second == tree->threshold_str_ &&
            tree->right_ == infoBran.tree_) {
            // 剪掉子树,即返回最大类
            tree->right_ = make_shared<BinTreeNode>();
            tree->right_->threshold_str_ = majorClass;
            return tree;
        }
        // 如果不相同,继续向下寻找
        tree->right_ = prunBranch(tree->right_, sub_X, sub_y, infoBran);
    }
    return tree;
}

应用注意事项

  1. 为方便理解,代码仅考虑了离散字符串的分类,并未考虑其他离散值和连续值的分类,实际生产过程可能需要补充;
  2. 原则上来说,代码数据集中的字符串均需要通过编码(分类算法中编码无限制),以提升效率。为方便理解,本文章使用原始字符串,不影响结果;
  3. 代码中的 feature_name 仅作画图需要,实际生产如无该需求,可以去掉该变量;
  4. 对于 CCP 误差的计算,scikit-learn 使用基尼不纯度进行代替,因其不用每次使用预测计算,提高了效率。但基尼不纯度与误差之间仅具有相关性,无法通过基尼不纯度推导出误差,仅用作近似计算;
  5. 代码未考虑缺失值的处理;
  6. 代码没有适配多线程场景;
  7. 其他可能的算法时空复杂度的优化。

回归树

训练

CartRegressor 的创建和训练过程与 CartClassifier 类似。最重要的区别在于模型训练时切分点的选取。

/**
 * @brief
 * 训练决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @param feature_names 属性名
 * @return 生成的决策树
 */
shared_ptr<BinTreeNode> CartRegressor::train(const vector<vector<double>>& X,
                                             const vector<double>& y,
                                             const vector<string>& feature_names) {
    feature_names_ = feature_names;
    tree_ = create_tree(X, y);
    return tree_;
}

train 的入参类型发生了变化,这是因为回归树使用的是连续类型数据。

/**
 * @brief
 * 创建树
 * @param X 映射后的数值属性集
 * @param y 映射后的数值目标变量集
 * @return 训练好的决策树
 */
shared_ptr<BinTreeNode> CartRegressor::create_tree(const vector<vector<double>>& X, const vector<double>& y) {
    // 若 X 中样本全属于同一类别 C,则停止划分
    auto tree = make_shared<BinTreeNode>();
    if (unordered_set(y.begin(), y.end()).size() == 1) {
        tree->threshold_ = y.front();
        return tree;
    }
    // 若节点样本数小于 min_samples_split,或者属性集上的取值均相同
    if (y.size() <= min_samples_split_ || set(X.begin(), X.end()).size() == 1) {
        tree->threshold_ = accumulate(y.begin(), y.end(), 0.) / static_cast<double>(y.size());
        return tree;
    }
    // 按照'平方误差最小',从 feature_names 中选择最优切分点
    auto [best_split_point, best_feature_index] = choose_best_point_to_split(X, y);
    const string_view best_feature_name = feature_names_[best_feature_index];
    // 根据最优切分点,进行子树的划分
    vector<vector<double>> sub_X1, sub_X2;
    vector<double> sub_y1, sub_y2;
    for (int i = 0; i < X.size(); i++)
        if (X[i][best_feature_index] <= best_split_point) {
            sub_X1.emplace_back(X[i]);
            sub_y1.emplace_back(y[i]);
        } else {
            sub_X2.emplace_back(X[i]);
            sub_y2.emplace_back(y[i]);
        }
    tree->feature_name_ = best_feature_name;
    tree->threshold_ = best_split_point;
    tree->left_ = create_tree(sub_X1, sub_y1);
    tree->right_ = create_tree(sub_X2, sub_y2);
    return tree;
}

在函数 create_tree 中,主要有 3 处与分类树不同:

  1. 当满足递归终止条件'节点样本数小于 min_samples_split_'时,返回的预测值是该集合中所有目标变量的平均值;
  2. 在 choose_best_point_to_split 函数中,在回归树中采用'平方误差最小'的原则来选择最优切分点;
  3. 使用最优属性和最优切分点划分数据集时相较分类树(处理匹配字符串'≤'和'>'的代码逻辑)做略微调整。
/**
 * @brief
 * 选择最优切分点
 * @param X 映射后的数值属性集
 * @param y 属性名称
 * @return 最优切分点和最优切分点所在属性的索引
 */
pair<double, int> CartRegressor::choose_best_point_to_split(const vector<vector<double>>& X,
                                                            const vector<double>& y) {
    double best_split_point = 0, best_loss_all = numeric_limits<double>::infinity();
    int best_feature_index = -1;
    const size_t num_feature = X[0].size();
    // 属性的个数
    for (int i = 0; i < num_feature; ++i) // 遍历每个属性
    {
        // 得到某个属性下的所有值,即某列,并去重,得到无重复的属性特征值
        set<double> unique_feature_value;
        vector<double> split_points;
        for (const vector<double>& x : X) unique_feature_value.emplace(x[i]);
        auto lit = unique_feature_value.begin(), rit = lit;
        ++rit;
        while (rit != unique_feature_value.end()) {
            split_points.emplace_back((*lit + *rit) / 2);
            ++lit;
            ++rit;
        }
        // 计算各个候选切分点的损失函数
        for (const double split_point : split_points) {
            vector<double> sub_y_left, sub_y_right;
            for (int j = 0; j < X.size(); j++)
                if (X[j][i] <= split_point) sub_y_left.emplace_back(y[j]);
                else sub_y_right.emplace_back(y[j]);
            const double sub_y_left_mean = accumulate(sub_y_left.begin(), sub_y_left.end(), 0.) /
                                           static_cast<double>(sub_y_left.size()),
                         sub_y_right_mean = accumulate(sub_y_right.begin(), sub_y_right.end(), 0.) /
                                            static_cast<double>(sub_y_right.size());
            double loss_left = 0, loss_right = 0;
            // 计算左子树的损失函数
            for (const double j : sub_y_left) loss_left += pow(j - sub_y_left_mean, 2);
            // 计算右子树的损失函数
            for (const double j : sub_y_right) loss_right += pow(j - sub_y_right_mean, 2);
            // 计算该切分点的总损失函数
            // 取损失函数最小时的属性索引和切分点
            if (const double loss_all = loss_left + loss_right; best_loss_all > loss_all) {
                best_loss_all = loss_all;
                best_feature_index = i;
                best_split_point = split_point;
            }
        }
    }
    return {best_split_point, best_feature_index};
}

choose_best_point_to_split 遍历所有属性值时,回归树中不再计算基尼不纯度和基尼增益,而是针对回归问题计算损失函数。分别计算了使用当前切分点划分的左右子树的残差平方和,再计算左右子树的总残差平方和。最后选出取得最小损失函数的切分点和属性索引,作为最优切分点和最优分裂属性。

预测

/**
 * @brief
 * 使用决策树进行预测
 * @param X 测试集属性值
 * @return 预测值
 */
vector<double> CartRegressor::predict(const vector<vector<double>>& X) {
    vector<double> y_preds;
    for (const vector<double>& x : X) y_preds.emplace_back(regression(tree_, x));
    return y_preds;
}

/**
 * @brief
 * 回归预测
 * @param tree 训练好的树
 * @param x 待分类样本
 * @return 预测类
 */
double CartRegressor::regression(const shared_ptr<BinTreeNode>& tree, const vector<double>& x) {
    const string& first_str = tree->feature_name_; // 根节点
    const size_t feature_index = distance(feature_names_.begin(), ranges::find(feature_names_, first_str));
    const double current_value = x[feature_index];
    if (tree->left_ && current_value <= tree->threshold_)
        return regression(tree->left_, x);
    if (tree->right_ && current_value > tree->threshold_)
        return regression(tree->right_, x);
    return tree->threshold_;
}

由于 CART 回归树与分类树的预测过程几乎完全相同,在此不做赘述。

剪枝

/**
 * @brief
 * 代价复杂度剪枝 CCP
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return 剪枝后的决策树集合
 */
vector<shared_ptr<BinTreeNode>> CartRegressor::pruning(const vector<vector<double>>& X,
                                                       const vector<double>& y) {
    // 递归计算对当前树的每个子树的 g(ti),挑选最小的 g(ti) 进行剪枝,得到新的 T,最终得到 n 个 T
    return split_n_best_trees(X, y);
}

/**
 * @brief
 * 根据 g(ti) 生成 n 个误差最小的树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return n 个误差最小的树
 */
vector<shared_ptr<BinTreeNode>> CartRegressor::split_n_best_trees(const vector<vector<double>>& X,
                                                                  const vector<double>& y) {
    vector<shared_ptr<BinTreeNode>> trees;
    shared_ptr<BinTreeNode> tree = tree_->copy();
    while (tree)
        if (shared_ptr<BinTreeNode> best_tree = split_1_best_trees(tree, X, y)) {
            trees.emplace_back(best_tree);
            tree = best_tree->copy();
        } else
            tree = nullptr;
    return trees;
}

/**
 * @brief
 * 计算α值,选出α值最小的剪枝树
 * @param tree 决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return α值最小的剪枝树
 */
shared_ptr<BinTreeNode> CartRegressor::split_1_best_trees(const shared_ptr<BinTreeNode>& tree,
                                                          const vector<vector<double>>& X,
                                                          const vector<double>& y) {
    // 构建节点信息总集合
    vector<Info> infoSet;
    // 计算数据集长度
    const size_t NT = X.size();
    // 计算误差增加率,并生成信息集合
    calErrorRatio(tree, X, y, NT, infoSet);
    if (infoSet.empty()) return nullptr;
    // a 的比较基准值
    double baseValue = 1;
    int bestNode = 0;
    for (int i = 0; i < infoSet.size(); i++)
        if (infoSet[i].a < baseValue) {
            baseValue = infoSet[i].a;
            bestNode = i;
        } else if (infoSet[i].a == baseValue && infoSet[i].num_leaf_ > infoSet[bestNode].num_leaf_)
            bestNode = i;
    return prunBranch(tree, X, y, infoSet[bestNode]);
}

/**
 * @brief
 * 计算非叶节点误差增加率
 * @param tree 决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @param NT 数据集总样本数目
 * @param infoSet 所有节点的信息总集合
 * @return 各个节点的信息集
 */
Info CartRegressor::calErrorRatio(const shared_ptr<BinTreeNode>& tree, const vector<vector<double>>& X,
                                  const vector<double>& y, size_t NT, vector<Info>& infoSet) {
    const string_view firstFeat = tree->feature_name_;
    const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
    if (tree->left_ && (tree->left_->left_ || tree->left_->right_)) {
        // 划分数据集
        vector<vector<double>> sub_X;
        vector<double> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] <= tree->threshold_) {
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        Info info = calErrorRatio(tree->left_, sub_X, sub_y, NT, infoSet);
        // 在节点信息集中,增加分类前特征
        info.key_ = {true, tree->threshold_};
        infoSet.emplace_back(info);
    }
    if (tree->right_ && (tree->right_->left_ || tree->right_->right_)) {
        // 划分数据集
        vector<vector<double>> sub_X;
        vector<double> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] > tree->threshold_) {
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        Info info = calErrorRatio(tree->right_, sub_X, sub_y, NT, infoSet);
        // 在节点信息集中,增加分类前特征
        info.key_ = {false, tree->threshold_};
        infoSet.emplace_back(info);
    }
    // 计算节点误差率
    const double Rt = static_cast<double>(nodeError(y)) / static_cast<double>(NT);
    // 计算子树误差率
    const double RTt = static_cast<double>(leafError(tree, X, y)) / static_cast<double>(NT);
    // 计算叶节点数目
    const size_t Nt = getNumLeaf(tree);
    const double a = Nt == 1 ? 2 : (Rt - RTt) / static_cast<double>(Nt - 1);
    return {tree, Nt, a};
}

/**
 * @brief
 * 计算非叶节点的误差
 * @param y 训练集目标变量
 * @return 误差
 */
size_t CartRegressor::nodeError(const vector<double>& y) {
    // 计算节点的平方误差
    const double mean_y = accumulate(y.begin(), y.end(), 0.) / static_cast<double>(y.size());
    size_t error = 0;
    for (const double& val : y) error += static_cast<size_t>(pow(val - mean_y, 2));
    return error;
}

/**
 * @brief
 * 计算叶节点的误差
 * @param tree 生成的决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @return 误差
 */
size_t CartRegressor::leafError(const shared_ptr<BinTreeNode>& tree, const vector<vector<double>>& X,
                                const vector<double>& y) {
    size_t error = 0;
    for (int i = 0; i < X.size(); i++) {
        const double pred = regression(tree, X[i]);
        error += static_cast<size_t>(pow(pred - y[i], 2));
    }
    return error;
}

/**
 * @brief
 * 获取叶节点数量
 * @param tree 决策树
 * @return 返回树的叶节点
 */
size_t CartRegressor::getNumLeaf(const shared_ptr<BinTreeNode>& tree) {
    size_t numLeafs = 0;
    if (tree->left_) numLeafs += getNumLeaf(tree->left_);
    if (tree->right_) numLeafs += getNumLeaf(tree->right_);
    if (!tree->left_ && !tree->right_) ++numLeafs;
    return numLeafs;
}

/**
 * @brief
 * 根据误差增加率,剪掉子树
 * @param tree 决策树
 * @param X 训练集属性值
 * @param y 训练集目标变量
 * @param infoBran 需剪掉的子树信息集
 * @return 剪枝后的决策树
 */
shared_ptr<BinTreeNode> CartRegressor::prunBranch(const shared_ptr<BinTreeNode>& tree,
                                                  const vector<vector<double>>& X,
                                                  const vector<double>& y,
                                                  const Info& infoBran) {
    const string_view firstFeat = tree->feature_name_;
    const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
    if (tree->left_) {
        // 划分数据集
        vector<vector<double>> sub_X;
        vector<double> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] <= tree->threshold_) {
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        // 计算该分支的平均值
        const double mean_val = accumulate(sub_y.begin(), sub_y.end(), 0.) / static_cast<double>(sub_y.size());
        // 如果当前子树分类前特征和子树都和预处理相同,则把该子树剪掉
        if (infoBran.key_.first && abs(infoBran.key_.second - tree->threshold_) < 1e-9 &&
            tree->left_ == infoBran.tree_) {
            // 剪掉子树,即返回平均值
            tree->left_ = make_shared<BinTreeNode>();
            tree->left_->threshold_ = mean_val;
            return tree;
        }
        // 如果不相同,继续向下寻找
        tree->left_ = prunBranch(tree->left_, sub_X, sub_y, infoBran);
    }
    if (tree->right_) {
        // 划分数据集
        vector<vector<double>> sub_X;
        vector<double> sub_y;
        for (int i = 0; i < X.size(); i++)
            if (X[i][labelIndex] > tree->threshold_) {
                sub_X.emplace_back(X[i]);
                sub_y.emplace_back(y[i]);
            }
        // 计算该分支的平均值
        const double mean_val = accumulate(sub_y.begin(), sub_y.end(), 0.) / static_cast<double>(sub_y.size());
        // 如果当前子树分类前特征和子树都和预处理相同,则把该子树剪掉
        if (!infoBran.key_.first && abs(infoBran.key_.second - tree->threshold_) < 1e-9 &&
            tree->right_ == infoBran.tree_) {
            // 剪掉子树,即返回平均值
            tree->right_ = make_shared<BinTreeNode>();
            tree->right_->threshold_ = mean_val;
            return tree;
        }
        // 如果不相同,继续向下寻找
        tree->right_ = prunBranch(tree->right_, sub_X, sub_y, infoBran);
    }
    return tree;
}

回归树的剪枝与分类树类似,不同点在于回归树计算误差使用的是均方差。

应用注意事项

  1. 代码中的 feature_name 仅作画图需要,实际生产如无该需求,可以去掉该变量;
  2. 代码未考虑缺失值的处理;
  3. 分类树和回归树中的 CCP 算法,仅在误差计算中有区别。分类树中可以使用基尼系数或误分类率(从效率层面,推荐使用基尼系数),回归树中使用均方差;
  4. 代码没有适配多线程场景;
  5. 其他可能的算法时空复杂度的优化。

目录

  1. 数据结构设计
  2. 二叉树设计
  3. 结点信息设计
  4. 分类树
  5. 训练
  6. 预测
  7. 剪枝
  8. 应用注意事项
  9. 回归树
  10. 训练
  11. 预测
  12. 剪枝
  13. 应用注意事项
  • 💰 8折买阿里云服务器限时8折了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • 大语言模型(LLM)学习路径:从入门到实战指南
  • AIGC 视频生成成本优化实战:文字 + 图片输入下 20 秒与 30 秒模型选型与价格对比
  • 鸿蒙金融理财全栈项目:架构设计、数据安全与体验优化
  • STL 转 STEP 格式转换工具 stltostp 安装与使用
  • Python 中的 with 语句与 try 语句:资源管理对比
  • 基于 Miloco 的全屋智能家居 AI 自动化部署方案
  • OpenClaw 中文发行版部署指南:npm/Docker 多模式安装与配置
  • whisper-large-v3-turbo 模型一键部署指南
  • 修复 Anaconda 开始菜单快捷方式丢失及 mkmenus 报错
  • Python-100-Days:Python 百天从新手到大师学习路径
  • SLAM Toolbox 机器人智能建图技术详解
  • 无人机航测内业处理:iTwin Capture Modeler 建模与土方算量
  • Z 字形变换与外观数列算法实战解析
  • 高级 RAG 技术全解析:优化检索增强生成的最佳实践
  • Z-Image Turbo 本地部署与使用指南
  • VSCode Copilot 配置文件提示未知工具警告解析
  • 2026 年 2 月 AIGC 行业模型发布及前沿资讯
  • Claude Code 本地环境配置与 API 接入指南
  • 前端高频面试题:TypeScript 核心考点与实战
  • OpenClaw 配合 cpolar 实现公网远程访问教程

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • Base64 字符串编码/解码

    将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online