LLM 核心技术:Attention 机制的实现与优化
背景介绍
在大型语言模型(LLM)的架构中,Attention 机制是核心组件,决定了模型处理序列数据的能力。随着上下文窗口(Context Window)的不断扩展,传统的 Attention 实现方式面临着计算复杂度高、显存占用大等挑战。本文深入探讨 Multi-Head Attention (MHA) 的原理及其多种优化方案,包括 MQA、GQA、SWA、FlashAttention 和 PagedAttention,旨在提升模型训练与推理的性能。
Multi-Head Attention (MHA)
原理与计算流程
MHA 的目标在于重构文本中的 Token Embedding 表示,使其能够捕捉上下文语义相关性和位置相关性。其计算过程主要包含以下步骤:
- Embedding Lookup:输入文本长度为 n(n 个 token),经过 Embedding Table 后,每个 token 返回一个大小为 (1, d) 的向量。对于长度为 n 的文本,生成 Embedding Matrix,大小为 (n, d),其中 d 为 Embedding 维度。
- 线性映射:Embedding Matrix X 进入 MHA 层后,通过线性变换生成 Query (Q)、Key (K)、Value (V)。假设 Head 数量为 h,每个 Head 的维度为 k 或 v(通常 k=v=d/h)。
- Q 维度:(h, n, k)
- K 维度:(h, n, k)
- V 维度:(h, n, v)
- Attention 计算:对每个 Head 执行 Softmax(QK^T / sqrt(d_k)) * V。由于 n 个 Token 两两交互,时间复杂度为 O(n^2)。
复杂度分析
MHA 的整体计算复杂度与上下文长度 n 的二次方成正比,与模型规模 d 的二次方成正比。公式如下:
$$ \text{Complexity} = O(n^2 \cdot d) $$
增大 Context 长度会带来计算复杂度的二次方增长,这限制了长文本的处理能力。同时,在自回归推理过程中,为了加速解码,需要缓存之前生成的 Key 和 Value (KV Cache),导致显存占用随序列长度线性增加。
推理优化:MQA 与 GQA
Multi-Query Attention (MQA)
在标准 MHA 中,每个 Head 都有独立的 K 和 V 矩阵。在推理时,GPU 显存占用会随着预测 Token 数目增加而累积。
MQA 通过在不同 Head 间共享 K 和 V 矩阵来优化。即所有 Head 使用同一组 Key 和 Value,仅 Query 独立。这使得存储的 K/V 矩阵数量从 2h 降低为 2 个。虽然显著降低了显存占用并提高了推理速度,但可能因信息压缩导致精度略有下降。
Group Query Attention (GQA)
GQA 是对 MQA 的改进,它在 Head 之间进行分组。一个 Group 内的多个 Head 共享一组 K 和 V,不同 Group 之间则独立。这种方式在保持接近 MHA 效果的同时,大幅减少了 KV Cache 的大小。
GQA 实现逻辑
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
内存开销对比
- MHA 显存开销:
batch * max_seq_len * n_heads * head_dim * sizeof(half) * 2
- GQA 显存开销:
batch * max_seq_len * n_kv_heads * head_dim * sizeof(half) * 2
其中 n_heads / n_kv_heads 即为 Group 大小。使用 GQA 可将 KV Cache 显存降低到 MHA 的 1/group 水平,极大缓解了访存密集型计算的瓶颈。
窗口注意力:Sliding Window Attention (SWA)
为了进一步降低 Attention 与 Context Length 的依赖关系,SWA 限制了每个 Token 只能关注前 W 个 Token。这将时间复杂度从 $O(n^2)$ 降低至 $O(n \cdot w)$。
在 SWA 中,注意力的传递通过层数增加向后延伸。每一层注意力层允许信息传递 W tokens。例如,对于 16k 序列长度和 4k 滑动窗口,通过 4 层即可实现整个序列信息的传递。这种机制特别适合长文本场景,但需注意长距离依赖可能丢失的问题。
底层算子优化:FlashAttention
FlashAttention 旨在解决 Attention 计算过程中频繁访问 HBM(High Bandwidth Memory)的问题。它利用 GPU 中 SRAM(片上高速缓存)速度快但容量小的特点,将 Attention 计算 Block 化,直接在 SRAM 中进行。
核心优化点
- IO 感知计算:减少 HBM 读写次数。传统方法需要将中间结果写入 HBM,FlashAttention 通过分块计算避免此步骤。
- 分块 Softmax:为了保证分块计算的 Softmax 值与原值一致,采用了在线 Softmax 算法,结合重计算(Recomputation)策略。
- 并行计算:将 Q、K、V 切分为 Block,提高 Operation 的处理效率。
FlashAttention 不仅提升了训练速度,也显著改善了推理延迟,是目前主流的大模型框架(如 PyTorch FSDP)默认启用的优化方案。
内存管理优化:PagedAttention
PagedAttention 解决了 Attention 计算过程中的内存分配问题,特别是针对 KV Cache 的动态变化特性。
传统 KV Cache 问题
- 显存占用大:大型模型单个序列的 KV Cache 可能占用高达 GB 级显存。
- 动态变化:序列长度不可预测,难以预分配。
- 内存碎片化:静态批处理策略下,请求结束后剩余空间浪费严重(内部碎片);连续内存分配要求导致外部碎片。
PagedAttention 优势
PagedAttention 引入了虚拟内存的概念,允许 KV Cache 在非连续的物理内存块中存储。系统维护一个页表来映射虚拟地址到物理地址。
- 非连续存储:解决了连续内存分配造成的空间浪费。
- 动态分配:支持小空间内存的动态分配,适应不同长度的序列。
- 高吞吐量:通过减少碎片和优化内存利用率,可以实现更大的 Batch Size 和更高的吞吐量。
总结
Attention 机制的优化是大模型性能提升的关键路径。从算法层面的 MQA/GQA 减少参数量和显存,到系统层面的 FlashAttention 优化 IO,再到内存管理的 PagedAttention 解决碎片问题,这些技术共同推动了 LLM 向更长上下文、更低成本的方向发展。在实际工程落地中,应根据硬件资源和业务需求选择合适的优化组合。
参考文献
- [1] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. arXiv preprint.
- [2] PagedAttention: Virtual Memory for Efficient LLM Inference. arXiv preprint.