跳到主要内容
斯坦福 CS336 作业实战:从零实现 Transformer 语言模型架构 | 极客日志
Python AI 算法
斯坦福 CS336 作业实战:从零实现 Transformer 语言模型架构 综述由AI生成 斯坦福 CS336 课程作业详解,涵盖 Transformer 语言模型核心组件的从零实现。内容包括线性层、Embedding、RMSNorm、SwiGLU 前馈网络、RoPE 位置编码及因果多头自注意力机制。重点解析了各模块的代码逻辑与数值稳定性处理,并通过 GPT-2 XL 规模模型进行参数量与 FLOPs 核算,分析不同组件的计算开销占比。最终整合为完整的 Transformer LM 架构,为理解大模型底层原理提供实践基础。
月光旅人 发布于 2026/4/11 更新于 2026/5/23 11 浏览斯坦福 CS336 作业实战:从零实现 Transformer 语言模型架构
在之前的讨论中,我们了解了 Transformer Language Model 的作业要求。今天我们来深入拆解 Assignment 1 的具体实现细节。这份笔记记录了从基础模块到完整架构的构建过程,旨在帮助读者理解大模型底层的代码逻辑。
1. 线性层实现 (Linear)
我们需要实现一个 Linear 类,继承自 torch.nn.Module。核心在于执行线性变换,但不包含偏置(bias)参数。接口设计需与 PyTorch 内置模块保持一致。
import math
import torch
from torch import nn
class Linear (nn.Module):
""" A bias-free Linear layer that matches torch.nn.Linear's interface (except it has no bias)
Stores weight as W with shape (out_features, in_features)
"""
def __init__ (self, in_features: int , out_features: int , device=None , dtype=None ):
super ().__init__()
self .in_features = int (in_features)
self .out_features = int (out_features)
self .weight = nn.Parameter(
torch.empty((self .out_features, self .in_features), device=device, dtype=dtype)
)
sigma = math.sqrt(2.0 /(self .in_features + self .out_features))
nn.init.trunc_normal_(self .weight, mean=0.0 , std=sigma, a=-3.0 *sigma, b=3.0 *sigma)
def forward (self, x: torch.Tensor ) -> torch.Tensor:
torch.einsum( , x, .weight)
return
"... i, o i -> ... o"
self
测试适配器用于加载权重并验证输出,运行 uv run pytest -k test_linear 即可确认实现无误。(图:测试运行结果)
2. 嵌入层实现 (Embedding) Embedding 模块负责将整数 Token ID 映射为连续向量。实现时需注意矩阵形状为 (num_embeddings, embedding_dim),且初始化同样使用截断正态分布。
import torch
from torch import nn
class Embedding (nn.Module):
""" A learnable embedding lookup table, equivalent to torch.nn.Embedding
This module maps integer token IDs to continuous vectors of fixed dimensionality.
"""
def __init__ (self, num_embeddings: int , embedding_dim: int , device=None , dtype=None ):
super ().__init__()
self .num_embeddings = int (num_embeddings)
self .embedding_dim = int (embedding_dim)
self .weight = nn.Parameter(
torch.empty((self .num_embeddings, self .embedding_dim), device=device, dtype=dtype)
)
nn.init.trunc_normal_(self .weight, mean=0.0 , std=1.0 , a=-3.0 , b=3.0 )
def forward (self, token_ids: torch.Tensor ) -> torch.Tensor:
return self .weight[token_ids]
运行 uv run pytest -k test_embedding 进行验证。(图:测试运行结果)
3. RMSNorm 归一化 RMSNorm 相比 LayerNorm 去除了均值计算,仅保留均方根。注意数值稳定性,建议先将输入提升为 float32 计算,最后再转回原始数据类型。
import torch
from torch import nn
class RMSNorm (nn.Module):
""" Root Mean Square Layer Normalization (RMSNorm).
For an input vector a in R^{d_model}:
RMS(a) = sqrt(mean(a^2) + eps)
RMSNorm(a) = (a / RMS(a)) * g
"""
def __init__ (self, d_model: int , eps: float = 1e-5 , device=None , dtype=None ):
super ().__init__()
self .d_model = int (d_model)
self .eps = float (eps)
self .weight = nn.Parameter(torch.ones((self .d_model,), device=device, dtype=dtype))
def forward (self, x: torch.Tensor ) -> torch.Tensor:
in_dtype = x.dtype
x_fp32 = x.to(torch.float32)
rms = torch.sqrt(x_fp32.pow (2 ).mean(dim=-1 , keepdim=True ) + self .eps)
y = (x_fp32 / rms) * self .weight.to(torch.float32)
return y.to(in_dtype)
4. SwiGLU 前馈网络 位置无关的前馈网络采用 SwiGLU 结构,由 SiLU 激活函数和 GLU 组成。中间维度 d_ff 通常设为 8/3 * d_model,并向上取整到 64 的倍数以优化硬件性能。
import math
import torch
from torch import nn
def round_up_to_multiple (x: int , multiple: int ) -> int :
if multiple <= 0 :
raise ValueError("multiple must be a positive integer" )
return int (((x + multiple - 1 ) // multiple) * multiple)
def default_d_ff (d_model: int , multiple_of: int = 64 ) -> int :
raw = int (math.ceil((8.0 * d_model) / 3.0 ))
return round_up_to_multiple(raw, multiple_of)
class SwiGLU (nn.Module):
""" Position-wise feed-forward network using the SwiGLU nonlinearity.
The transformation is: FFN(x) = W2( SiLU(W1 x) ⊙ (W3 x) )
"""
def __init__ (self, d_model: int , d_ff: int | None = None , *, multiple_of: int = 64 ,
device=None , dtype=None ):
super ().__init__()
self .d_model = int (d_model)
self .d_ff = int (d_ff) if d_ff is not None else default_d_ff(self .d_model, multiple_of)
self .w1 = Linear(self .d_model, self .d_ff, device=device, dtype=dtype)
self .w2 = Linear(self .d_ff, self .d_model, device=device, dtype=dtype)
self .w3 = Linear(self .d_model, self .d_ff, device=device, dtype=dtype)
@staticmethod
def silu (x: torch.Tensor ) -> torch.Tensor:
return x * torch.sigmoid(x)
def forward (self, x: torch.Tensor ) -> torch.Tensor:
a = self .w1(x)
b = self .w3(x)
gated = self .silu(a) * b
return self .w2(gated)
5. RoPE 旋转位置编码 RoPE 通过旋转矩阵对查询和键向量施加位置依赖。实现时需预计算 cos/sin 表,并根据当前 token 位置切片。支持任意数量的批处理维度。
import torch
from torch import nn
class RoPE (nn.Module):
""" Rotary Positional Embeddings (RoPE). Applies a position-dependent rotation to the last dimension.
"""
def __init__ (self, theta: float , d_k: int , max_seq_len: int , device=None ):
super ().__init__()
if d_k % 2 != 0 :
raise ValueError(f"d_k must be even for RoPE, got d_k={d_k} " )
if max_seq_len <= 0 :
raise ValueError(f"max_seq_len must be positive, got {max_seq_len} " )
self .theta = float (theta)
self .d_k = int (d_k)
self .max_seq_len = int (max_seq_len)
pair_idx = torch.arange(0 , self .d_k, 2 , device=device, dtype=torch.float32)
inv_freq = self .theta ** (-pair_idx / self .d_k)
positions = torch.arange(self .max_seq_len, device=device, dtype=torch.float32)
angles = positions[:, None ] * inv_freq[None , :]
cos = torch.cos(angles)
sin = torch.sin(angles)
self .register_buffer("cos" , cos, persistent=False )
self .register_buffer("sin" , sin, persistent=False )
def forward (self, x: torch.Tensor, token_positions: torch.Tensor ) -> torch.Tensor:
if x.size(-1 ) != self .d_k:
raise ValueError(f"Expected x.size(-1)==d_k=={self.d_k} , got {x.size(-1 )} " )
pos = token_positions.to(device=x.device)
cos = self .cos.index_select(0 , pos.reshape(-1 )).reshape(*pos.shape, -1 )
sin = self .sin.index_select(0 , pos.reshape(-1 )).reshape(*pos.shape, -1 )
x_fp32 = x.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
x_even = x_fp32[..., ::2 ]
x_odd = x_fp32[..., 1 ::2 ]
while cos.dim() < x_even.dim():
cos = cos.unsqueeze(cos.dim()-2 )
sin = sin.unsqueeze(sin.dim()-2 )
out_even = x_even * cos - x_odd * sin
out_odd = x_even * sin + x_odd * cos
out = torch.stack((out_even, out_odd), dim=-1 ).flatten(-2 )
return out.to(dtype=x.dtype)
6. Softmax 数值稳定实现 为避免指数溢出,标准做法是在指数运算前减去最大值。这保证了数值稳定性。
def softmax (x: torch.Tensor, dim: int ) -> torch.Tensor:
""" Numerically stable softmax over a given dimension."""
x_max = torch.amax(x, dim=dim, keepdim=True )
z = x - x_max
exp_z = torch.exp(z)
sum_exp = torch.sum (exp_z, dim=dim, keepdim=True )
return exp_z / sum_exp
7. 缩放点积注意力 (Scaled Dot-Product Attention) 这是 Transformer 的核心。需支持任意 batch 维度及可选掩码(mask)。掩码为 True 的位置允许计算,False 则屏蔽。
import math
import torch
def scaled_dot_product_attention (query: torch.Tensor, key: torch.Tensor, value: torch.tensor, mask: torch.Tensor | None = None ) -> torch.Tensor:
""" Scaled dot-product attention."""
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2 :
raise ValueError("query/key/value must have shape (..., seq_len, d_*)" )
if query.shape[:-2 ] != key.shape[:-2 ] or query.shape[:-2 ] != value.shape[:-2 ]:
raise ValueError("batch dimensions of query, key, value must match" )
d_k = query.shape[-1 ]
if d_k != key.shape[-1 ]:
raise ValueError("query and key must have the same d_k" )
q = query.to(torch.float32)
k = key.to(torch.float32)
v = value.to(torch.float32)
scale = 1.0 / math.sqrt(d_k)
logits = torch.einsum("... s d, ... t d -> ... s t" , q, k) * scale
if mask is not None :
if mask.dtype != torch.bool :
raise TypeError("mask must be a boolean tensor" )
neg_inf = torch.finfo(torch.float32).min
logits = torch.where(mask.to(device=logits.device), logits, neg_inf)
probs = softmax(logits, dim=-1 )
if mask is not None :
probs = probs * mask.to(device=probs.device, dtype=probs.dtype)
out = torch.einsum("... s t, ... t d -> ... s d" , probs, v)
return out.to(dtype=value.dtype)
8. 因果多头自注意力 (Causal Multi-Head Self-Attention) 将 Q、K、V 投影后分割成多个头,分别应用注意力机制,最后合并。注意因果掩码(causal mask)确保只能看到当前位置及之前的信息。
import math
import torch
from torch import nn
class CausalMultiHeadSelfAttention (nn.Module):
""" Causal multi-head self-attention (no RoPE)."""
def __init__ (self, d_model: int , num_heads: int , device=None , dtype=None ):
super ().__init__()
self .d_model = int (d_model)
self .num_heads = int (num_heads)
if self .d_model % self .num_heads != 0 :
raise ValueError("d_model must be divisible by num_heads" )
self .head_dim = self .d_model // self .num_heads
self .q_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .k_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .v_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .o_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
@staticmethod
def _causal_mask (seq_len: int , device: torch.device ) -> torch.Tensor:
return torch.tril(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool ))
def forward (self, x: torch.Tensor ) -> torch.Tensor:
if x.size(-1 ) != self .d_model:
raise ValueError(f"Expected last dim {self.d_model} , got {x.size(-1 )} " )
seq_len = x.size(-2 )
device = x.device
q = self .q_proj(x)
k = self .k_proj(x)
v = self .v_proj(x)
new_shape = q.shape[:-1 ] + (self .num_heads, self .head_dim)
q = q.view(new_shape).transpose(-3 , -2 )
k = k.view(new_shape).transpose(-3 , -2 )
v = v.view(new_shape).transpose(-3 , -2 )
mask = self ._causal_mask(seq_len, device=device)
out = scaled_dot_product_attention(q, k, v, mask=mask)
out = out.transpose(-3 , -2 ).contiguous().view(x.shape[:-1 ] + (self .d_model,))
return self .o_proj(out)
class CausalMultiHeadSelfAttentionWithRoPE (nn.Module):
""" Causal multi-head self-attention with RoPE applied to Q and K (not V)."""
def __init__ (self, d_model: int , num_heads: int , theta: float , max_seq_len: int , device=None , dtype=None ):
super ().__init__()
self .d_model = int (d_model)
self .num_heads = int (num_heads)
if self .d_model % self .num_heads != 0 :
raise ValueError("d_model must be divisible by num_heads" )
self .head_dim = self .d_model // self .num_heads
self .q_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .k_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .v_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .output_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .rope = RoPE(theta=theta, d_k=self .head_dim, max_seq_len=max_seq_len, device=device)
@staticmethod
def _causal_mask (seq_len: int , device: torch.device ) -> torch.Tensor:
return torch.tril(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool ))
def forward (self, x: torch.Tensor, token_positions: torch.Tensor ) -> torch.Tensor:
if x.size(-1 ) != self .d_model:
raise ValueError(f"Expected last dim {self.d_model} , got {x.size(-1 )} " )
seq_len = x.size(-2 )
device = x.device
q = self .q_proj(x)
k = self .k_proj(x)
v = self .v_proj(x)
new_shape = q.shape[:-1 ] + (self .num_heads, self .head_dim)
q = q.view(new_shape).transpose(-3 , -2 )
k = k.view(new_shape).transpose(-3 , -2 )
v = v.view(new_shape).transpose(-3 , -2 )
q = self .rope(q, token_positions)
k = self .rope(k, token_positions)
mask = self ._causal_mask(seq_len, device=device)
out = scaled_dot_product_attention(q, k, v, mask=mask)
out = out.transpose(-3 , -2 ).contiguous().view(x.shape[:-1 ] + (self .d_model,))
return self .output_proj(out)
9. Transformer Block 采用 Pre-Norm 结构,即先归一化再计算注意力或前馈网络,最后加残差连接。这种结构有助于深层网络的训练稳定性。
import torch
from torch import nn
class TransformerBlock (nn.Module):
""" Pre-norm Transformer block. Structure (pre-norm):
y = x + Attn(RMSNorm(x))
z = y + FFN(RMSNorm(y))
"""
def __init__ (self, d_model: int , num_heads: int , d_ff: int , *, max_seq_len: int , theta: float , eps: float = 1e-5 , device=None , dtype=None ):
super ().__init__()
self .d_model = int (d_model)
self .num_heads = int (num_heads)
self .d_ff = int (d_ff)
self .ln1 = RMSNorm(self .d_model, eps=eps, device=device, dtype=dtype)
self .attn = CausalMultiHeadSelfAttentionWithRoPE(
d_model=self .d_model, num_heads=self .num_heads, theta=theta, max_seq_len=max_seq_len, device=device, dtype=dtype
)
self .ln2 = RMSNorm(self .d_model, eps=eps, device=device, dtype=dtype)
self .ffn = SwiGLU(self .d_model, d_ff=self .d_ff, device=device, dtype=dtype)
def forward (self, x: torch.Tensor, token_positions: torch.Tensor ) -> torch.Tensor:
h = self .ln1(x)
x = x + self .attn(h, token_positions)
h = self .ln2(x)
x = x + self .ffn(h)
return x
10. Transformer LM 整合 将所有组件堆叠:Token Embedding -> N 个 Transformer Block -> Final RMSNorm -> LM Head。注意上下文长度限制和位置生成。
import torch
from torch import nn
from cs336_basics.modules import Embedding, Linear, RMSNorm, TransformerBlock
class TransformerLM (nn.Module):
""" A Transformer language model composed of:
token embedding -> N pre-norm Transformer blocks -> final RMSNorm -> LM head.
"""
def __init__ (self, vocab_size: int , context_length: int , d_model: int , num_layers: int , num_heads: int , d_ff: int , *, rope_theta: float , max_seq_len: int | None = None , eps: float = 1e-5 , device=None , dtype=None ):
super ().__init__()
self .vocab_size = int (vocab_size)
self .context_length = int (context_length)
self .d_model = int (d_model)
self .num_layers = int (num_layers)
self .max_seq_len = int (max_seq_len if max_seq_len is not None else context_length)
self .token_embeddings = Embedding(self .vocab_size, self .d_model, device=device, dtype=dtype)
self .layers = nn.ModuleList([
TransformerBlock(
d_model=self .d_model, num_heads=num_heads, d_ff=d_ff,
max_seq_len=self .max_seq_len, theta=rope_theta, eps=eps, device=device, dtype=dtype
) for _ in range (self .num_layers)
])
self .ln_final = RMSNorm(self .d_model, eps=eps, device=device, dtype=dtype)
self .lm_head = Linear(self .d_model, self .vocab_size, device=device, dtype=dtype)
def forward (self, in_indices: torch.Tensor ) -> torch.Tensor:
if in_indices.dim() != 2 :
raise ValueError(f"in_indices must have shape (batch, seq_len), got {tuple (in_indices.shape)} " )
batch, seq_len = in_indices.shape
if seq_len > self .context_length:
raise ValueError(f"seq_len={seq_len} exceeds context_length={self.context_length} " )
token_positions = torch.arange(seq_len, device=in_indices.device, dtype=torch.long).view(1 , seq_len)
token_positions = token_positions.expand(batch, seq_len)
x = self .token_embeddings(in_indices)
for block in self .layers:
x = block(x, token_positions)
x = self .ln_final(x)
logits = self .lm_head(x)
return logits
11. 资源核算 (FLOPs Accounting) 理解计算量和内存消耗至关重要。Transformer 的 FLOPs 主要来自矩阵乘法。规则是:A(m×n) × B(n×p) 需要 2mnp 次 FLOPs。
(a) GPT-2 XL 参数量 配置:vocab=50257, layers=48, d=1600, heads=25, d_ff=6400。
总参数约为 2.13B,若用 float32 表示,加载需约 8.51GB。
(b) 前向传播 FLOPs 每层主要操作包括:Q/K/V/O 投影、注意力分数计算、注意力加权、SwiGLU FFN。
代入数值计算,GPT-2 XL 单次前向传播总 FLOPs 约为 4.51 TFLOPs。
(c) 瓶颈分析 FFN(SwiGLU)消耗最多 FLOPs(约 62.9B),其次是投影层(约 21.0B),注意力 matmul 相对较少(约 6.7B)。
(d) 规模变化影响 随着模型变大(层数/宽度增加),与 d² 或 dd_ff 成正比的 FFN 与投影占比上升;与 S² d 成正比的注意力 matmul 占比下降。
(e) 序列长度影响 当 context_length 增加到 16384 时,线性项增长 16 倍,而注意力 matmul 平方增长 256 倍。此时注意力将成为主导项(约 55%)。
本次实现涵盖了 Transformer 语言模型的核心组件,从底层算子到高层架构。重点在于理解各模块如何协同工作,以及数值稳定性和计算效率的考量。完成本节后,我们已经具备了一个结构正确、接口清晰的 Transformer LM 实现,为后续训练流程的学习打下基础。
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
随机西班牙地址生成器 随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
Gemini 图片去水印 基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online