双重机器学习之因果推断 | CATE条件平均处理效应估计:五大方法原理详解与模拟数据实战(python版)

家人们我又更新了,代码和科研绘图在论文末尾,欢迎大家评论点赞和收藏,你们的认可是我坚持的动力,祝大家科研顺利。

因果推断 | CATE条件平均处理效应估计:五大方法原理详解与模拟数据实战

本文是因果推断系列文章。本篇聚焦 CATE(Conditional Average Treatment Effect,条件平均处理效应) 的估计,从ATE的局限性讲起,深入介绍S-Learner、T-Learner、X-Learner、因果森林DML和线性DML五种主流方法的原理,并在模拟数据上进行完整的代码实操与效果对比。

1 从ATE到CATE:为什么需要异质性处理效应?

1.1 ATE只能回答"平均有没有用"

ATE(Average Treatment Effect)回答的是:干预措施对整个群体的平均效果是什么

但在实际业务中,我们更想知道的是:对于不同的个体或子群,干预效果有什么不同

举几个例子:

  • 精准营销:给所有人发满减券ATE为正,但拆开看,高消费用户根本不需要券,低消费用户反而是增量用户——CATE帮你找到真正的增量人群。
  • 个性化医疗:某新药对年轻患者效果显著,对老年患者副作用大于疗效——只看ATE可能得出"有效"的结论,但对老年患者施加干预反而有害。
  • 政策评估:教育补贴对低收入家庭的效果远大于高收入家庭——CATE可以指导财政资源的精细化分配。

1.2 CATE的数学定义

CATE(条件平均处理效应)定义为:

其中:

  • Y(1):个体接受处理时的潜在结果
  • Y(0):个体未接受处理时的潜在结果
  • X:协变量/特征向量
  • τ(x):给定特征x条件下的处理效应

CATE是协变量X的函数,刻画了处理效应的异质性(Heterogeneity)。ATE只是CATE的期望:

1.3 估计CATE的核心难点

根本问题(Fundamental Problem of Causal Inference):对同一个个体,我们只能观察到一种潜在结果。要么看到Y(1),要么看到Y(0),不可能同时观察到两者。

因此,τ(x) = E[Y(1) - Y(0) | X=x] 不能直接通过数据计算,需要借助统计方法来估计。


2 五大CATE估计方法原理详解

下面介绍五种最常用的CATE估计方法,按照"从简单到复杂"的顺序排列。

2.1 S-Learner(Single Model Learner)

核心思想:把处理变量T当作一个普通特征,训练一个统一的模型

步骤

  1. 将T拼接到协变量X中,训练模型 μ̂(X, T) 拟合结果Y
  2. CATE估计:

优点:实现最简单,只需训练一个模型。

缺点:当使用正则化模型(如Lasso、随机森林)时,模型可能认为T的贡献不重要而将其效应缩小甚至忽略,导致CATE估计偏向于零(regularization bias)。

def s_learner(X, T, Y): XT = np.column_stack([X, T]) model = GradientBoostingRegressor(n_estimators=200, max_depth=4) model.fit(XT, Y) tau_hat = model.predict(np.c_[X, np.ones(len(X))]) \ - model.predict(np.c_[X, np.zeros(len(X))]) return tau_hat

2.2 T-Learner(Two Model Learner)

核心思想:处理组和控制组各自训练一个模型,然后做差。

步骤

  1. 用处理组数据训练 μ̂₁(x) 拟合 E[Y|X=x, T=1]
  2. 用控制组数据训练 μ̂₀(x) 拟合 E[Y|X=x, T=0]
  3. CATE估计:

优点:允许处理组和控制组的结果模型完全不同,灵活性高。

缺点:当处理组或控制组样本很少时,对应的模型估计不准确。两个独立模型的误差会叠加。

