LLM 核心技术:Attention 机制的实现与优化
背景介绍
在大型语言模型(LLM)的架构中,Attention 机制是核心组件,决定了模型处理序列数据的能力。随着上下文窗口(Context Window)的不断扩展,传统的 Attention 实现方式面临着计算复杂度高、显存占用大等挑战。本文深入探讨 Multi-Head Attention (MHA) 的原理及其多种优化方案,包括 MQA、GQA、SWA、FlashAttention 和 PagedAttention,旨在提升模型训练与推理的性能。
Multi-Head Attention (MHA)
原理与计算流程
MHA 的目标在于重构文本中的 Token Embedding 表示,使其能够捕捉上下文语义相关性和位置相关性。其计算过程主要包含以下步骤:
- Embedding Lookup:输入文本长度为 n(n 个 token),经过 Embedding Table 后,每个 token 返回一个大小为 (1, d) 的向量。对于长度为 n 的文本,生成 Embedding Matrix,大小为 (n, d),其中 d 为 Embedding 维度。
- 线性映射:Embedding Matrix X 进入 MHA 层后,通过线性变换生成 Query (Q)、Key (K)、Value (V)。假设 Head 数量为 h,每个 Head 的维度为 k 或 v(通常 k=v=d/h)。
- Q 维度:(h, n, k)
- K 维度:(h, n, k)
- V 维度:(h, n, v)
- Attention 计算:对每个 Head 执行 Softmax(QK^T / sqrt(d_k)) * V。由于 n 个 Token 两两交互,时间复杂度为 O(n^2)。
复杂度分析
MHA 的整体计算复杂度与上下文长度 n 的二次方成正比,与模型规模 d 的二次方成正比。公式如下:
$$ \text{Complexity} = O(n^2 \cdot d) $$
增大 Context 长度会带来计算复杂度的二次方增长,这限制了长文本的处理能力。同时,在自回归推理过程中,为了加速解码,需要缓存之前生成的 Key 和 Value (KV Cache),导致显存占用随序列长度线性增加。
推理优化:MQA 与 GQA
Multi-Query Attention (MQA)
在标准 MHA 中,每个 Head 都有独立的 K 和 V 矩阵。在推理时,GPU 显存占用会随着预测 Token 数目增加而累积。
MQA 通过在不同 Head 间共享 K 和 V 矩阵来优化。即所有 Head 使用同一组 Key 和 Value,仅 Query 独立。这使得存储的 K/V 矩阵数量从 2h 降低为 2 个。虽然显著降低了显存占用并提高了推理速度,但可能因信息压缩导致精度略有下降。
Group Query Attention (GQA)
GQA 是对 MQA 的改进,它在 Head 之间进行分组。一个 Group 内的多个 Head 共享一组 K 和 V,不同 Group 之间则独立。这种方式在保持接近 MHA 效果的同时,大幅减少了 KV Cache 的大小。
GQA 实现逻辑
# 初始化投影层
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
.v_proj = nn.Linear(.hidden_size, .num_key_value_heads * .head_dim, bias=)
key_states = repeat_kv(key_states, .num_key_value_groups)
value_states = repeat_kv(value_states, .num_key_value_groups)


