Attention 计算中为何需要除以根号 d
问题背景
在 Attention 的计算公式中,为什么要除以 (\sqrt{d_k})?
这个问题是 NLP 面试中的高频考点,几乎在问到 Attention 或者 Transformer 架构时都会被提及。作为面试官,这能很快考察出候选人的数学功底和对底层原理的理解。
如果你是 NLP 领域的学生或从业者,建议先思考一下。本文将从数学推导和实验验证两个角度详细解答。
原始论文的解释
在《Attention is All You Need》的原始论文中给出了初步解释:
While for small values of (d_k), the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of (d_k). We suspect that for large values of (d_k), the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this effect, we scale the dot products by (1/\sqrt{d_k}).
作者指出,当 (d_k) 变大时,点积的数值会增大,导致 Softmax 函数进入梯度极小的饱和区。为了抵消这种影响,将点积缩放 (1/\sqrt{d_k})。
但这引出了两个更深层的问题:
- 为什么会导致梯度消失?
- 为什么恰好是 (\sqrt{d_k}),而不是其他值?
下面进行详细推导。
变大为什么会导致梯度消失?
结论: 如果 (d_k) 变大,点积的方差会变大。方差变大会导致向量元素间的差值变大,进而导致 Softmax 退化为 Argmax,最终使得反向传播的梯度变为 0。
第一点:(d_k) 变大,方差会变大
假设查询向量 (Q) 和键向量 (K) 的元素均值为 0,方差为 1,且相互独立。则它们的点积 (Q \cdot K = \sum_{i=1}^{d_k} Q_i K_i) 的方差为:
$$ \text{Var}(Q \cdot K) = \sum_{i=1}^{d_k} \text{Var}(Q_i K_i) = d_k \times 1 = d_k $$
因此,当 (d_k) 变大时,点积的方差线性增加。
第二点:方差变大会导致元素差值变大
方差越大,代表数据分布越分散。对于正态分布,最大值的期望随标准差(方差的平方根)增大而增大。这意味着向量中最大值与最小值的差距会显著拉大。
第三点:Softmax 退化为 Argmax
Softmax 函数的定义为:
$$ \text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}} $$
设 (x_{max}) 为向量中的最大值,令 (x_i = x_{max} - \delta_i),其中 (\delta_i \ge 0)。代入公式得:
$$ \text{softmax}(x)i = \frac{e^{x{max} - \delta_i}}{\sum_j e^{x_{max} - \delta_j}} = \frac{e^{-\delta_i}}{\sum_j e^{-\delta_j}} $$
当方差很大时,非最大元素的 (\delta_i) 会非常大(例如大于 10),此时 (e^{-\delta_i}) 趋近于 0。只有最大值对应的项((\delta=0))保留为 1,其余项均为 0。即 Softmax 退化为 Argmax。
第四点:Softmax 梯度消失
Softmax 的雅可比矩阵(Jacobian Matrix)元素为:
$$ \frac{\partial y_i}{\partial x_j} = y_i (\delta_{ij} - y_j) $$
当 Softmax 退化为 One-hot 向量(Argmax)时,输出向量中只有一个元素为 1,其余为 0。此时,对于任意输入向量,雅可比矩阵将变为全零矩阵,意味着梯度全部消失,模型无法更新参数。
实验验证
我们通过 Python 代码验证上述理论。对比方差为 1 和方差较大时的 Softmax 输出及梯度。
numpy np
np.random.seed()
n =
x1 = np.random.normal(loc=, scale=, size=n)
x2 = np.random.normal(loc=, scale=np.sqrt(), size=n)
(, (x1) - (x1))
(, (x2) - (x2))
():
exp_x = np.exp(x - np.(x))
exp_x / np.(exp_x)
():
jacobian = np.diag(y) - np.outer(y, y)
np.linalg.norm(jacobian)
y1 = softmax(x1)
y2 = softmax(x2)
(, y1)
(, y2)
(, softmax_grad(y1))
(, softmax_grad(y2))