def t_learner(X, T, Y): model_1 = GradientBoostingRegressor().fit(X[T==1], Y[T==1]) model_0 = GradientBoostingRegressor().fit(X[T==0], Y[T==0]) tau_hat = model_1.predict(X) - model_0.predict(X) return tau_hat

2.3 X-Learner(Cross Learner)

核心思想:利用"反事实插补"的思想,用一个组的模型去预测另一个组的反事实结果,弥补T-Learner在样本不均衡时的不足。出自 Künzel et al. (2019) 的经典论文。

步骤(三阶段)

阶段1:分别拟合结果模型(同T-Learner),训练控制组模型 μ̂₀(x) 和处理组模型 μ̂₁(x)。

阶段2:计算伪处理效应(Imputed Treatment Effects)

对处理组个体(观测结果 - 预测的反事实):

对控制组个体(预测的反事实 - 观测结果):

阶段3:拟合CATE模型并加权

  • 处理组CATE模型 τ̂₁(x) 拟合 D̃¹ ~ X
  • 控制组CATE模型 τ̂₀(x) 拟合 D̃⁰ ~ X
  • 最终CATE:

其中e(x)是倾向性得分。直觉是:哪个组样本多,就更多地依赖那个组的CATE估计

优点:在处理组/控制组样本严重不均衡时表现远优于T-Learner。

2.4 因果森林DML(CausalForestDML)

核心思想:结合双重机器学习(DML)的去偏框架因果森林来捕捉非线性异质性效应。

步骤(两阶段)

第一阶段——去偏/正交化(Orthogonalization)

使用机器学习模型去除协变量对Y和T的影响:

关键操作是交叉拟合(Cross-Fitting):将数据分为K折,每折用其余K-1折训练的模型来预测当前折,避免过拟合偏差。

第二阶段——CATE估计

用因果森林(Generalized Random Forest)在残差Ỹ和T̃上估计异质性效应τ(X)。因果森林的分裂准则:最大化子节点间处理效应差异(而非预测误差)。

优点:能捕捉复杂的非线性CATE模式,同时DML去偏保证了估计的无偏性。

2.5 线性DML(LinearDML)

核心思想:与因果森林DML共享相同的第一阶段去偏过程,但第二阶段假设CATE是协变量的线性函数

优点

  • 可解释性极强:β的每个分量直接告诉你对应特征对CATE的边际效应
  • 提供统计推断:可以输出系数的置信区间和p值
  • 适合特征影响近似线性的场景

缺点:无法捕捉非线性异质性(如我们模拟数据中X₂²的效应)。

2.6 方法对比总结

方法模型数能否捕捉非线性CATE是否去偏是否有置信区间适用场景
S-Learner1快速基线
T-Learner2样本均衡
X-Learner4+部分样本不均衡
CausalForestDML通用首选
LinearDML需要可解释性

3 模拟数据设计(DGP)

为了公平评估各方法,我们设计了如下的数据生成过程(DGP):

3.1 变量设计

符号说明
X(5维)协变量,服从标准正态分布
T(0或1)二值处理变量
Y(连续)结果变量

3.2 数据生成公式

真实CATE函数

设计意图:

  • X₁:对CATE有线性影响
  • X₂:对CATE有非线性(二次)影响
  • X₃, X₄, X₅:对CATE无影响(噪声特征)

倾向性得分(处理分配机制)

X₁和X₃同时影响处理分配和结果(混淆变量),这使得朴素估计存在偏差。

基线结果函数

观测结果

3.3 数据生成代码

def generate_data(n=5000, seed=42): np.random.seed(seed) X = np.random.randn(n, 5) X1, X2, X3 = X[:, 0], X[:, 1], X[:, 2] # 真实CATE tau_true = X1 + 0.5 * X2**2 - 1.0 # 倾向性得分(存在混淆) e_x = 1.0 / (1.0 + np.exp(-(0.5*X1 + 0.3*X3))) T = np.random.binomial(1, e_x) # 基线结果 + 处理效应 + 噪声 Y = (2.0*X1 - X2 + 0.5*X3) + T * tau_true + np.random.randn(n)*0.5 return X, T, Y, tau_true

