一、决策树核心原理:深度解析
1.1 信息增益 vs 基尼指数:为什么 CART 用基尼指数?
关键问题:ID3 用信息增益,CART 用基尼指数,选哪个更好?
| 指标 | 信息增益(ID3) | 基尼指数(CART) |
|---|---|---|
| 计算复杂度 | 需计算对数(计算量大) | 仅需平方运算(计算快) |
| 分裂效果 | 信息增益高 → 纯度提升大(但易选多值特征) | 基尼指数小 → 纯度高(对连续特征更友好) |
| 数学公式 | $Gain(S,A) = Ent(S) - \sum_{v} \frac{ | S_v |
| 鸢尾花示例 | 特征 花萼长度:Gain=0.478 → 被选为根节点 | 特征 花萼长度:Gini=0.344 → 被选为根节点 |
为什么 CART 选基尼指数?
以鸢尾花数据集为例,计算花萼长度分裂后的纯度:信息增益:
$Ent(S) = -0.333\log_2 0.333 - 0.333\log_2 0.333 - 0.333\log_2 0.333 = 1.585$
$Ent(S>5.0)=0, Ent(S\le5.0)=1.0$
$Gain = 1.585 - \frac{50}{150}\times 0 - \frac{100}{150}\times 1.0 = 0.918$
基尼指数:
$Gini(S) = 1 - (0.333^2 \times 3) = 0.667$
$Gini(S>5.0)=0, Gini(S\le5.0)=1-(0.5^2 \times 2)=0.5$
$Gini_{split} = \frac{50}{150}\times 0 + \frac{100}{150}\times 0.5 = 0.333$
$Gain = Gini(S) - Gini_{split} = 0.667 - 0.333 = 0.334$
结论:基尼指数计算更快,且与信息增益趋势一致(高 Gain 对应低 Gini)。
二、Python 实现:深度代码解析
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import accuracy_score, classification_report
import numpy as np
import matplotlib.pyplot as plt
# ========== 1. 数据加载与预处理 ==========
iris = load_iris()
X, y = iris.data, iris.target
feature_names = iris.feature_names
target_names = iris.target_names
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=, random_state=, stratify=y
)
dt_default = DecisionTreeClassifier(
random_state=,
max_depth=,
min_samples_split=,
min_samples_leaf=
)
dt_default.fit(X_train, y_train)
()
()
()
(classification_report(y_test, dt_default.predict(X_test)))
path = dt_default.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
clfs = []
alpha ccp_alphas:
clf = DecisionTreeClassifier(
random_state=,
ccp_alpha=alpha
)
clf.fit(X_train, y_train)
clfs.append(clf)
cv_scores = []
i, alpha (ccp_alphas):
scores = cross_val_score(clfs[i], X_train, y_train, cv=)
cv_scores.append(np.mean(scores))
best_alpha = ccp_alphas[np.argmax(cv_scores)]
()
dt_pruned = DecisionTreeClassifier(
random_state=,
ccp_alpha=best_alpha
)
dt_pruned.fit(X_train, y_train)
()
()
()
()
()
plt.figure(figsize=(, ))
plt.subplot(, , )
plot_tree(dt_default, feature_names=feature_names, class_names=target_names, filled=)
plt.title()
plt.subplot(, , )
plot_tree(dt_pruned, feature_names=feature_names, class_names=target_names, filled=)
plt.title()
plt.tight_layout()
plt.savefig(, dpi=)
plt.show()

