跳到主要内容现代大模型架构核心:GQA 与 RMSNorm | 极客日志PythonAI算法
现代大模型架构核心:GQA 与 RMSNorm
现代大模型架构核心聚焦于注意力机制优化与归一化改进。Grouped Query Attention (GQA) 通过减少 KV 头数量平衡了显存开销与表达能力,而 RMSNorm 结合 Pre-Norm 结构提升了深层网络训练稳定性。文章详细对比了 MHA、MQA 与 GQA 的差异,并提供了基于 PyTorch 的代码实现,帮助开发者理解 LLaMA、DeepSeek 等主流模型的底层设计逻辑。
GitMaster1 浏览 
在大模型论文学习中,很多人最初会感觉架构大同小异,主要是数据和算力在堆积。但随着对 LLaMA、Qwen、DeepSeek 等主流模型架构的深入总结,会发现 Attention、位置编码、FFN 与归一化模块已经悄然从经典 Transformer 演进到了新的默认配置。
相较于最初的 Transformer,现在的主流大模型在架构上逐渐发生了以下变化:
- 注意力机制:MQA → GQA(Grouped Query Attention)
- 位置编码:绝对位置编码 → RoPE(Rotary Positional Embedding)
- MLP 激活层:ReLU / GELU 前馈网络 → SwiGLU 前馈网络
- 归一化:LayerNorm → RMSNorm + Pre-Norm
掌握这四件套,基本就能理清现代 LLM 架构的核心逻辑。
一、现如今的 Transformer
早期的 Transformer 架构通常作为 baseline 被直接沿用,如 BERT、GPT 等。但研究者发现,通过更换特定模块可以达到更好的效果。因此,现代大模型不再直接使用原始 Transformer 架构,而是采用了经过模块替换的新 baseline。
下表统计了经典模型所采用的关键组件对比:
| 模型家族 | 注意力 | 位置编码 | MLP 激活 | 归一化 |
|---|
| 早期 GPT/BERT | MHA | 绝对 PE / learned pos | GELU | LayerNorm |
| LLaMA 1/2/3 系列 | GQA(大模型) | RoPE | SwiGLU | RMSNorm |
| Qwen2 / Qwen2.5 | GQA | RoPE | SwiGLU | RMSNorm |
| Mistral 7B | GQA + sliding window | RoPE | SwiGLU | RMSNorm |
| DeepSeek-LLM 等 | GQA/自研高效注意力 | RoPE | SwiGLU | RMSNorm |
| Granite / Gemma 等 | GQA/MQA | RoPE | SwiGLU/GeGLU | RMSNorm/LN |
二、注意力机制演进
2.1 Multi-Head Attention (MHA)
我们先回顾一下经典的注意力机制公式:
$$ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V $$
在标准自注意力中,我们通过 $QK^T / \sqrt{d_k}$ 计算不同 token 之间的注意力权重。但作者发现,仅用一个注意力头往往难以同时捕捉多种语义关系(如词法、语义、句法等)。因此,Transformer 提出了多头注意力机制 (Multi-Head Attention, MHA)。
将输入特征通过不同的线性投影矩阵,映射到多个低维子空间中:
$$ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$
然后将所有头拼接(concatenate)再线性变换:
$$ \text{MultiHead}(Q,K,V) = \text{Concat}( ext{head}_1, \dots, \text{head}_h) W^O $$
MHA 通过多个小头可以从不同角度捕捉语义信息,增强模型的表达能力和稳定性。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.0):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attn_mask=None):
"""
x: [B, L, d_model]
"""
B, L, _ = x.size()
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
def reshape_heads(t):
return t.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
Q = reshape_heads(Q)
K = reshape_heads(K)
V = reshape_heads(V)
scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5)
if attn_mask is not None:
scores = scores.masked_fill(attn_mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = attn @ V
out = out.transpose(1, 2).contiguous().view(B, L, self.d_model)
return self.w_o(out)
2.2 Multi-Query Attention (MQA)
有了 MHA 之后,大家第一反应是头越多越好。但在大模型、尤其是 Decoder-Only + 长上下文 + 自回归生成的场景下,MHA 暴露出了一个非常现实的问题:KV Cache 太贵了。
在自回归生成过程中,每生成一个新 token,都需要用到历史所有位置的 K, V。对于标准 MHA,每个注意力头都维护一份自己的 $K_h, V_h$。如果有 h 个头,那么 KV Cache 的内存开销大致是 $\mathcal{O}(h \cdot L \cdot d_{\text{head}})$。当我们把头数堆到 32、64 甚至更多,再把上下文长度拉到 32K、64K 时,这个开销就会变成显存吞噬怪。
为了在几乎不损失模型效果的前提下压缩 KV Cache 和带宽成本,就提出了 Multi-Query Attention(MQA)。MQA 提出所有的头共享同一份 K, V,也就是说,只保留一组 $W^K, W^V$,而 $W_i^Q$ 仍然为每个头独立:
$$ Q_i = X W_i^Q, \quad K = X W^K, \quad V = X W^V $$
$$ \text{head}_i = \text{Attention}(Q_i, K, V) = \text{softmax}\left(\frac{Q_i K^\top}{\sqrt{d_k}}\right) V $$
经验发现'多 KV'并没有带来线性收益,Q 仍然是多头的,多头仍能捕捉多种语义关系。
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.0):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, self.head_dim)
self.w_v = nn.Linear(d_model, self.head_dim)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attn_mask=None):
""" x: [B, L, d_model] """
B, L, _ = x.size()
Q = self.w_q(x)
Q = Q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
K = self.w_k(x)
V = self.w_v(x)
K = K.unsqueeze(1)
V = V.unsqueeze(1)
K = K.expand(B, self.num_heads, L, self.head_dim)
V = V.expand(B, self.num_heads, L, self.head_dim)
scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5)
if attn_mask is not None:
scores = scores.masked_fill(attn_mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = attn @ V
out = out.transpose(1, 2).contiguous().view(B, L, self.d_model)
return self.w_o(out)
2.3 Grouped Query Attention (GQA)
- MHA:每个头都有独立的 $K_h, V_h$,表达能力强,但 KV Cache 成本最高;
- MQA:所有头共享同一份 K, V,KV Cache 成本最低,但多头之间视角差异弱,表达能力稍打折。
于是就自然出现了一个折中思路:能不能在'省 KV'和'头之间有点差异'之间找个平衡?这就是 Grouped-Query Attention(GQA)。GQA 的核心思想:Q 仍然是很多头,但 K/V 的头数减少为更少的组(num_kv_heads),每组 KV 服务若干个 Q 头。
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_q_heads, num_kv_heads, dropout=0.0):
super().__init__()
assert d_model % num_q_heads == 0
assert num_q_heads % num_kv_heads == 0
self.d_model = d_model
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = d_model // num_q_heads
self.group_size = num_q_heads // num_kv_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, num_kv_heads * self.head_dim)
self.w_v = nn.Linear(d_model, num_kv_heads * self.head_dim)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attn_mask=None):
""" x: [B, L, d_model] """
B, L, _ = x.size()
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
Q = Q.view(B, L, self.num_q_heads, self.head_dim).transpose(1, 2)
K = K.view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
V = V.view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
K = K.repeat_interleave(self.group_size, dim=1)
V = V.repeat_interleave(self.group_size, dim=1)
scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5)
if attn_mask is not None:
scores = scores.masked_fill(attn_mask == 0, float("-inf"))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = attn @ V
out = out.transpose(1, 2).contiguous().view(B, L, self.d_model)
return self.w_o(out)
三、归一化:LayerNorm → RMSNorm + Pre-Norm
在 Transformer 里,归一化(Normalization)主要解决两个问题:
- 深层网络训练不稳定:梯度可能爆炸或消失;
- 不同层输出分布漂移,导致学习变慢。
最早的 Transformer 使用的是 LayerNorm + Post-Norm 残差结构(指在全连接层后跟上一个归一化层)。
但到了 LLaMA、DeepSeek 等大模型时,大家开始逐渐转向:RMSNorm + Pre-Norm(指在全连接层前跟上一个归一化层)。
🔹 Post-Norm(原始 Transformer 用法)
最早的 Transformer 论文(Attention Is All You Need)使用的是 Post-Norm,代码结构类似:
out = x + sublayer(x)
out = layer_norm(out)
🔹 Pre-Norm(现代 LLM 常用)
大多数现代 LLM(如 LLaMA、DeepSeek 系列)改成了 Pre-Norm,代码结构类似:
h = layer_norm(x)
out = x + sublayer(h)
实践上,Pre-Norm 再配合 RMSNorm,只调节尺度不改均值,在 Decoder-only 结构里训练更稳定、实现也更简单。
3.1 LayerNorm
Layer Normalization(LN)是在 Transformer 中使用最广的归一化方式之一。给定一个 token 的隐藏表示 $x \in \mathbb{R}^{d}$,LayerNorm 对其 特征维度 进行归一化:
$$ \mu = \frac{1}{d} \sum_{i=1}^{d} x_i, \quad \sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2 $$
$$ \text{LN}(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta $$
- $\gamma, \beta \in \mathbb{R}^{d}$ 是可学习的缩放和平移参数;
- 归一化是在单个样本、单个 token 的通道维度上完成的。
对每个 token 的特征做一遍'标准化 + 线性变换',让每一层看到的分布更平滑,避免某些维度过大/过小导致训练不稳。
在 PyTorch 中,你平时看到的 nn.LayerNorm 就是这个东西:
import torch
import torch.nn as nn
x = torch.randn(2, 4, 8)
ln = nn.LayerNorm(8)
y = ln(x)
为什么不用 BatchNorm,而用 LayerNorm / RMSNorm?
这一问是面试官很喜欢的一个考点,尤其是 Transformer / LLM 岗位。核心区别在于:归一化时用哪些维度来统计均值与方差。
- BatchNorm(BN):
- 在 CV 里常用,对 batch 维度 + 空间维度 做统计;
- 对每个通道 c,使用整批数据的统计量。
- LayerNorm(LN):
- 对单个样本、单个 token 的所有特征求均值和方差,不依赖 batch 大小。
在 Transformer / LLM 场景中,BN 存在几个问题:
- 序列长度不固定:BN 在变长序列上不自然,统计维度不好选;
- 推理阶段 batch 很小甚至为 1:BN 的 running mean/var 与训练时差异大,容易分布漂移;
- 自注意力中不同 token 之间差异大:BN 混合不同 token 的统计量,会引入额外噪声。
因此,大模型里更偏向用 LayerNorm / RMSNorm 这种'不依赖 batch、只看自己'的归一化方式。
3.2 RMSNorm
RMSNorm 是基于'层归一化中主要起作用的是缩放因子,而非平移因子'这个发现而提出的归一化方法。在层归一化中需要减去均值,而模型在训练过程中已经学会通过投影矩阵自动调节均值;而 $\gamma$ 的作用是调整每一维的相对 scale,是表达力的核心。
给定 $x \in \mathbb{R}^d$,RMSNorm 的公式为:
$$ \text{RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2 + \epsilon} $$
$$ \text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma $$
RMSNorm 更像是'把这个向量整体缩放到一个合适能量水平',不去把它'拉回 0 均值',只控制它的尺度。
实践上,在 Decoder-only 大模型里:RMSNorm + Pre-Norm 组合在超深层网络(几十层)上表现更稳定,这也是 LLaMA / DeepSeek / Qwen 等系列广泛采用它的原因之一。
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-8):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x):
""" x: [B, L, d_model] """
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
x_norm = x / rms
return self.weight * x_norm
四、总结
本章我们先把现代大模型里的两块'基础设施'打牢:一块是从 MHA → MQA → GQA 的注意力演化,用更少的 KV 头(甚至共享 KV)在不明显掉点的前提下,大幅降低 KV Cache 与长上下文显存开销;另一块是从 LayerNorm → RMSNorm + Pre-Norm 的归一化升级,用'只归一化能量'的 RMSNorm 配合 Pre-Norm 结构,让超深的 Decoder-only 模型在训练和推理中都更加稳定。后面的章节,我们再把 RoPE / SwiGLU / MoE / MLA 这些'进阶武器'一个个拆开,拼成一整套现代 LLM 的'架构面经图谱'。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online