4 实验结果

4.1 评估指标

MSE(均方误差)

Bias(平均偏差)

指标含义越小/大越好
MSE均方误差越小越好
MAE平均绝对误差越小越好
Bias平均偏差越接近0越好
对真实CATE变异的解释比例越大越好
Corr与真实CATE的相关系数越大越好

4.2 各方法评估结果

以n=5000的模拟数据运行结果:

方法MSEMAEBiasCorr
S-Learner0.10030.21680.04690.93620.9788
T-Learner0.13110.26520.04740.91660.9593
X-Learner0.05160.15390.02260.96720.9840
DML-CausalForest0.13290.22200.06660.91550.9648
DML-Linear0.76580.76260.50350.51320.8220

4.3 结果分析

1. X-Learner表现最优(MSE=0.0516, R²=0.9672)

X-Learner通过反事实插补和倾向性加权,充分利用了两组数据的信息互补。在本实验中MSE最低、R²最高,是五种方法中表现最好的。

2. S-Learner表现稳健(MSE=0.1003, R²=0.9362)

S-Learner只用了一个模型,但GBT能自动捕捉T与X的交互效应,表现仅次于X-Learner。在快速验证场景下是很好的基线方法。

3. DML-CausalForest受限于简化实现

本代码中的DML-CausalForest是基于"残差比值+随机森林"的简化实现。完整的因果森林(如EconML的CausalForestDML)使用专门的因果分裂准则,效果会更好。如果安装了econml,建议直接使用CausalForestDML

4. DML-Linear的局限性(MSE=0.7658, R²=0.5132)

由于真实CATE包含X₂²项(非线性),LinearDML无法捕捉这部分异质性。但它成功识别出了X1的系数约为1.0(真实值为1.0),X3/X4/X5的系数接近0(符合真实DGP)。当你需要可解释的系数和置信区间时,LinearDML仍有不可替代的价值

5. T-Learner中规中矩

T-Learner独立训练两个模型,误差会叠加。在样本不均衡或效应较弱时劣势更明显。


5 可视化解读

5.1 CATE估计值 vs 真实值散点图

每个子图展示一种方法的估计值(纵轴)与真实CATE(横轴)的关系。理想情况下所有点应落在对角线上。

  • X-Learner的点最紧密地围绕对角线
  • LinearDML由于线性假设,在CATE极端值处偏离明显

【在此插入 fig1_cate_scatter.png】

5.2 CATE随X1变化趋势

真实CATE与X1呈线性关系(τ随X1增大而增大)。可以看到:

  • 所有方法都能捕捉到这个线性趋势
  • 差异主要体现在噪声大小和边缘区域的偏差

【在此插入 fig2_cate_by_x1.png】

5.3 CATE随X2变化趋势

真实CATE与X2呈U型关系(X₂²项)。这是对各方法的关键测试:

  • CausalForestDML和X-Learner能较好地还原U型
  • LinearDML只能拟合出线性趋势,完全丢失了非线性信息

【在此插入 fig2b_cate_by_x2.png】

5.4 误差箱线图

箱线图展示各方法估计误差的分布:

  • 中位线接近0说明无偏
  • 箱体越窄说明估计越精确
  • X-Learner的箱体最窄且最居中

【在此插入 fig4_error_boxplot.png】


6 实战建议与方法选择指南

6.1 该用哪种方法?

根据你的实际需求选择:

场景1:快速验证异质性是否存在 → 先用 S-Learner 或 T-Learner 作为基线。

场景2:样本不均衡(处理组远少于控制组) → 优先选择 X-Learner

场景3:需要无偏估计 + 非线性CATE → 使用 CausalForestDML(推荐首选)。

