跳到主要内容现代大模型架构:组注意力机制(GQA)和 RMSNorm | 极客日志PythonAI算法
现代大模型架构:组注意力机制(GQA)和 RMSNorm
综述由AI生成现代大模型架构在注意力机制与归一化层上经历了显著演进。注意力方面,从 MHA 的多头独立 KV 缓存转向 MQA 的单头共享,最终折中为 GQA,通过减少 KV 头数量降低显存开销并维持推理速度。归一化方面,LayerNorm 逐渐被 RMSNorm 取代,配合 Pre-Norm 结构以增强深层网络训练稳定性。详细对比了 MHA、MQA 与 GQA 的原理及实现差异,解析了 LayerNorm 与 RMSNorm 的数学公式与适用场景,并提供 PyTorch 代码示例,帮助理解 LLaMA、Qwen 等主流模型的底层配置。
奇形怪状4 浏览 
前言
在大模型论文学习中,随着 LLaMA、Qwen、DeepSeek 等模型的演进,主流架构在 Attention、位置编码、FFN 与归一化上已逐渐形成新的默认配置。相较于最初的 Transformer,现代大模型主要变化包括:
- MQA → GQA(Grouped Query Attention)
- 绝对位置编码 → RoPE(Rotary Positional Embedding)
- ReLU / GELU → SwiGLU
- LayerNorm → RMSNorm + Pre-Norm
本文聚焦于目前的大模型默认配置,重点解析注意力机制的演化与归一化层的升级。
一、现如今的 Transformer
研究者发现模块的更替可以达到更好的效果,因此现代 baseline 架构已不再直接沿用经典 Transformer 的所有组件。以下是经典模型与现代大模型模块对比:
| 模型家族 | 注意力 | 位置编码 | MLP 激活 | 归一化 |
|---|
| 早期 GPT/BERT | MHA | 绝对 PE / learned pos | GELU | LayerNorm |
| LLaMA 1/2/3 系列 | GQA | RoPE | SwiGLU | RMSNorm |
| Qwen2 / Qwen2.5 | GQA | RoPE | SwiGLU | RMSNorm |
| Mistral 7B | GQA + sliding window | RoPE | SwiGLU | RMSNorm |
| DeepSeek-LLM | GQA/自研高效注意力 | RoPE | SwiGLU | RMSNorm |
| Granite / Gemma | GQA/MQA | RoPE | SwiGLU/GeGLU | RMSNorm/LN |
如表格所示,现代大模型在注意力机制、位置编码、MLP 激活层以及归一化方式上均有显著改变。掌握这四件套有助于理清现代 LLM 架构。
二、Attention Serious
2.1 Multi-Head Attention (MHA)
标准自注意力公式为:
$$\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V$$
Transformer 提出多头注意力机制 (Multi-Head Attention, MHA),将输入特征通过不同的线性投影矩阵映射到多个低维子空间:
$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$
最后拼接再线性变换:
$$\text{MultiHead}(Q,K,V) = \text{Concat}( ext{head}_1, \dots, ext{head}_h) W^O$$
MHA 通过多个小头从不同角度捕捉语义信息,增强模型的表达能力和稳定性。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.0):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attn_mask=None):
B, L, _ = x.size()
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
def reshape_heads(t):
return t.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
Q = reshape_heads(Q)
K = reshape_heads(K)
V = reshape_heads(V)
scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5)
if attn_mask is not None:
scores = scores.masked_fill(attn_mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = attn @ V
out = out.transpose(1, 2).contiguous().view(B, L, self.d_model)
return self.w_o(out)
2.2 Multi-Query Attention (MQA)
在 Decoder-Only + 长上下文 + 自回归生成场景下,MHA 暴露出 KV Cache 内存开销过大的问题。对于标准 MHA,每个注意力头维护一份自己的 $K_h, V_h$,显存开销约为 $O(h \cdot L \cdot d_{\text{head}})$。
MQA 提出所有头共享同一份 $K, V$,即只保留一组 $W^K, W^V$,而 $W_i^Q$ 仍然为每个头独立:
$$Q_i = X W_i^Q, \quad K = X W^K, \quad V = X W^V$$
这使得 KV Cache 成本大幅降低,同时 Q 的多头结构仍能捕捉多种语义关系。
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.0):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, self.head_dim)
self.w_v = nn.Linear(d_model, self.head_dim)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attn_mask=None):
B, L, _ = x.size()
Q = self.w_q(x)
Q = Q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
K = self.w_k(x)
V = self.w_v(x)
K = K.unsqueeze(1).expand(B, self.num_heads, L, self.head_dim)
V = V.unsqueeze(1).expand(B, self.num_heads, L, self.head_dim)
scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5)
if attn_mask is not None:
scores = scores.masked_fill(attn_mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = attn @ V
out = out.transpose(1, 2).contiguous().view(B, L, self.d_model)
return self.w_o(out)
2.3 Grouped Query Attention (GQA)
GQA 是 MHA 与 MQA 的折中方案。Q 仍然是很多头,但 K/V 的头数减少为更少的组(num_kv_heads),每组 KV 服务若干个 Q 头。
核心思想:在'省 KV'和'头之间有点差异'之间找平衡。
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_q_heads, num_kv_heads, dropout=0.0):
super().__init__()
assert d_model % num_q_heads == 0
assert num_q_heads % num_kv_heads == 0
self.d_model = d_model
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = d_model // num_q_heads
self.group_size = num_q_heads // num_kv_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, num_kv_heads * self.head_dim)
self.w_v = nn.Linear(d_model, num_kv_heads * self.head_dim)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attn_mask=None):
B, L, _ = x.size()
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
Q = Q.view(B, L, self.num_q_heads, self.head_dim).transpose(1, 2)
K = K.view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
V = V.view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
K = K.repeat_interleave(self.group_size, dim=1)
V = V.repeat_interleave(self.group_size, dim=1)
scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5)
if attn_mask is not None:
scores = scores.masked_fill(attn_mask == 0, float("-inf"))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = attn @ V
out = out.transpose(1, 2).contiguous().view(B, L, self.d_model)
return self.w_o(out)
三、归一化:LayerNorm → RMSNorm + Pre-Norm
归一化主要解决深层网络训练不稳定及分布漂移问题。现代 LLM 倾向于使用 RMSNorm + Pre-Norm。
- Post-Norm:原始 Transformer 用法,残差后接归一化。
- Pre-Norm:现代 LLM 常用,残差前接归一化,配合 RMSNorm 训练更稳定。
3.1 LayerNorm
给定隐藏表示 $x \in \mathbb{R}^d$,LayerNorm 对特征维度进行归一化:
$$\mu = \frac{1}{d} \sum_{i=1}^{d} x_i, \quad \sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2$$
$$\text{LN}(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta$$
- BN 依赖 batch 统计量,推理时 batch 较小易导致分布漂移。
- LN 不依赖 batch,适合变长序列和 Transformer 架构。
3.2 RMSNorm
RMSNorm 基于'层归一化中主要起作用的是缩放因子'的发现,去除了均值减法,只控制尺度:
$$\text{RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2 + \epsilon}$$
$$\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma$$
在 Decoder-only 大模型里,RMSNorm + Pre-Norm 组合在超深层网络上表现更稳定。
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-8):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x):
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
x_norm = x / rms
return self.weight * x_norm
四、总结
现代大模型架构在注意力机制与归一化层上经历了显著演进。注意力方面,从 MHA 转向 GQA,通过减少 KV 头数量降低显存开销并维持推理速度。归一化方面,LayerNorm 逐渐被 RMSNorm 取代,配合 Pre-Norm 结构以增强深层网络训练稳定性。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online