PyTorch 复合函数求导:链式法则与自动微分实战
在深度学习里,我们几乎天天都在跟复合函数打交道。从简单的线性回归到复杂的 Transformer,本质上都是多层函数的嵌套。PyTorch 之所以好用,核心就在于它的自动微分引擎能帮我们搞定这些繁琐的链式法则计算。但很多初学者容易忽略背后的数学原理,导致调参时遇到梯度消失或爆炸却不知原因。今天咱们就聊聊 PyTorch 里的复合函数求导,顺便看看怎么利用它优化模型训练。
链式法则的直观理解
假设有一个复合函数 $y = f(g(x))$,根据微积分知识,$rac{dy}{dx} = rac{dy}{dg} \cdot \frac{dg}{dx}$。在神经网络中,每一层就是一个函数,整个网络就是无数层的复合。PyTorch 的 autograd 模块会自动构建这个计算图,并在反向传播时按图索骥地乘起来。
代码实战:嵌套函数求导
别光听理论,直接上代码。下面这个例子模拟了一个典型的非线性变换过程,包含两层中间变量。
import torch
# 定义输入变量,需要追踪梯度
x = torch.tensor([1.0, 2.0], requires_grad=True)
# 第一层变换:z = x^2
z = x ** 2
# 第二层变换:y = sin(z)
y = torch.sin(z)
# 执行反向传播
y.sum().backward()
# 查看结果
print(f"x 的梯度:{x.grad}")
运行这段代码,你会发现 x.grad 的值正是 $\cos(x^2) \cdot 2x$ 的计算结果。PyTorch 内部已经帮你完成了链式乘法。注意这里用了 y.sum(),因为 backward() 默认要求被求导的对象是标量,如果是向量得先聚合一下。
几个容易踩的坑
1. 梯度累加问题
如果你在一个循环里多次调用 backward() 而没有清空梯度,PyTorch 会把新算出的梯度加到旧梯度上。这在 RNN 或者多步更新时很常见。记得在每次迭代前用 optimizer.zero_grad() 或者 x.grad.zero_() 重置。
2. 中间变量不需要梯度
有时候中间计算结果很大,不需要保留梯度来节省显存。可以用 .detach() 切断计算图,或者在创建 tensor 时设置 requires_grad=False。
3. 动态图的特性
PyTorch 是动态图,这意味着计算图是在运行时构建的。如果你想在反向传播中修改某些参数结构,要注意图的重新构建逻辑。这与 TensorFlow 1.x 的静态图完全不同。
总结
掌握复合函数求导不仅仅是为了应付面试,更是为了调试模型。当你发现 Loss 不下降时,检查一下是不是梯度流断了,或者是不是某一层把梯度给截断了。理解 autograd 的工作机制,能让你写出更高效的代码。