场景4:需要系数可解释性和置信区间 → 使用 LinearDML,可输出每个特征对CATE的边际效应。

场景5:高维稀疏特征 → 使用 SparseLinearDML(EconML提供),自动做特征选择。

6.2 注意事项

  1. SUTVA假设:个体间无干扰(Stable Unit Treatment Value Assumption)
  2. 无未观测混淆:所有影响T和Y的变量都被纳入X(Unconfoundedness)
  3. 正值假设:0 < e(X) < 1,每个个体都有可能被分配到任一组(Overlap / Positivity)
  4. 交叉拟合很重要:DML方法中的cross-fitting可以有效降低过拟合偏差
  5. 基学习器选择:第一阶段的ML模型不宜过于复杂,避免引入额外偏差

7 完整代码

完整代码文件 cate_simulation.py 包含:

  • 数据生成函数(可调节样本量、效应强度、混淆程度)
  • 5种CATE估计方法的实现
  • 评估指标计算
  • 6张可视化图表的生成

环境安装

pip install numpy pandas matplotlib scikit-learn scipy

注:本代码纯sklearn实现,无需econml。如需使用完整的CausalForestDML,额外安装 pip install econml

运行方式

python cate_simulation.py

运行后会在当前目录生成以下图表文件:

文件名内容
fig1_cate_scatter.png各方法 CATE 估计 vs 真实值散点图
fig2_cate_by_x1.pngCATE 随 X1 变化趋势对比
fig2b_cate_by_x2.pngCATE 随 X2 变化趋势对比
fig3_cate_dist.pngCATE 估计值分布直方图
fig4_error_boxplot.png估计误差箱线图
fig5_dml_linear_coef.pngLinearDML 系数(含置信区间)
fig6_true_cate_heatmap.png真实 CATE 在 X1-X2 平面的热力图

8 总结

要点说明
CATE是什么条件平均处理效应,刻画处理效应的异质性
为什么需要CATEATE是"一刀切",CATE支持个性化决策
核心难点反事实不可观测,需要统计方法估计
推荐方法CausalForestDML(通用首选)+ LinearDML(需要可解释性时)
关键假设无未观测混淆 + SUTVA + Positivity

下一篇我们将介绍DML在真实数据集上的应用实战,敬请期待!


参考文献

  1. Künzel S R, Sekhon J S, Bickel P J, et al. Metalearners for estimating heterogeneous treatment effects using machine learning[J]. PNAS, 2019.
  2. Chernozhukov V, Chetverikov D, Demirer M, et al. Double/debiased machine learning for treatment and structural parameters[J]. The Econometrics Journal, 2018.
  3. Athey S, Tibshirani J, Wager S. Generalized random forests[J]. The Annals of Statistics, 2019.
  4. EconML Documentation: https://econml.azurewebsites.net/
