Multi-Head Attention
Multi-Head Attention (MHA) 是 Transformer 架构的核心组件。Attention 的计算复杂度与文本长度的二次方成正比,相关的计算过程如下。
MHA 的整体复杂度与上下文长度 n 的二次方成正比,与模型的规模 d(embedding size)的二次方成正比。增大 context 的长度,会带来计算复杂度的二次方增大。
Attention 实现机制优化
Multi-Query Attention (MQA)
对于 Multi-Head Attention,每个 head 对应的 k 矩阵和 v 矩阵不同,因此对于每个 token 都有 h(head 数目)个 k 矩阵和 v 矩阵。
在模型推理的过程中,为了防止重新计算,会缓存之前 token 对应的 Keys 和 Values。因此 GPU 显存占用会随着预测的 token 数目而增加。
Multi-Query Attention 通过在不同 head 中共享 K 和 V,即不同的 head 具有相同的 key 和 value,降低了存储的 k 矩阵和 v 矩阵的数目。对于每个 token 存储的 matrix 数目由 2h 个降低为两个 matrix。同时也降低了计算复杂度。
Multi-Query Attention 极大地提高了推理速度。
Group Query Attention (GQA)
Group Query Attention 是对所有 head 的 Query 分组为不同的 group,对一个 group 内的 query,共享 key 和 value。GQA 的效果与 MHA 的效果相当,训练速度与 MQA 相当,提高了训练速度的同时,效果相比 MQA 有提高。
GQA 的实现
# init 时 k 和 v 用 self.num_key_value_heads * self.head_dim 初始化,当 self.num_key_value_heads 小于 self.num_heads 时,参数量变少
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
# forward 时,通过 repeat_kv 方法,将 hidden states 从 (batch, num_key_value_heads, seqlen, head_dim) 变成 (batch, num_attention_heads, seqlen, head_dim),相当于是复制了 self.num_key_value_groups 份
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(, )
value_states = value_states.view(bsz, q_len, .num_key_value_heads, .head_dim).transpose(, )
key_states = repeat_kv(key_states, .num_key_value_groups)
value_states = repeat_kv(value_states, .num_key_value_groups)


