Multi-Head Attention
Multi-Head Attention (MHA) 是 Transformer 架构的核心组件。Attention 的计算复杂度与文本长度的二次方成正比,相关的计算过程如下。
MHA 的整体复杂度与上下文长度 n 的二次方成正比,与模型的规模 d(embedding size)的二次方成正比。增大 context 的长度,会带来计算复杂度的二次方增大。
Attention 实现机制优化
Multi-Query Attention (MQA)
对于 Multi-Head Attention,每个 head 对应的 k 矩阵和 v 矩阵不同,因此对于每个 token 都有 h(head 数目)个 k 矩阵和 v 矩阵。
在模型推理的过程中,为了防止重新计算,会缓存之前 token 对应的 Keys 和 Values。因此 GPU 显存占用会随着预测的 token 数目而增加。
Multi-Query Attention 通过在不同 head 中共享 K 和 V,即不同的 head 具有相同的 key 和 value,降低了存储的 k 矩阵和 v 矩阵的数目。对于每个 token 存储的 matrix 数目由 2h 个降低为两个 matrix。同时也降低了计算复杂度。
Multi-Query Attention 极大地提高了推理速度。
Group Query Attention (GQA)
Group Query Attention 是对所有 head 的 Query 分组为不同的 group,对一个 group 内的 query,共享 key 和 value。GQA 的效果与 MHA 的效果相当,训练速度与 MQA 相当,提高了训练速度的同时,效果相比 MQA 有提高。
GQA 的实现
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
内存开销计算
使用 MHA 结构的自回归模型,在推理过程中,会维护一个巨大的 k/v cache。它的内存开销公式为:
batch * max_seq_len * n_heads * head_dim * sizeof(half) * 2
而对于 GQA 来说,k/v cache 的内存开销公式变成:
batch * max_seq_len * n_kv_heads * head_dim * sizeof(half) * 2
n_heads / n_kv_heads 就是 group 的大小。可见,使用 GQA 可以把 k/v cache 降低到 MHA 的 1/group 的水平。非常利好 Attention 这种访存密集型的计算。
SWA (Sliding Window Attention)
通过优化 attention 的实现,降低 attention 与 context length 的长度依赖关系。这种对 attention 结构的优化,会同时提升训练和推理的性能。
注意力的时间复杂度是序列长度的二次方,空间复杂度是序列长度的一次方。在推理时,由于缓存的可用性降低,会造成更高的延迟和更小的吞吐量。为了减少这样的问题,提出了窗口注意力机制,在每一个注意力层每个 token 最多能注意前 W 个 token。
注意力的传递通过层数的增加而向后传递。每一层注意力层,信息可以传递 W tokens。经过两层注意力层,信息可以传递 2W tokens。比如对于 16k 序列长度和 4k 的滑动窗口,通过 4 层,信息可以实现整个序列长度的传递。因此序列越长,在滑动窗口长度固定的情况下,为了实现整个序列长度的传递,需要的注意力层数越多。
Attention 底层实现优化
FlashAttention
FlashAttention 解决 attention 计算过程中,频繁访问 HBM 的问题,将 attention 计算 block 化,直接在 SRAM 中进行。
在 GPU 中底层对算子的优化,会同时提升模型的训练和推理性能。考虑到 Attention 在 GPU 的计算过程以及 GPU 的结构,优化 Attention 在 GPU 中的实现。
GPU 中的两个核心部分,SRAM 运算速度快但是存储量小,HBM 运算速度慢但是存储量大。GPU 中的 operation 运算过程,是从 HBM 拷贝数据进行运算,完成运算后再将数据存储到 HBM。
FlashAttention 通过以下两个操作实现了 attention 的加速实现:
- 利用了 GPU 中存储的差异性。将数据从 HBM 拷贝到 SRAM 中,计算时从 SRAM 中读取数据,SRAM 相比 HBM 读取和写入速度更快。
- SRAM 相比 HBM 速度快,但是存储量小,因此采用分块 block 的形式计算 QK 的矩阵乘法。即实现了并行 block 的 softmax 计算。为了保证分块 block 计算的 softmax 值与原有的 softmax 值不变,采用了 block 的 softmax 计算。
FlashAttention 将 Q、K 和 V 切分为 block,进行 block 的计算,提高 operation 的处理速度。
PagedAttention
PagedAttention 解决 attention 计算过程中的内存分配问题,防止内存的浪费,更好的分配内存,可以实现更大的 batch size 和吞吐量。
传统 KV Cache 存在的问题主要包括:
- 显存占用大:对于大型模型如 LLaMA-13B 中的单个序列,KV Cache 可能占用高达 1.7GB 的内存。
- 动态变化:KV Cache 的大小取决于序列长度,而序列长度具有高度可变和不可预测的特点,这对有效管理 KV Cache 构成挑战。
- 内存碎片化和过度预留:由于显存碎片和过度预留,现有系统浪费了 60%-80% 的显存。
- 内部碎片化:在静态批处理策略下,一个请求结束后,其剩余的空间就被浪费掉了。
- 外部碎片化:由于 KV Cache 是一个巨大的矩阵,且必须占用连续内存,操作系统如果只分配大的连续内存,势必有很多小的内存空间被浪费掉。
PagedAttention 的优势
相当于小空间内存的动态分配,可以实现非连续的内存存储,解决了传统 KV Cache 连续动态内存分配造成的内存空间浪费。
总结
Attention 机制的优化主要集中在减少计算复杂度和显存占用上。MHA 提供了强大的表达能力但显存消耗大;MQA 和 GQA 通过共享 Key/Value 显著减少了 KV Cache 的显存需求,提升了推理吞吐量;SWA 通过限制注意力范围降低了长序列处理的延迟;FlashAttention 通过优化 IO 路径利用 SRAM 加速计算;PagedAttention 则解决了显存碎片问题,支持更大的并发 Batch Size。在实际应用中,通常结合多种技术以达到性能与精度的最佳平衡。