""" 因果推断 | CATE(条件平均处理效应)估计方法:模拟数据下的完整实战 ================================================================= 包含以下方法: 1. S-Learner 2. T-Learner 3. X-Learner 4. DML + 因果森林 5. DML + 线性模型 全部基于 numpy / sklearn 实现,无需 econml。 """ import numpy as np import pandas as pd import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib import rcParams from scipy.ndimage import uniform_filter1d import warnings warnings.filterwarnings('ignore') rcParams['font.sans-serif'] = ['DejaVu Sans'] rcParams['axes.unicode_minus'] = False plt.style.use('seaborn-v0_8-whitegrid') COLORS = ['#2196F3', '#FF5722', '#4CAF50', '#9C27B0', '#FF9800'] from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor from sklearn.linear_model import LogisticRegression, LinearRegression from sklearn.model_selection import KFold # ==================== DGP ==================== def generate_data(n=5000, seed=42): np.random.seed(seed) X = np.random.randn(n, 5) X1, X2, X3 = X[:,0], X[:,1], X[:,2] tau_true = X1 + 0.5*X2**2 - 1.0 logit_e = 0.5*X1 + 0.3*X3 e_x = 1.0/(1.0+np.exp(-logit_e)) T = np.random.binomial(1, e_x) mu_0 = 2.0*X1 - X2 + 0.5*X3 Y = mu_0 + T*tau_true + np.random.randn(n)*0.5 print("="*60) print(" Data Summary") print("="*60) print(f" n={n} dim={X.shape[1]} treat_rate={T.mean():.3f} ATE={tau_true.mean():.3f}") print(f" CATE range: [{tau_true.min():.2f}, {tau_true.max():.2f}]") print("="*60) return X, T, Y, tau_true # ==================== Methods ==================== def s_learner(X, T, Y): XT = np.column_stack([X, T]) m = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42) m.fit(XT, Y) return m.predict(np.c_[X, np.ones(len(X))]) - m.predict(np.c_[X, np.zeros(len(X))]) def t_learner(X, T, Y): m1 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42) m0 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42) m1.fit(X[T==1], Y[T==1]); m0.fit(X[T==0], Y[T==0]) return m1.predict(X) - m0.predict(X) def x_learner(X, T, Y): m1 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42) m0 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42) m1.fit(X[T==1], Y[T==1]); m0.fit(X[T==0], Y[T==0]) D1 = Y[T==1] - m0.predict(X[T==1]) D0 = m1.predict(X[T==0]) - Y[T==0] tm1 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42) tm0 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42) tm1.fit(X[T==1], D1); tm0.fit(X[T==0], D0) ps = LogisticRegression(random_state=42, max_iter=1000); ps.fit(X, T) e = ps.predict_proba(X)[:,1] return e*tm0.predict(X) + (1-e)*tm1.predict(X) def dml_cross_fit(X, T, Y, n_splits=5): kf = KFold(n_splits=n_splits, shuffle=True, random_state=42) Y_res, T_res = np.zeros(len(Y)), np.zeros(len(T), dtype=float) for tr, te in kf.split(X): my = GradientBoostingRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, random_state=42) mt = GradientBoostingRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, random_state=42) my.fit(X[tr], Y[tr]); mt.fit(X[tr], T[tr].astype(float)) Y_res[te] = Y[te] - my.predict(X[te]) T_res[te] = T[te] - mt.predict(X[te]) return Y_res, T_res def dml_causal_forest(X, T, Y): Y_res, T_res = dml_cross_fit(X, T, Y) T_clip = np.clip(np.abs(T_res), 0.01, None)*np.sign(T_res) T_clip[T_clip==0] = 0.01 pseudo = Y_res / T_clip q_lo, q_hi = np.percentile(pseudo, [2, 98]) pseudo = np.clip(pseudo, q_lo, q_hi) rf = RandomForestRegressor(n_estimators=300, max_depth=6, min_samples_leaf=20, random_state=42) rf.fit(X, pseudo) return rf.predict(X) def dml_linear(X, T, Y): Y_res, T_res = dml_cross_fit(X, T, Y) Z = T_res.reshape(-1,1) * X reg = LinearRegression(fit_intercept=True); reg.fit(Z, Y_res) tau_hat = X @ reg.coef_ + reg.intercept_ resid = Y_res - reg.predict(Z) n, p = Z.shape sigma2 = np.sum(resid**2)/(n-p-1) se = np.sqrt(np.diag(sigma2 * np.linalg.inv(Z.T@Z + 1e-8*np.eye(p)))) return tau_hat, reg.coef_, se # ==================== Evaluation ==================== def evaluate(tau_true, tau_hat, name): mse = np.mean((tau_true-tau_hat)**2) mae = np.mean(np.abs(tau_true-tau_hat)) bias = np.mean(tau_hat-tau_true) ss_res = np.sum((tau_true-tau_hat)**2) ss_tot = np.sum((tau_true-tau_true.mean())**2) r2 = 1 - ss_res/ss_tot corr = np.corrcoef(tau_true, tau_hat)[0,1] return {'Method':name, 'MSE':round(mse,4), 'MAE':round(mae,4), 'Bias':round(bias,4), 'R2':round(r2,4), 'Corr':round(corr,4)} # ==================== Plots ==================== def plot_scatter(tau_true, res, path='fig1_cate_scatter.png'): n = len(res) fig, axes = plt.subplots(1, n, figsize=(3.8*n, 4), dpi=130) if n==1: axes=[axes] for i,(name,th) in enumerate(res.items()): ax=axes[i] ax.scatter(tau_true, th, alpha=0.12, s=6, c=COLORS[i%5]) lims=[min(tau_true.min(),th.min())-0.5, max(tau_true.max(),th.max())+0.5] ax.plot(lims,lims,'k--',lw=1,alpha=0.5) c=np.corrcoef(tau_true,th)[0,1]; m=np.mean((tau_true-th)**2) ax.set_title(f'{name}\nCorr={c:.3f} MSE={m:.3f}',fontsize=10) ax.set_xlabel('True CATE',fontsize=9) if i==0: ax.set_ylabel('Estimated CATE',fontsize=9) plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close() print(f" [saved] {path}") def plot_by_feature(X, tau_true, res, fi=0, fn='X1', path='fig2_cate_by_x1.png'): fig,ax=plt.subplots(1,1,figsize=(10,6),dpi=130) si=np.argsort(X[:,fi]); xs=X[si,fi]; w=100 ax.scatter(xs,tau_true[si],alpha=0.06,s=4,c='gray',label='True CATE') ax.plot(xs,uniform_filter1d(tau_true[si],w),'k-',lw=2.5,label='True (smoothed)') for i,(name,th) in enumerate(res.items()): ax.plot(xs,uniform_filter1d(th[si],w),lw=2,alpha=0.85,color=COLORS[i%5],label=name) ax.set_xlabel(fn,fontsize=13); ax.set_ylabel('CATE',fontsize=13) ax.set_title(f'CATE vs {fn}',fontsize=14); ax.legend(fontsize=9,loc='upper left') plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close() print(f" [saved] {path}") def plot_dist(tau_true, res, path='fig3_cate_dist.png'): fig,ax=plt.subplots(1,1,figsize=(10,6),dpi=130) ax.hist(tau_true,bins=50,alpha=0.3,color='gray',density=True,label='True CATE') for i,(name,th) in enumerate(res.items()): ax.hist(th,bins=50,alpha=0.4,color=COLORS[i%5],density=True,label=name,histtype='step',lw=2) ax.set_xlabel('CATE',fontsize=13); ax.set_ylabel('Density',fontsize=13) ax.set_title('CATE Distribution',fontsize=14); ax.legend(fontsize=9) plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close() print(f" [saved] {path}") def plot_error_box(tau_true, res, path='fig4_error_boxplot.png'): fig,ax=plt.subplots(1,1,figsize=(10,6),dpi=130) errs=[th-tau_true for th in res.values()]; labs=list(res.keys()) bp=ax.boxplot(errs,labels=labs,patch_artist=True,showfliers=False) for p,c in zip(bp['boxes'],COLORS[:len(labs)]): p.set_facecolor(c); p.set_alpha(0.5) ax.axhline(0,color='red',ls='--',lw=1,alpha=0.7) ax.set_ylabel('Error',fontsize=12); ax.set_title('CATE Estimation Error',fontsize=14) plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close() print(f" [saved] {path}") def plot_coef(coef, se, fnames, path='fig5_dml_linear_coef.png'): fig,ax=plt.subplots(1,1,figsize=(8,5),dpi=130) y=np.arange(len(fnames)); ci=1.96*se ax.barh(y,coef,xerr=ci,color=COLORS[0],alpha=0.7,capsize=5) ax.set_yticks(y); ax.set_yticklabels(fnames,fontsize=12) ax.axvline(0,color='red',ls='--',lw=1) ax.set_xlabel('Coefficient',fontsize=12); ax.set_title('DML-Linear Coefficients (95% CI)',fontsize=13) for i,(c,s) in enumerate(zip(coef,se)): ax.text(c+ci[i]+0.02, i, f'{c:.3f}', va='center', fontsize=10) plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close() print(f" [saved] {path}") def plot_heatmap(tau_true, X, path='fig6_true_cate_heatmap.png'): fig,ax=plt.subplots(1,1,figsize=(8,6),dpi=130) sc=ax.scatter(X[:,0],X[:,1],c=tau_true,cmap='RdYlBu_r',alpha=0.4,s=5,vmin=-3,vmax=5) plt.colorbar(sc,ax=ax,label='True CATE') ax.set_xlabel('X1',fontsize=13); ax.set_ylabel('X2',fontsize=13) ax.set_title(r'True CATE: $\tau(X) = X_1 + 0.5 X_2^2 - 1$',fontsize=14) plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close() print(f" [saved] {path}") # ==================== Main ==================== def main(): print("\n"+"="*60) print(" CATE Estimation - Simulation Study") print("="*60+"\n") X, T, Y, tau_true = generate_data(n=5000, seed=42) fnames = ['X1','X2','X3','X4','X5'] res = {}; metrics = [] print("\n>>> [1/5] S-Learner...") res['S-Learner'] = s_learner(X,T,Y) metrics.append(evaluate(tau_true, res['S-Learner'], 'S-Learner')) print(">>> [2/5] T-Learner...") res['T-Learner'] = t_learner(X,T,Y) metrics.append(evaluate(tau_true, res['T-Learner'], 'T-Learner')) print(">>> [3/5] X-Learner...") res['X-Learner'] = x_learner(X,T,Y) metrics.append(evaluate(tau_true, res['X-Learner'], 'X-Learner')) print(">>> [4/5] DML-CausalForest...") res['DML-CausalForest'] = dml_causal_forest(X,T,Y) metrics.append(evaluate(tau_true, res['DML-CausalForest'], 'DML-CausalForest')) print(">>> [5/5] DML-Linear...") tau_l, coef_l, se_l = dml_linear(X,T,Y) res['DML-Linear'] = tau_l metrics.append(evaluate(tau_true, tau_l, 'DML-Linear')) # Results table print("\n"+"="*60) print(" Evaluation Results") print("="*60) df = pd.DataFrame(metrics) print(df.to_string(index=False)) # DML-Linear coefficients print("\n"+"="*60) print(" DML-Linear Coefficients (true: X1=1.0, X2=nonlinear, others=0)") print("="*60) for i,f in enumerate(fnames): lo=coef_l[i]-1.96*se_l[i]; hi=coef_l[i]+1.96*se_l[i] sig = "*" if lo>0 or hi<0 else "" print(f" {f:>4}: {coef_l[i]:>7.4f} SE={se_l[i]:.4f} 95%CI=[{lo:.4f}, {hi:.4f}] {sig}") # Plots print("\n>>> Generating plots...") plot_scatter(tau_true, res) plot_by_feature(X, tau_true, res, 0, 'X1', 'fig2_cate_by_x1.png') plot_by_feature(X, tau_true, res, 1, 'X2', 'fig2b_cate_by_x2.png') plot_dist(tau_true, res) plot_error_box(tau_true, res) plot_coef(coef_l, se_l, fnames) plot_heatmap(tau_true, X) print("\n"+"="*60) print(" All done!") print("="*60) return res, df if __name__ == '__main__': results, df_metrics = main()

