跳到主要内容CS336 从零构建语言模型:Transformer 架构实现详解 | 极客日志PythonAI算法
CS336 从零构建语言模型:Transformer 架构实现详解
综述由AI生成斯坦福 CS336 课程作业实战记录,完整实现了 Transformer 语言模型架构。涵盖线性层、嵌入层、RMSNorm、SwiGLU 前馈网络、RoPE 位置编码及因果多头自注意力等核心组件的 PyTorch 代码编写。重点解析了数值稳定性处理、张量形状变换及资源消耗核算方法,最终整合为完整的 Transformer LM 架构,并分析了 GPT-2 XL 规模的参数量与 FLOPs 分布。
安卓系统17 浏览 CS336 从零构建语言模型:Transformer 架构实现详解
本次作业旨在从零开始实现 Transformer 语言模型的核心组件。我们将基于 PyTorch,逐步搭建线性层、嵌入层、归一化模块、前馈网络、位置编码及注意力机制,最终组合成一个完整的 Transformer LM。整个过程注重数值稳定性与接口规范。
基础模块实现
1. 线性层 (Linear)
首先实现一个无偏置的 Linear 类。它继承自 torch.nn.Module,需严格遵循 nn.Linear 的接口设计,但移除 bias 参数。权重矩阵 W 的形状应为 (out_features, in_features),这有助于内存排列优化。
初始化时,使用截断正态分布 trunc_normal_,标准差设为 sqrt(2 / (d_in + d_out))。在 forward 传播中,为了明确维度意图,推荐使用 einsum 进行矩阵乘法。
import math
import torch
from torch import nn
class Linear(nn.Module):
def __init__(self, in_features: int, out_features: int, device=None, dtype=None):
super().__init__()
self.in_features = int(in_features)
self.out_features = int(out_features)
self.weight = nn.Parameter(
torch.empty((self.out_features, self.in_features), device=device, dtype=dtype)
)
sigma = math.sqrt(2.0 / (self.in_features + self.out_features))
nn.init.trunc_normal_(self.weight, mean=0.0, std=sigma, a=-3.0 * sigma, b=3.0 * sigma)
def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.einsum(, x, .weight)
return
"... i, o i -> ... o"
self
测试适配器加载权重后运行 pytest -k test_linear,验证通过即可确认模块功能正常。
2. 嵌入层 (Embedding)
接下来是词嵌入查找。实现 Embedding 类,将 token ID 映射为连续向量。权重矩阵形状为 (num_embeddings, embedding_dim)。
初始化同样采用截断正态分布,均值为 0,标准差为 1。forward 函数直接通过索引操作获取向量。
class Embedding(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int, device=None, dtype=None):
super().__init__()
self.num_embeddings = int(num_embeddings)
self.embedding_dim = int(embedding_dim)
self.weight = nn.Parameter(
torch.empty((self.num_embeddings, self.embedding_dim), device=device, dtype=dtype)
)
nn.init.trunc_normal_(self.weight, mean=0.0, std=1.0, a=-3.0, b=3.0)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
return self.weight[token_ids]
3. RMSNorm 归一化
RMSNorm 相比 LayerNorm 移除了均值计算,仅保留均方根。注意数值稳定性:输入需先提升为 float32 计算,结果再转回原始数据类型。
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
super().__init__()
self.d_model = int(d_model)
self.eps = float(eps)
self.weight = nn.Parameter(torch.ones((self.d_model,), device=device, dtype=dtype))
def forward(self, x: torch.Tensor) -> torch.Tensor:
in_dtype = x.dtype
x_fp32 = x.to(torch.float32)
rms = torch.sqrt(x_fp32.pow(2).mean(dim=-1, keepdim=True) + self.eps)
y = (x_fp32 / rms) * self.weight.to(torch.float32)
return y.to(in_dtype)
核心组件构建
4. SwiGLU 前馈网络
这是 Transformer 中的关键激活结构。中间维度 d_ff 通常设为 8/3 * d_model,并向上取整到 64 的倍数以适配硬件。
SwiGLU 公式为 FFN(x) = W2(SiLU(W1x) ⊙ W3x)。这里直接使用 torch.sigmoid 实现 SiLU 以保证数值稳定。
def default_d_ff(d_model: int, multiple_of: int = 64) -> int:
raw = int(math.ceil((8.0 * d_model) / 3.0))
return int(((raw + multiple_of - 1) // multiple_of) * multiple_of)
class SwiGLU(nn.Module):
def __init__(self, d_model: int, d_ff: int | None = None, *, multiple_of: int = 64, device=None, dtype=None):
super().__init__()
self.d_model = int(d_model)
self.d_ff = int(d_ff) if d_ff is not None else default_d_ff(self.d_model, multiple_of)
self.w1 = Linear(self.d_model, self.d_ff, device=device, dtype=dtype)
self.w2 = Linear(self.d_ff, self.d_model, device=device, dtype=dtype)
self.w3 = Linear(self.d_model, self.d_ff, device=device, dtype=dtype)
@staticmethod
def silu(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
a = self.w1(x)
b = self.w3(x)
gated = self.silu(a) * b
return self.w2(gated)
5. RoPE 位置编码
旋转位置编码 (RoPE) 不依赖绝对位置,而是通过旋转矩阵注入相对位置信息。需要预计算 cos/sin 表作为 buffer。
关键点在于支持任意 batch 维度,并在 forward 时根据 token_positions 切片缓存表。注意处理偶数/奇数维度的配对旋转。
class RoPE(nn.Module):
def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
super().__init__()
if d_k % 2 != 0:
raise ValueError(f"d_k must be even for RoPE, got d_k={d_k}")
self.theta = float(theta)
self.d_k = int(d_k)
self.max_seq_len = int(max_seq_len)
pair_idx = torch.arange(0, self.d_k, 2, device=device, dtype=torch.float32)
inv_freq = self.theta ** (-pair_idx / self.d_k)
positions = torch.arange(self.max_seq_len, device=device, dtype=torch.float32)
angles = positions[:, None] * inv_freq[None, :]
self.register_buffer("cos", torch.cos(angles), persistent=False)
self.register_buffer("sin", torch.sin(angles), persistent=False)
def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
if x.size(-1) != self.d_k:
raise ValueError(f"Expected x.size(-1)==d_k=={self.d_k}, got {x.size(-1)}")
pos = token_positions.to(device=x.device)
cos = self.cos.index_select(0, pos.reshape(-1)).reshape(*pos.shape, -1)
sin = self.sin.index_select(0, pos.reshape(-1)).reshape(*pos.shape, -1)
x_fp32 = x.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
x_even = x_fp32[..., ::2]
x_odd = x_fp32[..., 1::2]
while cos.dim() < x_even.dim():
cos = cos.unsqueeze(cos.dim() - 2)
sin = sin.unsqueeze(sin.dim() - 2)
out_even = x_even * cos - x_odd * sin
out_odd = x_even * sin + x_odd * cos
out = torch.stack((out_even, out_odd), dim=-1).flatten(-2)
return out.to(dtype=x.dtype)
6. Softmax 与 注意力机制
Softmax 实现需注意数值稳定性,先减去最大值再指数化。
缩放点积注意力 (SDPA) 支持可选掩码 (mask)。True 表示允许,False 表示屏蔽(对应负无穷)。计算 logits 时使用 float32 以避免溢出。
def softmax(x: torch.Tensor, dim: int) -> torch.Tensor:
x_max = torch.amax(x, dim=dim, keepdim=True)
z = x - x_max
exp_z = torch.exp(z)
sum_exp = torch.sum(exp_z, dim=dim, keepdim=True)
return exp_z / sum_exp
def scaled_dot_product_attention(query, key, value, mask=None):
q = query.to(torch.float32)
k = key.to(torch.float32)
v = value.to(torch.float32)
scale = 1.0 / math.sqrt(q.shape[-1])
logits = torch.einsum("... s d, ... t d -> ... s t", q, k) * scale
if mask is not None:
neg_inf = torch.finfo(torch.float32).min
logits = torch.where(mask.to(device=logits.device), logits, neg_inf)
probs = softmax(logits, dim=-1)
if mask is not None:
probs = probs * mask.to(device=probs.device, dtype=probs.dtype)
out = torch.einsum("... s t, ... t d -> ... s d", probs, v)
return out.to(dtype=value.dtype)
7. 多头自注意力 (Multi-head Self-Attention)
将 Q/K/V 投影后 reshape 为多头形式,应用 SDPA 后再合并。因果掩码 (Causal Mask) 确保当前 token 只能关注之前的 token。
class CausalMultiHeadSelfAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, device=None, dtype=None):
super().__init__()
self.d_model = int(d_model)
self.num_heads = int(num_heads)
if self.d_model % self.num_heads != 0:
raise ValueError("d_model must be divisible by num_heads")
self.head_dim = self.d_model // self.num_heads
self.q_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
self.k_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
self.v_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
self.o_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
@staticmethod
def _causal_mask(seq_len: int, device) -> torch.Tensor:
return torch.tril(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool))
def forward(self, x: torch.Tensor) -> torch.Tensor:
seq_len = x.size(-2)
device = x.device
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
new_shape = q.shape[:-1] + (self.num_heads, self.head_dim)
q = q.view(new_shape).transpose(-3, -2)
k = k.view(new_shape).transpose(-3, -2)
v = v.view(new_shape).transpose(-3, -2)
mask = self._causal_mask(seq_len, device=device)
out = scaled_dot_product_attention(q, k, v, mask=mask)
out = out.transpose(-3, -2).contiguous().view(x.shape[:-1] + (self.d_model,))
return self.o_proj(out)
完整模型整合
8. Transformer Block
采用 Pre-Norm 结构,即 y = x + Attn(RMSNorm(x)),z = y + FFN(RMSNorm(y))。每个块内部集成 RoPE。
9. Transformer Language Model
将所有层堆叠,加上 Token Embedding 和 LM Head。输入序列长度不能超过 context_length。
class TransformerLM(nn.Module):
def __init__(self, vocab_size: int, context_length: int, d_model: int, num_layers: int,
num_heads: int, d_ff: int, rope_theta: float, max_seq_len: int = None,
eps: float = 1e-5, device=None, dtype=None):
super().__init__()
self.vocab_size = int(vocab_size)
self.context_length = int(context_length)
self.d_model = int(d_model)
self.num_layers = int(num_layers)
self.max_seq_len = int(max_seq_len if max_seq_len is not None else context_length)
self.token_embeddings = Embedding(self.vocab_size, self.d_model, device=device, dtype=dtype)
self.layers = nn.ModuleList([
TransformerBlock(d_model=self.d_model, num_heads=num_heads, d_ff=d_ff,
max_seq_len=self.max_seq_len, theta=rope_theta, eps=eps,
device=device, dtype=dtype)
for _ in range(self.num_layers)
])
self.ln_final = RMSNorm(self.d_model, eps=eps, device=device, dtype=dtype)
self.lm_head = Linear(self.d_model, self.vocab_size, device=device, dtype=dtype)
def forward(self, in_indices: torch.Tensor) -> torch.Tensor:
batch, seq_len = in_indices.shape
if seq_len > self.context_length:
raise ValueError(f"seq_len={seq_len} exceeds context_length={self.context_length}")
token_positions = torch.arange(seq_len, device=in_indices.device, dtype=torch.long).view(1, seq_len)
token_positions = token_positions.expand(batch, seq_len)
x = self.token_embeddings(in_indices)
for block in self.layers:
x = block(x, token_positions)
x = self.ln_final(x)
logits = self.lm_head(x)
return logits
资源核算分析
理解 Transformer 的计算量和内存消耗至关重要。绝大多数 FLOPs 来自矩阵乘法,规则为 2 * m * n * p。
以 GPT-2 XL 为例(48 层,d_model=1600):
- 参数量:约 2.13B,float32 加载需 8.51GB。
- FLOPs 分布:
- FFN (SwiGLU):占比最高,每层约 62.9B FLOPs。
- 投影层 (Q/K/V/O):次之,每层约 21.0B FLOPs。
- 注意力 Matmul:约 6.7B FLOPs。
- LM Head:相对较小。
随着模型规模增大(层数/宽度增加),与 d^2 或 d*d_ff 成正比的 FFN 与投影占比上升;而与 S^2*d 成正比的注意力 Matmul 占比下降。若上下文长度 context_length 大幅增加,注意力计算的平方级增长将使其成为主导项。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online