Online Softmax 算法原理与 Flash Attention 应用解析
什么是 Softmax?
在机器学习中,Softmax 是一个非常常见的操作,尤其在分类任务中,比如神经网络的输出层。它的作用是将一组数字(通常是模型的原始分数,称为 logits)转化为概率分布。简单来说,Softmax 能让每个输出的值都在 [0, 1] 之间,且所有值的和为 1。
Softmax 的数学公式是:
$$ \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}} $$
这里的 $x_i$ 是输入的第 i 个值,$N$ 是输入的总数。看起来很简单,但实际使用时有几个问题:
- 数值不稳定性:如果 $x_i$ 很大,$e^{x_i}$ 会变得非常大,可能导致溢出。
- 计算效率:需要两次遍历数据——一次找最大值(为了数值稳定),一次计算概率。
- 内存需求:对于大数据集,需要存储所有中间结果。
为了解决这些问题,Online Softmax 算法应运而生!它通过单次遍历和增量更新来高效计算 Softmax,尤其适合流式数据处理场景,比如在 Transformer 的注意力机制(如 FlashAttention)中。
为什么需要 Online Softmax?
想象你正在处理一个超大的数据集,比如实时处理视频帧的特征,或者在 Transformer 中计算注意力分数。如果用传统 Softmax,每次都要把所有数据读一遍来找最大值,再读一遍来算概率,这效率太低了!而且,如果数据是流式到达的(比如在线推理),你根本没法等所有数据都到齐再处理。
Online Softmax 的目标是:
- 单次遍历:只看一遍数据,边看边算。
- 流式处理:数据一块一块来,随时更新结果。
- 内存高效:不用把所有数据都存下来。
接下来,我们一步步推导这个算法,尽量用大白话解释清楚!
Online Softmax 的核心思想
Online Softmax 的核心在于增量更新。我们不一次性处理所有数据,而是每次来一个新数据点,就更新两个关键统计量:
- 当前最大值:记录目前见过的最大输入值。
- 分母的累加和:Softmax 分母是所有 $e^{x_i}$ 的和,我们动态维护它。
当新数据到来时,我们只需要用已有的统计量和新数据点,更新这两个值,就能保证结果正确。这就像在流水线上加工零件,每来一个零件就更新一下生产线上的状态,不用等所有零件都到齐。
算法推导(用大白话解释)
假设我们已经处理了前 j 个数据,得到了:
- 当前最大值 $m_j$(就是前 j 个数里最大的那个)。
- 当前分母 $d_j = \sum_{i=1}^{j} e^{x_i - m_j}$(注意,为了数值稳定,分母里每个项都减去了最大值)。
现在,第 j + 1 个数据 $x_{j+1}$ 来了,我们需要更新 $m_j$ 和 $d_j$。
步骤 1:更新最大值
新来的 $x_{j+1}$ 可能比之前的最大值 $m_j$ 大,也可能小。所以,新的最大值是:
$$ m_{j+1} = \max(m_j, x_{j+1}) $$
这很好理解:新最大值要么是老的最大值,要么是新来的值,挑大的那个就行。
步骤 2:更新分母
分母是 Softmax 的核心部分,我们需要计算新的分母 $d_{j+1}$,它应该是:
$$ d_{j+1} = \sum_{i=1}^{j+1} e^{x_i - m_{j+1}} $$
直接算这个和需要重新遍历所有数据,太麻烦了!我们能不能利用之前的 $d_j$ 来计算 $d_{j+1}$ 呢?答案是可以的,关键在于处理最大值的变化。
旧的分母是:
$$ d_j = \sum_{i=1}^{j} e^{x_i - m_j} $$