如果觉得有帮助,请点赞+收藏+关注,你的支持是我持续输出的动力!

Read more

OpenClaw dashboard命令后,无法登录web控制面板(在systemd服务无法启动的一些虚拟机里会碰到)

OpenClaw dashboard命令后,无法登录web控制面板(在systemd服务无法启动的一些虚拟机里会碰到)

先上结论 执行OpenClaw dashboard命令后,无法登录web控制面板,是因为OpenClaw的gateway服务没有起来。原来小龙虾OpenClaw 的命令没有学明白,先弄清楚命令: openclaw onboard 是配置 openclaw dashboard是显示web控制面板登录信息 openclaw gateway --verbose 是启动网关 openclaw gateway start是启动网关服务 问题就是因为这台系统的systemd没有起作用,导致openclaw的gateway服务没有起来,所以控制面板无法登录。 OpenClaw status Overview ┌─────────────────┬───────────────────────────────────────────────────────────────────────────────────────────────────┐ │ Item │ Value │ ├─────────────────┼────────────────────────────────────

By Ne0inhk
SpringBoot源码解析(十):应用上下文AnnotationConfigServletWebServerApplicationContext构造方法

SpringBoot源码解析(十):应用上下文AnnotationConfigServletWebServerApplicationContext构造方法

SpringBoot源码系列文章 SpringBoot源码解析(一):SpringApplication构造方法 SpringBoot源码解析(二):引导上下文DefaultBootstrapContext SpringBoot源码解析(三):启动开始阶段 SpringBoot源码解析(四):解析应用参数args SpringBoot源码解析(五):准备应用环境 SpringBoot源码解析(六):打印Banner SpringBoot源码解析(七):应用上下文结构体系 SpringBoot源码解析(八):Bean工厂接口体系 SpringBoot源码解析(九):Bean定义接口体系 SpringBoot源码解析(十):应用上下文AnnotationConfigServletWebServerApplicationContext构造方法 目录 * 前言 * 源码入口 * 一、初始化注解Bean定义读取器 * 1、BeanDefinitionRegistry(Bean定义注册接口) * 2、获取环境对象Environment * 3、注

