""" 因果推断 | CATE(条件平均处理效应)估计方法:模拟数据下的完整实战 ================================================================= 包含以下方法: 1. S-Learner 2. T-Learner 3. X-Learner 4. DML + 因果森林 5. DML + 线性模型 全部基于 numpy / sklearn 实现,无需 econml。 """
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib import rcParams
from scipy.ndimage import uniform_filter1d
import warnings
warnings.filterwarnings('ignore')
rcParams['font.sans-serif'] = ['DejaVu Sans']
rcParams['axes.unicode_minus'] = False
plt.style.use('seaborn-v0_8-whitegrid')
COLORS = ['#2196F3', '#FF5722', '#4CAF50', '#9C27B0', '#FF9800']
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.model_selection import KFold
def generate_data(n=5000, seed=42):
np.random.seed(seed)
X = np.random.randn(n, 5)
X1, X2, X3 = X[:,0], X[:,1], X[:,2]
tau_true = X1 + 0.5*X2**2 - 1.0
logit_e = 0.5*X1 + 0.3*X3
e_x = 1.0/(1.0+np.exp(-logit_e))
T = np.random.binomial(1, e_x)
mu_0 = 2.0*X1 - X2 + 0.5*X3
Y = mu_0 + T*tau_true + np.random.randn(n)*0.5
print("="*60)
print(" Data Summary")
print("="*60)
print(f" n={n} dim={X.shape[1]} treat_rate={T.mean():.3f} ATE={tau_true.mean():.3f}")
print(f" CATE range: [{tau_true.min():.2f}, {tau_true.max():.2f}]")
print("="*60)
return X, T, Y, tau_true
def s_learner(X, T, Y):
XT = np.column_stack([X, T])
m = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
m.fit(XT, Y)
return m.predict(np.c_[X, np.ones(len(X))]) - m.predict(np.c_[X, np.zeros(len(X))])
def t_learner(X, T, Y):
m1 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
m0 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
m1.fit(X[T==1], Y[T==1]); m0.fit(X[T==0], Y[T==0])
return m1.predict(X) - m0.predict(X)
def x_learner(X, T, Y):
m1 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
m0 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
m1.fit(X[T==1], Y[T==1]); m0.fit(X[T==0], Y[T==0])
D1 = Y[T==1] - m0.predict(X[T==1])
D0 = m1.predict(X[T==0]) - Y[T==0]
tm1 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
tm0 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
tm1.fit(X[T==1], D1); tm0.fit(X[T==0], D0)
ps = LogisticRegression(random_state=42, max_iter=1000); ps.fit(X, T)
e = ps.predict_proba(X)[:,1]
return e*tm0.predict(X) + (1-e)*tm1.predict(X)
def dml_cross_fit(X, T, Y, n_splits=5):
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
Y_res, T_res = np.zeros(len(Y)), np.zeros(len(T), dtype=float)
for tr, te in kf.split(X):
my = GradientBoostingRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, random_state=42)
mt = GradientBoostingRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, random_state=42)
my.fit(X[tr], Y[tr]); mt.fit(X[tr], T[tr].astype(float))
Y_res[te] = Y[te] - my.predict(X[te])
T_res[te] = T[te] - mt.predict(X[te])
return Y_res, T_res
def dml_causal_forest(X, T, Y):
Y_res, T_res = dml_cross_fit(X, T, Y)
T_clip = np.clip(np.abs(T_res), 0.01, None)*np.sign(T_res)
T_clip[T_clip==0] = 0.01
pseudo = Y_res / T_clip
q_lo, q_hi = np.percentile(pseudo, [2, 98])
pseudo = np.clip(pseudo, q_lo, q_hi)
rf = RandomForestRegressor(n_estimators=300, max_depth=6, min_samples_leaf=20, random_state=42)
rf.fit(X, pseudo)
return rf.predict(X)
def dml_linear(X, T, Y):
Y_res, T_res = dml_cross_fit(X, T, Y)
Z = T_res.reshape(-1,1) * X
reg = LinearRegression(fit_intercept=True); reg.fit(Z, Y_res)
tau_hat = X @ reg.coef_ + reg.intercept_
resid = Y_res - reg.predict(Z)
n, p = Z.shape
sigma2 = np.sum(resid**2)/(n-p-1)
se = np.sqrt(np.diag(sigma2 * np.linalg.inv(Z.T@Z + 1e-8*np.eye(p))))
return tau_hat, reg.coef_, se
def evaluate(tau_true, tau_hat, name):
mse = np.mean((tau_true-tau_hat)**2)
mae = np.mean(np.abs(tau_true-tau_hat))
bias = np.mean(tau_hat-tau_true)
ss_res = np.sum((tau_true-tau_hat)**2)
ss_tot = np.sum((tau_true-tau_true.mean())**2)
r2 = 1 - ss_res/ss_tot
corr = np.corrcoef(tau_true, tau_hat)[0,1]
return {'Method':name, 'MSE':round(mse,4), 'MAE':round(mae,4), 'Bias':round(bias,4), 'R2':round(r2,4), 'Corr':round(corr,4)}
def plot_scatter(tau_true, res, path='fig1_cate_scatter.png'):
n = len(res)
fig, axes = plt.subplots(1, n, figsize=(3.8*n, 4), dpi=130)
if n==1: axes=[axes]
for i,(name,th) in enumerate(res.items()):
ax=axes[i]
ax.scatter(tau_true, th, alpha=0.12, s=6, c=COLORS[i%5])
lims=[min(tau_true.min(),th.min())-0.5, max(tau_true.max(),th.max())+0.5]
ax.plot(lims,lims,'k--',lw=1,alpha=0.5)
c=np.corrcoef(tau_true,th)[0,1]; m=np.mean((tau_true-th)**2)
ax.set_title(f'{name}\nCorr={c:.3f} MSE={m:.3f}',fontsize=10)
ax.set_xlabel('True CATE',fontsize=9)
if i==0: ax.set_ylabel('Estimated CATE',fontsize=9)
plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close()
print(f" [saved] {path}")
def plot_by_feature(X, tau_true, res, fi=0, fn='X1', path='fig2_cate_by_x1.png'):
fig,ax=plt.subplots(1,1,figsize=(10,6),dpi=130)
si=np.argsort(X[:,fi]); xs=X[si,fi]; w=100
ax.scatter(xs,tau_true[si],alpha=0.06,s=4,c='gray',label='True CATE')
ax.plot(xs,uniform_filter1d(tau_true[si],w),'k-',lw=2.5,label='True (smoothed)')
for i,(name,th) in enumerate(res.items()):
ax.plot(xs,uniform_filter1d(th[si],w),lw=2,alpha=0.85,color=COLORS[i%5],label=name)
ax.set_xlabel(fn,fontsize=13); ax.set_ylabel('CATE',fontsize=13)
ax.set_title(f'CATE vs {fn}',fontsize=14); ax.legend(fontsize=9,loc='upper left')
plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close()
print(f" [saved] {path}")
def plot_dist(tau_true, res, path='fig3_cate_dist.png'):
fig,ax=plt.subplots(1,1,figsize=(10,6),dpi=130)
ax.hist(tau_true,bins=50,alpha=0.3,color='gray',density=True,label='True CATE')
for i,(name,th) in enumerate(res.items()):
ax.hist(th,bins=50,alpha=0.4,color=COLORS[i%5],density=True,label=name,histtype='step',lw=2)
ax.set_xlabel('CATE',fontsize=13); ax.set_ylabel('Density',fontsize=13)
ax.set_title('CATE Distribution',fontsize=14); ax.legend(fontsize=9)
plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close()
print(f" [saved] {path}")
def plot_error_box(tau_true, res, path='fig4_error_boxplot.png'):
fig,ax=plt.subplots(1,1,figsize=(10,6),dpi=130)
errs=[th-tau_true for th in res.values()]; labs=list(res.keys())
bp=ax.boxplot(errs,labels=labs,patch_artist=True,showfliers=False)
for p,c in zip(bp['boxes'],COLORS[:len(labs)]): p.set_facecolor(c); p.set_alpha(0.5)
ax.axhline(0,color='red',ls='--',lw=1,alpha=0.7)
ax.set_ylabel('Error',fontsize=12); ax.set_title('CATE Estimation Error',fontsize=14)
plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close()
print(f" [saved] {path}")
def plot_coef(coef, se, fnames, path='fig5_dml_linear_coef.png'):
fig,ax=plt.subplots(1,1,figsize=(8,5),dpi=130)
y=np.arange(len(fnames)); ci=1.96*se
ax.barh(y,coef,xerr=ci,color=COLORS[0],alpha=0.7,capsize=5)
ax.set_yticks(y); ax.set_yticklabels(fnames,fontsize=12)
ax.axvline(0,color='red',ls='--',lw=1)
ax.set_xlabel('Coefficient',fontsize=12); ax.set_title('DML-Linear Coefficients (95% CI)',fontsize=13)
for i,(c,s) in enumerate(zip(coef,se)): ax.text(c+ci[i]+0.02, i, f'{c:.3f}', va='center', fontsize=10)
plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close()
print(f" [saved] {path}")
def plot_heatmap(tau_true, X, path='fig6_true_cate_heatmap.png'):
fig,ax=plt.subplots(1,1,figsize=(8,6),dpi=130)
sc=ax.scatter(X[:,0],X[:,1],c=tau_true,cmap='RdYlBu_r',alpha=0.4,s=5,vmin=-3,vmax=5)
plt.colorbar(sc,ax=ax,label='True CATE')
ax.set_xlabel('X1',fontsize=13); ax.set_ylabel('X2',fontsize=13)
ax.set_title(r'True CATE: $\tau(X) = X_1 + 0.5 X_2^2 - 1$',fontsize=14)
plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close()
print(f" [saved] {path}")
def main():
print("\n"+"="*60)
print(" CATE Estimation - Simulation Study")
print("="*60+"\n")
X, T, Y, tau_true = generate_data(n=5000, seed=42)
fnames = ['X1','X2','X3','X4','X5']
res = {}; metrics = []
print("\n>>> [1/5] S-Learner...")
res['S-Learner'] = s_learner(X,T,Y)
metrics.append(evaluate(tau_true, res['S-Learner'], 'S-Learner'))
print(">>> [2/5] T-Learner...")
res['T-Learner'] = t_learner(X,T,Y)
metrics.append(evaluate(tau_true, res['T-Learner'], 'T-Learner'))
print(">>> [3/5] X-Learner...")
res['X-Learner'] = x_learner(X,T,Y)
metrics.append(evaluate(tau_true, res['X-Learner'], 'X-Learner'))
print(">>> [4/5] DML-CausalForest...")
res['DML-CausalForest'] = dml_causal_forest(X,T,Y)
metrics.append(evaluate(tau_true, res['DML-CausalForest'], 'DML-CausalForest'))
print(">>> [5/5] DML-Linear...")
tau_l, coef_l, se_l = dml_linear(X,T,Y)
res['DML-Linear'] = tau_l
metrics.append(evaluate(tau_true, tau_l, 'DML-Linear'))
print("\n"+"="*60)
print(" Evaluation Results")
print("="*60)
df = pd.DataFrame(metrics)
print(df.to_string(index=False))
print("\n"+"="*60)
print(" DML-Linear Coefficients (true: X1=1.0, X2=nonlinear, others=0)")
print("="*60)
for i,f in enumerate(fnames):
lo=coef_l[i]-1.96*se_l[i]; hi=coef_l[i]+1.96*se_l[i]
sig = "*" if lo>0 or hi<0 else ""
print(f" {f:>4}: {coef_l[i]:>7.4f} SE={se_l[i]:.4f} 95%CI=[{lo:.4f}, {hi:.4f}] {sig}")
print("\n>>> Generating plots...")
plot_scatter(tau_true, res)
plot_by_feature(X, tau_true, res, 0, 'X1', 'fig2_cate_by_x1.png')
plot_by_feature(X, tau_true, res, 1, 'X2', 'fig2b_cate_by_x2.png')
plot_dist(tau_true, res)
plot_error_box(tau_true, res)
plot_coef(coef_l, se_l, fnames)
plot_heatmap(tau_true, X)
print("\n"+"="*60)
print(" All done!")
print("="*60)
return res, df
if __name__ == '__main__':
results, df_metrics = main()