集成学习领域的经典算法,随机森林凭借泛化能力强、抗过拟合和易用性在数据挖掘和工业界广泛应用。本文从基础原理出发,结合 Python 实战落地,解析核心机制与调优技巧。
一、随机森林是什么?
随机森林(Random Forest)由 Leo Breiman 于 2001 年提出,核心思想是'多棵决策树协同工作'。通过对样本和特征的双重随机抽样,构建多棵独立的决策树,最终通过投票(分类任务)或平均(回归任务)得到结果。
关键定位:随机森林是"Bagging 集成 + 决策树"的经典组合,属于并行集成学习算法(各决策树独立训练,可并行计算)。
理解随机森林前,需回顾两个核心基础:
- 决策树:随机森林的基学习器,通过递归分裂特征构建树状结构,单棵树易过拟合、稳定性差;
- Bagging 集成:通过 bootstrap 抽样生成多个训练集,训练多棵基学习器,最后融合结果降低方差。
二、核心原理:双重随机性
随机森林的性能优势,根源在于其'双重随机性'设计——样本随机抽样和特征随机选择,这两个步骤从根本上降低了基学习器的相关性,提升了集成效果。

1. 样本随机(Bootstrap 抽样)
假设原始训练集有 N 个样本,构建每棵决策树时,都会从原始集中有放回地随机抽取 N 个样本作为该树的训练集。
- 袋外样本(OOB)的价值:由于是有放回抽样,约 37% 的样本不会被抽到,这部分样本称为'袋外样本'。它可作为免费的验证集,无需单独划分数据即可评估模型性能。在 sklearn 中,可通过设置
oob_score=True启用该功能,训练后通过rf_clf.oob_score_获取 OOB 准确率。 - 样本多样性保障:每棵树的训练集都是独立抽样生成的,避免了单一样本对模型的过度影响,让多棵树的预测更具差异性。
2. 特征随机选择
单棵决策树分裂时,先从全部 M 个特征中随机选择 k 个特征(k<M),再从这 k 个特征中选择最优分裂点。这是区别于普通 Bagging 集成的关键。
- k 值的科学选择:分类任务默认取√M(sklearn 中
max_features="sqrt"),回归任务默认取 M/3(max_features="auto")。实际调优时,可在 [√M, M/2] 区间测试。 - 打破强特征垄断:若数据中存在强特征,普通决策树会反复使用导致树间高度相似。特征随机迫使树探索其他特征的组合价值,提升树群的多样性。
3. 结果融合
所有决策树训练完成后,通过'少数服从多数'(分类)或'均值平均'(回归)得到最终结果。
三、Python 实战:分类与回归
下面用 sklearn 库实现随机森林的分类(鸢尾花数据集)和回归(加州房价数据集)任务。
1. 环境准备
pip install scikit-learn pandas numpy matplotlib
2. 随机森林分类(鸢尾花)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
sklearn.datasets load_iris
sklearn.ensemble RandomForestClassifier
sklearn.model_selection train_test_split
sklearn.metrics accuracy_score, confusion_matrix, classification_report
iris = load_iris()
X = iris.data
y = iris.target
(, X.shape, y.shape)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=, random_state=, stratify=y
)
rf_clf = RandomForestClassifier(
n_estimators=,
max_depth=,
min_samples_split=,
random_state=
)
rf_clf.fit(X_train, y_train)
y_pred = rf_clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
()
feature_importance = pd.DataFrame({
: iris.feature_names,
: rf_clf.feature_importances_
}).sort_values(by=, ascending=)
plt.figure(figsize=(, ))
plt.barh(feature_importance[], feature_importance[], color=)
plt.xlabel()
plt.title()
plt.show()