By Ne0inhk
YOLO可视化界面,目标检测前端QT页面。

YOLO可视化界面,目标检测前端QT页面。

使用PySide6/QT实现YOLOv8可视化GUI页面 在人工智能和计算机视觉领域,YOLO(You Only Look Once)是一种广泛使用的实时目标检测算法。为了直观地展示YOLO算法的检测效果,我们可以使用Python中的PySide6库来创建一个简单的GUI应用程序,将检测结果实时可视化。 本文将指导你如何使用PySide6实现这一功能。 1. 原视频/图片区:上半部分左边区域为原视频/图片展示区; 2. 检测区:上半部分右边区域为检测结果输出展示区; 3. 日志文本框:打印输出操作日志; 4. 加载模型:从本地选择模型pt文件进行加载; 5. 置信度阈值:自定义检测区的置信度阈值; 6. 文件上传:选择目标文件; 7. 开始检测:执行检测程序; 8. 停止:终止检测程序; 一、工具介绍 1、PySide6 PySide6是一款功能强大的GUI(图形用户界面)开发框架,它允许Python开发者使用Qt库的功能来构建跨平台的桌面应用程序。PySide6作为Qt的Python绑定版本,继承了Qt的跨平台特性,支持在Windows、

By Ne0inhk

Lottie-Web 完整技术指南:让动画开发更简单高效

📚 目录 * 一、什么是 Lottie-Web * 二、为什么选择 Lottie-Web * 三、安装与引入 * 四、基础使用 * 五、API 详解 * 六、Vue 集成实战 * 七、高级特性 * 八、性能优化 * 九、常见问题与解决方案 * 十、最佳实践 * 十一、实际应用场景 * 十二、总结 一、什么是 Lottie-Web 1.1 Lottie 简介 Lottie 是 Airbnb 开源的一个动画库,它可以将 After Effects 动画导出为 JSON 格式,然后在 Web、iOS、Android

By Ne0inhk