网格搜索参数优化的 XGBoost+SHAP 回归预测及可视化分析
网格搜索结合 XGBoost 模型进行回归预测,利用 SHAP 库实现模型可解释性分析。文章涵盖特征相关性热图、散点密度图等可视化方法,展示超参数优化过程及 RMSE 评估指标。提供完整 Python 代码流程,包括数据读取、归一化、训练、预测、评估及结果保存(MAT 文件),适用于科研论文中的模型物理解释需求。

网格搜索结合 XGBoost 模型进行回归预测,利用 SHAP 库实现模型可解释性分析。文章涵盖特征相关性热图、散点密度图等可视化方法,展示超参数优化过程及 RMSE 评估指标。提供完整 Python 代码流程,包括数据读取、归一化、训练、预测、评估及结果保存(MAT 文件),适用于科研论文中的模型物理解释需求。

一句话概括:
'网格搜索就是把所有设定好的超参数组合排成一个'网格',逐个尝试,通过评估结果找到表现最佳的那一组参数。' 就像在一个二维或多维坐标空间里,把所有候选参数都排列出来,然后把每个点都跑一遍,最终选出模型表现最优的位置。
网格搜索的理念非常直观:
网格搜索的价值主要体现在几个方面:
你可以完全决定参数候选集,调参过程完全透明。
特别适合探索学习率、树深、正则项等关键参数的小步长变化。
与 Cross-Validation 结合后,能够获得稳定、可靠的参数评估结果。
每个组合都被尝试过,调参过程完整记录,适合科研工作。
网格搜索广泛应用于:
在你的任务里,网格搜索非常适合用于关键参数的局部精调,确保模型在最佳点附近充分探索。
该图展示 GridSearchCV 调参过程中各超参数与 RMSE 的相关性重要性,其中 learning_rate、reg_alpha 和 n_estimators 影响最明显,可用于识别关键参数并指导后续调参方向。

上述三条目录的基本原理已在前置推文中做过详细介绍,需要学习了解的请查阅相关文档。
本程序 SHAP 带的图包括:

这些图都是发论文神器。论文价值:可解释性直接提升一档。 SCI 论文里 reviewer 最爱问:
特征值相关性热图用于展示各特征之间的相关强弱,通过颜色深浅体现正负相关关系,帮助快速识别冗余特征、强相关特征及可能影响模型稳定性的变量,为后续特征选择和建模提供参考。

散点密度图通过颜色或亮度反映点的聚集程度,用于展示大量样本的分布特征。相比普通散点图,它能更直观地呈现高密度区域、异常点及整体趋势,常用于回归分析与模型评估。以下为训练集和测试集出图效果。

我的代码程序中将参数最优值输出到当前目录的 best_params.txt 文本中,
并将训练集和测试集的精度评估指标保存到 metrics.mat 矩阵中。共两行,第一行代表训练集的,第二行代表测试集的;共 7 个精度评估指标,分别代表 R, R2, ME, MAE, MAPE, RMSE 以及样本数量。
保存的 regression_result.mat 数据中分别保存了名字为 Y_train、y_pred_train、y_test、y_pred_test 的矩阵向量。
同样的针对大家各自的数据训练出的模型结构也保存在 model.json 中,方便再一次调用。
调用的程序我在程序中注释了,如下
# 加载模型
# model.load_model("model.json")
主程序如下,其中从 1-10,每一步都有详细的注释,要获取完整程序,请转下文代码获取
# ========================================================= # 主程序 # ========================================================= def main(): print("=== 1. 读取数据 ===") data = pd.read_excel("data.xlsx") X = data.iloc[:, :10].values y = data.iloc[:, 10].values feature_names = list(data.columns[:10]) print("=== 2. 划分训练与测试 ===") X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42 ) print("=== 3. 归一化 ===") scaler_X = MinMaxScaler() scaler_y = MinMaxScaler() X_train_norm = scaler_X.fit_transform(X_train) X_test_norm = scaler_X.transform(X_test) y_train_norm = scaler_y.fit_transform(y_train.reshape(-1, 1)).ravel() print("=== 4. 模型训练 ===") model = train_model(X_train_norm, y_train_norm) print("=== 5. 预测(反归一化到原始尺度) ===") y_pred_train_norm = model.predict(X_train_norm) y_pred_test_norm = model.predict(X_test_norm) y_pred_train = scaler_y.inverse_transform( y_pred_train_norm.reshape(-1, 1) ).ravel() y_pred_test = scaler_y.inverse_transform( y_pred_test_norm.reshape(-1, 1) ).ravel() print("=== 6. 模型评估 ===") metrics_train = evaluate_model(y_train, y_pred_train) metrics_test = evaluate_model(y_test, y_pred_test) print("\n训练集评估指标:") for k, v in metrics_train.items(): print(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}") print("\n测试集评估指标:") for k, v in metrics_test.items(): print(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}") print("=== 7. 保存结果到 MAT 文件 ===") result_dict = { "y_train": y_train.astype(float), "y_pred_train": y_pred_train.astype(float), "y_test": y_test.astype(float), "y_pred_test": y_pred_test.astype(float), } savemat("regression_result.mat", result_dict) print("已保存 regression_result.mat") # 按指标顺序排列 metrics_matrix = np.array([ [metrics_train['R'], metrics_test['R']], [metrics_train['R2'], metrics_test['R2']], [metrics_train['ME'], metrics_test['ME']], [metrics_train['MAE'], metrics_test['MAE']], [metrics_train['MAPE'], metrics_test['MAPE']], [metrics_train['RMSE'], metrics_test['RMSE']], [metrics_train['样本数'], metrics_test['样本数']] ], dtype=float) savemat("metrics.mat", {"metrics": metrics_matrix}) print("已保存 metrics.mat(矩阵大小 7×2)") print("=== 8. SHAP 分析 ===") X_combined = np.vstack([X_train_norm, X_test_norm]) X_df = pd.DataFrame(X_combined, columns=feature_names) # shap_results = shap_analysis(model, X_combined, feature_names) plot_shap_dependence(model, X_combined, feature_names, X_df) print("=== 9. 密度散点图 ===") plot_density_scatter( y_test, y_pred_test, save_path="scatter_density_test.png" ) plot_density_scatter( y_train, y_pred_train, save_path="scatter_density_train.png" ) print("=== 10. 相关性热图 ===") correlation_heatmap(data, feature_names) print("=== 完成!===") if __name__ == "__main__": main()

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online