跳到主要内容
CS336 从零构建语言模型:Transformer LM 架构实现 | 极客日志
Python AI 算法
CS336 从零构建语言模型:Transformer LM 架构实现 综述由AI生成 基于斯坦福 CS336 课程作业,详细实现了 Transformer 语言模型的核心组件。内容包括线性层、嵌入层、RMSNorm、SwiGLU 前馈网络、RoPE 位置编码、缩放点积注意力及多头自注意力的从零构建。文章提供了完整的 PyTorch 代码实现与测试方法,并针对 GPT-2 XL 规模模型进行了参数量与 FLOPs 的资源核算分析,帮助理解 Transformer 在计算与内存上的消耗分布。
字节跳动 发布于 2026/4/6 更新于 2026/5/20 27 浏览1. Problem (linear): Implementing the linear module (1 point)
Deliverable :请实现一个 Linear 类,该类继承自 torch.nn.Module,并执行线性变换,你的实现应当遵循 PyTorch 内置 nn.Linear 模块的接口设计,但不包含偏置(bias)参数或偏置项。
def __init__ (self, in_features, out_features, device=None , dtype=None ):
...
用于构造一个线性变换模块,该函数应当接收以下参数:
in_features: int:输入的最终维度
out_features: int:输出的最终维度
device: torch.device | None = None:用于存放参数的设备
dtype: torch.dtype | None = None:参数的数据类型
def forward (self, x: torch.Tensor ) -> torch.Tensor:
...
将线性变换应用到输入张量上。
实现时请务必注意以下几点 :
继承自 nn.Module
调用父类构造函数(super().__init__())
构造并存储参数矩阵为 W(而不是 W^T),这是出于内存排列顺序的考虑,该参数应存放在一个 nn.Parameter 中
不要使用 nn.Linear 或 nn.functional.linear
关于参数初始化,请使用上文给出的初始化设置,并结合 torch.nn.init.trunc_normal_ 来初始化权重参数。
代码实现如下:
import math
import torch
from torch import nn
class Linear (nn.Module):
""" A bias-free Linear layer that matches torch.nn.Linear's interface (except it has no bias)
Stores weight as W with shape (out_features, in_features)
"""
def __init__ (self, in_features: int , out_features: int , device: torch.device | None = , dtype: torch.dtype | = ):
().__init__()
.in_features = (in_features)
.out_features = (out_features)
.weight = nn.Parameter(torch.empty(( .out_features, .in_features), device=device, dtype=dtype))
sigma = math.sqrt( /( .in_features + .out_features))
nn.init.trunc_normal_( .weight, mean= , std=sigma, a=- *sigma, b= *sigma)
( ) -> torch.Tensor:
torch.einsum( , x, .weight)
None
None
None
super
self
int
self
int
self
self
self
2.0
self
self
self
0.0
3.0
3.0
def
forward
self, x: torch.Tensor
return
"... i, o i -> ... o"
self
uv run pytest -k test_linear
2. Problem (embedding): Implement the embedding module (1 point) Deliverable :请实现一个 Embedding 类,该类继承自 torch.nn.Module,并执行嵌入查找(embedding lookup)。
def __init__ (self, num_embeddings, embedding_dim, device=None , dtype=None ):
...
继承自 nn.Module
调用父类构造函数(super().__init__())
将嵌入矩阵初始化并存储为一个 nn.Parameter
嵌入矩阵的最后一个维度必须是 d_model
不要使用 nn.Embedding 或 nn.functional.embedding
import torch
from torch import nn
class Embedding (nn.Module):
""" A learnable embedding lookup table, equivalent to torch.nn.Embedding
This module maps integer token IDs to continuous vectors of fixed dimensionality (embedding_dim).
"""
def __init__ (self, num_embeddings: int , embedding_dim: int , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .num_embeddings = int (num_embeddings)
self .embedding_dim = int (embedding_dim)
self .weight = nn.Parameter(torch.empty((self .num_embeddings, self .embedding_dim), device=device, dtype=dtype))
nn.init.trunc_normal_(self .weight, mean=0.0 , std=1.0 , a=-3.0 , b=3.0 )
def forward (self, token_ids: torch.Tensor ) -> torch.Tensor:
return self .weight[token_ids]
uv run pytest -k test_embedding
3. Problem (rmsnorm): Root Mean Square Layer Normalization (1 point) Deliverable :请将 RMSNorm 实现为一个 torch.nn.Module。
def __init__ (self, d_model: int , eps: float = 1e-5 , device=None , dtype=None ):
...
Note :请记得在执行归一化之前,先将输入提升为 torch.float32,并在计算完成后再转换回原始数据类型。
import torch
from torch import nn
class RMSNorm (nn.Module):
""" Root Mean Square Layer Normalization (RMSNorm). For an input vector a in R^{d_model}:
RMS(a) = sqrt(mean(a^2) + eps)
RMSNorm(a) = (a / RMS(a)) * g
"""
def __init__ (self, d_model: int , eps: float = 1e-5 , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .d_model = int (d_model)
self .eps = float (eps)
self .weight = nn.Parameter(torch.ones((self .d_model,), device=device, dtype=dtype))
def forward (self, x: torch.Tensor ) -> torch.Tensor:
in_dtype = x.dtype
x_fp32 = x.to(torch.float32)
rms = torch.sqrt(x_fp32.pow (2 ).mean(dim=-1 , keepdim=True ) + self .eps)
y = (x_fp32 / rms) * self .weight.to(torch.float32)
return y.to(in_dtype)
uv run pytest -k test_rmsnorm
4. Problem (positionwise_feedforward): Implement the position-wise feed-forward network (2 points) Deliverable :请实现一个 SwiGLU 前馈网络 ,该网络由 SiLU 激活函数 和 GLU(门控线性单元) 组成。
Note :在这一具体实现中,为了提高数值稳定性,你可以在代码中直接使用 torch.sigmoid。
在实现时,你应当将前馈网络的中间维度 d_ff 设为大约 d_ff ≈ 8/3 × d_model,同时需要确保 前馈网络内部层的维度是 64 的整数倍 。
import math
import torch
from torch import nn
def round_up_to_multiple (x: int , multiple: int ) -> int :
if multiple <= 0 :
raise ValueError("multiple must be a positive integer" )
return int (((x + multiple - 1 ) // multiple) * multiple)
def default_d_ff (d_model: int , multiple_of: int = 64 ) -> int :
raw = int (math.ceil((8.0 * d_model) / 3.0 ))
return round_up_to_multiple(raw, multiple_of)
class SwiGLU (nn.Module):
""" Position-wise feed-forward network using the SwiGLU nonlinearity.
The transformation is: FFN(x) = W2( SiLU(W1 x) ⊙ (W3 x) )
"""
def __init__ (self, d_model: int , d_ff: int | None = None , *, multiple_of: int = 64 , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .d_model = int (d_model)
self .d_ff = int (d_ff) if d_ff is not None else default_d_ff(self .d_model, multiple_of)
self .w1 = Linear(self .d_model, self .d_ff, device=device, dtype=dtype)
self .w2 = Linear(self .d_ff, self .d_model, device=device, dtype=dtype)
self .w3 = Linear(self .d_model, self .d_ff, device=device, dtype=dtype)
@staticmethod
def silu (x: torch.Tensor ) -> torch.Tensor:
return x * torch.sigmoid(x)
def forward (self, x: torch.Tensor ) -> torch.Tensor:
a = self .w1(x)
b = self .w3(x)
gated = self .silu(a) * b
return self .w2(gated)
uv run pytest -k test_swiglu
5. Problem (rope): Implement RoPE (2 points) Deliverable :请实现一个 RotaryPositionalEmbedding 类,用于将 RoPE(旋转位置编码) 应用于输入张量。
def __init__ (self, theta: float , d_k: int , max_seq_len: int , device=None ):
...
你的实现应当 支持任意数量的批处理维度
可以假设 token_positions 是一个形状为 (..., seq_len) 的张量
你应当使用 token_positions,在序列维度上 切片(slice) 你可能已经预计算好的 cos 和 sin 张量
import torch
from torch import nn
class RoPE (nn.Module):
""" Rotary Positional Embeddings (RoPE). Applies a position-dependent rotation to the last dimension (d_k) of an input tensor.
"""
def __init__ (self, theta: float , d_k: int , max_seq_len: int , device: torch.device | None = None ):
super ().__init__()
if d_k % 2 != 0 :
raise ValueError(f"d_k must be even for RoPE, got d_k={d_k} " )
if max_seq_len <= 0 :
raise ValueError(f"max_seq_len must be positive, got {max_seq_len} " )
self .theta = float (theta)
self .d_k = int (d_k)
self .max_seq_len = int (max_seq_len)
pair_idx = torch.arange(0 , self .d_k, 2 , device=device, dtype=torch.float32)
inv_freq = self .theta ** (-pair_idx / self .d_k)
positions = torch.arange(self .max_seq_len, device=device, dtype=torch.float32)
angles = positions[:, None ] * inv_freq[None , :]
cos = torch.cos(angles)
sin = torch.sin(angles)
self .register_buffer("cos" , cos, persistent=False )
self .register_buffer("sin" , sin, persistent=False )
def forward (self, x: torch.Tensor, token_positions: torch.Tensor ) -> torch.Tensor:
if x.size(-1 ) != self .d_k:
raise ValueError(f"Expected x.size(-1)==d_k=={self.d_k} , got {x.size(-1 )} " )
pos = token_positions.to(device=x.device)
cos = self .cos.index_select(0 , pos.reshape(-1 )).reshape(*pos.shape, -1 )
sin = self .sin.index_select(0 , pos.reshape(-1 )).reshape(*pos.shape, -1 )
x_fp32 = x.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
x_even = x_fp32[..., ::2 ]
x_odd = x_fp32[..., 1 ::2 ]
while cos.dim() < x_even.dim():
cos = cos.unsqueeze(cos.dim()-2 )
sin = sin.unsqueeze(sin.dim()-2 )
out_even = x_even * cos - x_odd * sin
out_odd = x_even * sin + x_odd * cos
out = torch.stack((out_even, out_odd), dim=-1 ).flatten(-2 )
return out.to(dtype=x.dtype)
uv run pytest -k test_rope
6. Problem (softmax): Implement softmax (1 point) Deliverable :编写一个函数,用于对一个张量应用 softmax 操作。
为避免数值稳定性问题,请使用如下技巧:在第 i 个维度上,对该维度的所有元素减去该维度上的最大值,再计算 softmax 。
def softmax (x: torch.Tensor, dim: int ) -> torch.Tensor:
""" Numerically stable softmax over a given dimension.
"""
x_max = torch.amax(x, dim=dim, keepdim=True )
z = x - x_max
exp_z = torch.exp(z)
sum_exp = torch.sum (exp_z, dim=dim, keepdim=True )
return exp_z / sum_exp
uv run pytest -k test_softmax_matches_pytorch
7. Problem (scaled_dot_product_attention): Implement scaled dot-product attention (5 points) Deliverable :实现缩放点积注意力(scaled dot-product attention)函数。
你的实现还需要支持一个 可选的、由用户提供的布尔掩码(mask) 。
import math
import torch
def scaled_dot_product_attention (query: torch.Tensor, key: torch.Tensor, value: torch.tensor, mask: torch.Tensor | None = None ) -> torch.Tensor:
""" Scaled dot-product attention.
"""
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2 :
raise ValueError("query/key/value must have shape (..., seq_len, d_*)" )
if query.shape[:-2 ] != key.shape[:-2 ] or query.shape[:-2 ] != value.shape[:-2 ]:
raise ValueError("batch dimensions of query, key, value must match" )
d_k = query.shape[-1 ]
if d_k != key.shape[-1 ]:
raise ValueError("query and key must have the same d_k" )
q = query.to(torch.float32)
k = key.to(torch.float32)
v = value.to(torch.float32)
scale = 1.0 / math.sqrt(d_k)
logits = torch.einsum("... s d, ... t d -> ... s t" , q, k) * scale
if mask is not None :
if mask.dtype != torch.bool :
raise TypeError("mask must be a boolean tensor" )
neg_inf = torch.finfo(torch.float32).min
logits = torch.where(mask.to(device=logits.device), logits, neg_inf)
probs = softmax(logits, dim=-1 )
if mask is not None :
probs = probs * mask.to(device=probs.device, dtype=probs.dtype)
out = torch.einsum("... s t, ... t d -> ... s d" , probs, v)
return out.to(dtype=value.dtype)
uv run pytest -k test_scaled_dot_product_attention
uv run pytest -k test_4d_scaled_dot_product_attention
8. Problem (multihead_self_attention): Implement causal multi-head self-attention (5 points) Deliverable :实现一个 因果多头自注意力(causal multi-head self-attention) 模块。
按照 [Vaswani+ 2017] 的设定,令 d_k = d_v = d_model / h。
import math
import torch
from torch import nn
class CausalMultiHeadSelfAttention (nn.Module):
""" Causal multi-head self-attention (no RoPE).
"""
def __init__ (self, d_model: int , num_heads: int , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .d_model = int (d_model)
self .num_heads = int (num_heads)
if self .d_model % self .num_heads != 0 :
raise ValueError("d_model must be divisible by num_heads" )
self .head_dim = self .d_model // self .num_heads
self .q_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .k_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .v_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .o_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
@staticmethod
def _causal_mask (seq_len: int , device: torch.device ) -> torch.Tensor:
return torch.tril(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool ))
def forward (self, x: torch.Tensor ) -> torch.Tensor:
if x.size(-1 ) != self .d_model:
raise ValueError(f"Expected last dim {self.d_model} , got {x.size(-1 )} " )
seq_len = x.size(-2 )
device = x.device
q = self .q_proj(x)
k = self .k_proj(x)
v = self .v_proj(x)
new_shape = q.shape[:-1 ] + (self .num_heads, self .head_dim)
q = q.view(new_shape).transpose(-3 , -2 )
k = k.view(new_shape).transpose(-3 , -2 )
v = v.view(new_shape).transpose(-3 , -2 )
mask = self ._causal_mask(seq_len, device=device)
out = scaled_dot_product_attention(q, k, v, mask=mask)
out = out.transpose(-3 , -2 ).contiguous().view(x.shape[:-1 ] + (self .d_model,))
return self .o_proj(out)
class CausalMultiHeadSelfAttentionWithRoPE (nn.Module):
""" Causal multi-head self-attention with RoPE applied to Q and K (not V).
"""
def __init__ (self, d_model: int , num_heads: int , theta: float , max_seq_len: int , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .d_model = int (d_model)
self .num_heads = int (num_heads)
if self .d_model % self .num_heads != 0 :
raise ValueError("d_model must be divisible by num_heads" )
self .head_dim = self .d_model // self .num_heads
self .q_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .k_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .v_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .output_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .rope = RoPE(theta=theta, d_k=self .head_dim, max_seq_len=max_seq_len, device=device)
@staticmethod
def _causal_mask (seq_len: int , device: torch.device ) -> torch.Tensor:
return torch.tril(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool ))
def forward (self, x: torch.Tensor, token_positions: torch.Tensor ) -> torch.Tensor:
if x.size(-1 ) != self .d_model:
raise ValueError(f"Expected last dim {self.d_model} , got {x.size(-1 )} " )
seq_len = x.size(-2 )
device = x.device
q = self .q_proj(x)
k = self .k_proj(x)
v = self .v_proj(x)
new_shape = q.shape[:-1 ] + (self .num_heads, self .head_dim)
q = q.view(new_shape).transpose(-3 , -2 )
k = k.view(new_shape).transpose(-3 , -2 )
v = v.view(new_shape).transpose(-3 , -2 )
q = self .rope(q, token_positions)
k = self .rope(k, token_positions)
mask = self ._causal_mask(seq_len, device=device)
out = scaled_dot_product_attention(q, k, v, mask=mask)
out = out.transpose(-3 , -2 ).contiguous().view(x.shape[:-1 ] + (self .d_model,))
return self .output_proj(out)
uv run pytest -k test_multihead_self_attention
9. Problem (transformer_block): Implement the Transformer block (3 points) Deliverable :请按照描述实现一个 pre-norm Transformer 块 。
结构:y = x + Attn(RMSNorm(x)), z = y + FFN(RMSNorm(y))。
import torch
from torch import nn
class TransformerBlock (nn.Module):
""" Pre-norm Transformer block.
Structure (pre-norm): y = x + Attn(RMSNorm(x)); z = y + FFN(RMSNorm(y))
"""
def __init__ (self, d_model: int , num_heads: int , d_ff: int , *, max_seq_len: int , theta: float , eps: float = 1e-5 , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .d_model = int (d_model)
self .num_heads = int (num_heads)
self .d_ff = int (d_ff)
self .ln1 = RMSNorm(self .d_model, eps=eps, device=device, dtype=dtype)
self .attn = CausalMultiHeadSelfAttentionWithRoPE(
d_model=self .d_model, num_heads=self .num_heads, theta=theta, max_seq_len=max_seq_len, device=device, dtype=dtype
)
self .ln2 = RMSNorm(self .d_model, eps=eps, device=device, dtype=dtype)
self .ffn = SwiGLU(self .d_model, d_ff=self .d_ff, device=device, dtype=dtype)
def forward (self, x: torch.Tensor, token_positions: torch.Tensor ) -> torch.Tensor:
h = self .ln1(x)
x = x + self .attn(h, token_positions)
h = self .ln2(x)
x = x + self .ffn(h)
return x
uv run pytest -k test_transformer_block
10. Problem (transformer_lm): Implementing the Transformer LM (3 points) 现在我们将所有模块组合在一起。首先对输入进行嵌入处理,然后将结果送入 num_layers 个 Transformer 块中,最后再将输出传入三个输出层,从而得到在整个词表上的概率分布。
至少,你的实现需要支持前面所有 Transformer 块的构造参数,此外还应支持以下额外参数:
vocab_size: int:词表大小
context_length: int:最大上下文长度
num_layers: int:使用的 Transformer 块的数量
import torch
from torch import nn
from cs336_basics.modules import Embedding, Linear, RMSNorm, TransformerBlock
class TransformerLM (nn.Module):
""" A Transformer language model composed of: token embedding -> N pre-norm Transformer blocks -> final RMSNorm -> LM head.
"""
def __init__ (self, vocab_size: int , context_length: int , d_model: int , num_layers: int , num_heads: int , d_ff: int , *, rope_theta: float , max_seq_len: int | None = None , eps: float = 1e-5 , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .vocab_size = int (vocab_size)
self .context_length = int (context_length)
self .d_model = int (d_model)
self .num_layers = int (num_layers)
self .max_seq_len = int (max_seq_len if max_seq_len is not None else context_length)
self .token_embeddings = Embedding(self .vocab_size, self .d_model, device=device, dtype=dtype)
self .layers = nn.ModuleList([
TransformerBlock(
d_model=self .d_model, num_heads=num_heads, d_ff=d_ff,
max_seq_len=self .max_seq_len, theta=rope_theta, eps=eps, device=device, dtype=dtype
) for _ in range (self .num_layers)
])
self .ln_final = RMSNorm(self .d_model, eps=eps, device=device, dtype=dtype)
self .lm_head = Linear(self .d_model, self .vocab_size, device=device, dtype=dtype)
def forward (self, in_indices: torch.Tensor ) -> torch.Tensor:
if in_indices.dim() != 2 :
raise ValueError(f"in_indices must have shape (batch, seq_len), got {tuple (in_indices.shape)} " )
batch, seq_len = in_indices.shape
if seq_len > self .context_length:
raise ValueError(f"seq_len={seq_len} exceeds context_length={self.context_length} " )
token_positions = torch.arange(seq_len, device=in_indices.device, dtype=torch.long).view(1 , seq_len)
token_positions = token_positions.expand(batch, seq_len)
x = self .token_embeddings(in_indices)
for block in self .layers:
x = block(x, token_positions)
x = self .ln_final(x)
logits = self .lm_head(x)
return logits
uv run pytest -k test_transformer_lm
11. Problem (transformer_accounting): Transformer LM resource accounting (5 points) Resource accounting.
理解 Transformer 各个组成部分在 计算量和内存 方面的消耗是非常有帮助的。
Rule: 给定矩阵 A ∈ R m×n 和 B ∈ R n×p,矩阵乘积 AB 需要 2mnp 次 FLOPs。
(a)GPT-2 XL 配置分析
vocab_size : 50,257
context_length : 1,024
num_layers : 48
d_model : 1,600
num_heads : 25
d_ff : 6,400
参数量 :约 2.13B 个可训练参数,如果每个参数都使用 float32 表示,仅加载该模型需要 8.51GB。
(b)FLOPs 核算
Q/K/V/Output 投影(4 次) :F_proj = 4 · 2Sd²
注意力分数 QKᵀ(每头) :F_QKT = 2S²d
注意力加权 AV(每头) :F_AV = 2S²d
所以注意力两次 matmul 合计:F_attn = 4S²d
SwiGLU FFN 三次 matmul(W1, W3, W2) :F_ffn = 6Sdd_ff
LM head :F_lm = 2SdV
每层合计 ≈ 90.5969664e9
48 层总计 ≈ 4.3486543872e12
LM head ≈ 1.646821376e11
总前向 FLOPs :≈ 4.51 TFLOPs
(c)主要消耗部分 在这个设置下,FFN(SwiGLU)是最大头(每层约 62.9B FLOPs),其次是 Q/K/V/O 投影(约 21.0B),再其次是 注意力两次 matmul(QKᵀ + AV)(约 6.7B);lm_head 占比相对较小。
(d)不同模型规模对比 随着模型变大(层数/宽度增大),与 d² 或 d*d_ff 成正比的 FFN 与投影占比上升;与 S²*d 成正比的 注意力 matmul 占比下降,而 lm_head(~S*d*V)占比也明显下降。
(e)Context Length 扩展影响 context_length 增长 16× 时,线性项(投影、FFN、LM head)FLOPs 也增长 16×,而注意力 matmul(QKᵀ、AV)的 FLOPS 是平方增长,即增长 256×。总 FLOPs 将从 ≈ 4.51×10¹² 增加到 ≈ 1.495×10¹⁴ FLOPs(约 33.1× 增长),相对贡献也会变化,其中注意力 matmul 将变成主导项(约 55%),FFN 与投影占比下降。
参考 相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
随机西班牙地址生成器 随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
Gemini 图片去水印 基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online