跳到主要内容
CS336 从零构建语言模型:Transformer LM 架构实现 | 极客日志
Python AI 算法
CS336 从零构建语言模型:Transformer LM 架构实现 综述由AI生成 详细记录了斯坦福 CS336 课程 Assignment 1 中 Transformer 语言模型的从零实现过程。内容包括线性层、嵌入层、RMSNorm、SwiGLU 前馈网络、RoPE 位置编码、Softmax、缩放点积注意力及多头自注意力模块的代码实现。最后整合为完整的 Transformer Block 和 Transformer LM,并对 GPT-2 XL 规模的参数量与前向传播 FLOPs 进行了核算分析,明确了各组件的计算消耗占比。
神经兮兮 发布于 2026/4/6 更新于 2026/5/20 32 浏览前言
本文记录 CS336 作业 Assignment 1: Basics 中的 Transformer Language Model Architecture 实现,涵盖从基础模块到完整模型的构建过程。
Assignment 1: https://github.com/stanford-cs336/assignment1-basics/tree/main
1. Problem (linear): Implementing the linear module (1 point)
Deliverable :请实现一个 Linear 类,该类继承自 torch.nn.Module,并执行线性变换,你的实现应当遵循 PyTorch 内置 nn.Linear 模块的接口设计,但不包含偏置(bias)参数或偏置项。
推荐使用如下接口:
def __init__ (self, in_features, out_features, device=None , dtype=None )
用于构造一个线性变换模块,该函数应当接收以下参数:
in_features: int:输入的最终维度
out_features: int:输出的最终维度
device: torch.device | None = None:用于存放参数的设备
dtype: torch.dtype | None = None:参数的数据类型
def forward (self, x: torch.Tensor ) -> torch.Tensor
将线性变换应用到输入张量上。
实现时请务必注意以下几点 :
继承自 nn.Module
调用父类构造函数(super().__init__())
构造并存储参数矩阵为 W (而不是 W^T),这是出于内存排列顺序的考虑,该参数应存放在一个 nn.Parameter 中
不要 使用 nn.Linear 或 nn.functional.linear
关于参数初始化,请使用上文给出的初始化设置,并结合 torch.nn.init.trunc_normal_ 来初始化权重参数。
代码实现如下:
import math
import torch
from torch import nn
class Linear (nn.Module):
""" A bias-free Linear layer that matches torch.nn.Linear's interface (except it has no bias)
Stores weight as W with shape (out_features, in_features)
"""
( ):
().__init__()
.in_features = (in_features)
.out_features = (out_features)
.weight = nn.Parameter(
torch.empty(( .out_features, .in_features), device=device, dtype=dtype)
)
sigma = math.sqrt( /( .in_features + .out_features))
nn.init.trunc_normal_( .weight, mean= , std=sigma, a=- *sigma, b= *sigma)
( ) -> torch.Tensor:
torch.einsum( , x, .weight)
def
__init__
self, in_features: int , out_features: int , device: torch.device | None = None , dtype: torch.dtype | None = None
super
self
int
self
int
self
self
self
2.0
self
self
self
0.0
3.0
3.0
def
forward
self, x: torch.Tensor
return
"... i, o i -> ... o"
self
测试适配器 [adapters.run_linear] 的实现如下:
def run_linear (
d_in: int ,
d_out: int ,
weights: Float[Tensor, "d_out d_in" ],
in_features: Float[Tensor, "... d_in" ],
) -> Float[Tensor, "... d_out" ]:
""" Given the weights of a Linear layer, compute the transformation of a batched input."""
from cs336_basics.modules import Linear
layer = Linear(d_in, d_out, device=in_features.device, dtype=in_features.dtype)
layer.load_state_dict({"weight" : weights.to(device=in_features.device, dtype=in_features.dtype)})
return layer(in_features)
执行 uv run pytest -k test_linear 后输出如下:
2. Problem (embedding): Implement the embedding module (1 point) Deliverable :请实现一个 Embedding 类,该类继承自 torch.nn.Module,并执行嵌入查找(embedding lookup),你的实现应当遵循 PyTorch 内置 nn.Embedding 模块的接口设计。
def __init__ (self, num_embeddings, embedding_dim, device=None , dtype=None )
num_embeddings:词表大小(vocabulary size)
embedding_dim : int:嵌入向量的维度,即 d_model
device: torch.device | None = None:用于存放参数的设备
dtype: torch.dtype | None = None:参数的数据类型
def forward (self, token_ids: torch.Tensor ) -> torch.Tensor
根据给定的 token ID,查找并返回对应的嵌入向量。
继承自 nn.Module
调用父类构造函数(super().__init__())
将嵌入矩阵初始化并存储为一个 nn.Parameter
嵌入矩阵的最后一个维度必须是 d_model
不要 使用 nn.Embedding 或 nn.functional.embedding
关于参数初始化,同样请使用前文给出的初始化设置,并使用 torch.nn.init.trunc_normal_ 来初始化嵌入权重。
import torch
from torch import nn
class Embedding (nn.Module):
""" A learnable embedding lookup table, equivalent to torch.nn.Embedding
This module maps integer token IDs to continuous vectors of fixed dimensionality (embedding_dim).
The embedding matrix is stored as a learnable parameter of shape (num_embeddings, embedding_dim).
"""
def __init__ (self, num_embeddings: int , embedding_dim: int , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .num_embeddings = int (num_embeddings)
self .embedding_dim = int (embedding_dim)
self .weight = nn.Parameter(
torch.empty((self .num_embeddings, self .embedding_dim), device=device, dtype=dtype)
)
nn.init.trunc_normal_(self .weight, mean=0.0 , std=1.0 , a=-3.0 , b=3.0 )
def forward (self, token_ids: torch.Tensor ) -> torch.Tensor:
return self .weight[token_ids]
测试适配器 [adapters.run_embedding] 的实现如下:
def run_embedding (
vocab_size: int ,
d_model: int ,
weights: Float[Tensor, "vocab_size d_model" ],
token_ids: Int[Tensor, "..." ],
) -> Float[Tensor, "... d_model" ]:
""" Given the weights of an Embedding layer, get the embeddings for a batch of token ids."""
from cs336_basics.modules import Embedding
layer = Embedding(vocab_size, d_model, device=weights.device, dtype=weights.dtype)
layer.load_state_dict({"weight" : weights.to(device=weights.device, dtype=weights.dtype)})
return layer(token_ids)
执行 uv run pytest -k test_embedding 后输出如下:
3. Problem (rmsnorm): Root Mean Square Layer Normalization (1 point) Deliverable :请将 RMSNorm 实现为一个 torch.nn.Module,我们推荐使用如下接口:
def __init__ (self, d_model:int , eps:float =1e-5 , device=None , dtype=None )
用于构造 RMSNorm 模块,该函数应当接收以下参数:
d_model: int:模型的隐藏维度
eps: float = 1e-5:用于数值稳定性的 ε 参数
device: torch.device | None = None:用于存放参数的设备
dtype: torch.dtype | None = None:参数的数据类型
def forward (self, x: torch.Tensor ) -> torch.Tensor
对形状为 (batch_size, sequence_length, d_model) 的输入张量进行处理,并返回 形状相同 的输出张量。
Note :请记得在执行归一化之前,先将输入提升为 torch.float32,并在计算完成后 再转换回原始数据类型 ,如前文所述。
import torch
from torch import nn
class RMSNorm (nn.Module):
""" Root Mean Square Layer Normalization (RMSNorm). For an input vector a in R^{d_model}:
RMS(a) = sqrt(mean(a^2) + eps)
RMSNorm(a) = (a / RMS(a)) * g
"""
def __init__ (self, d_model: int , eps: float = 1e-5 , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .d_model = int (d_model)
self .eps = float (eps)
self .weight = nn.Parameter(torch.ones((self .d_model,), device=device, dtype=dtype))
def forward (self, x: torch.Tensor ) -> torch.Tensor:
in_dtype = x.dtype
x_fp32 = x.to(torch.float32)
rms = torch.sqrt(x_fp32.pow (2 ).mean(dim=-1 , keepdim=True ) + self .eps)
y = (x_fp32 / rms) * self .weight.to(torch.float32)
return y.to(in_dtype)
测试适配器 [adapters.run_rmsnorm] 的实现如下:
def run_rmsnorm (
d_model: int ,
eps: float ,
weights: Float[Tensor, "d_model" ],
in_features: Float[Tensor, "... d_model" ],
) -> Float[Tensor, "... d_model" ]:
""" Given the weights of a RMSNorm affine transform, return the output of running RMSNorm on the input features."""
from cs336_basics.modules import RMSNorm
layer = RMSNorm(d_model=d_model, eps=eps, device=in_features.device, dtype=weights.dtype)
layer.load_state_dict({"weight" : weights.to(device=in_features.device, dtype=weights.dtype)})
return layer(in_features)
执行 uv run pytest -k test_rmsnorm 后输出如下:
4. Problem (positionwise_feedforward): Implement the position-wise feed-forward network (2 points) Deliverable :请实现一个 SwiGLU 前馈网络 ,该网络由 SiLU 激活函数 和 GLU(门控线性单元) 组成。
Note :在这一具体实现中,为了提高数值稳定性,你可以在代码中直接使用 torch.sigmoid。
在实现时,你应当将前馈网络的中间维度 d_ff 设为大约 d_ff ≈ 8/3 × d_model,同时需要确保 前馈网络内部层的维度是 64 的整数倍 ,以便更好地利用硬件性能。
import math
import torch
from torch import nn
def round_up_to_multiple (x: int , multiple: int ) -> int :
"""Round x up to the nearest positive multiple of `multiple`."""
if multiple <= 0 :
raise ValueError("multiple must be a positive integer" )
return int (((x + multiple - 1 ) // multiple) * multiple)
def default_d_ff (d_model: int , multiple_of: int = 64 ) -> int :
""" Compute the recommended SwiGLU hidden size. We use d_ff ~= (8/3) * d_model and then round up to a hardware-friendly multiple (typically 64). """
raw = int (math.ceil((8.0 * d_model) / 3.0 ))
return round_up_to_multiple(raw, multiple_of)
class SwiGLU (nn.Module):
""" Position-wise feed-forward network using the SwiGLU nonlinearity.
The transformation is: FFN(x) = W2( SiLU(W1 x) ⊙ (W3 x) )
where SiLU(z) = z * sigmoid(z), and ⊙ is elementwise multiplication.
"""
def __init__ (self, d_model: int , d_ff: int | None = None , *, multiple_of: int = 64 , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .d_model = int (d_model)
self .d_ff = int (d_ff) if d_ff is not None else default_d_ff(self .d_model, multiple_of)
self .w1 = Linear(self .d_model, self .d_ff, device=device, dtype=dtype)
self .w2 = Linear(self .d_ff, self .d_model, device=device, dtype=dtype)
self .w3 = Linear(self .d_model, self .d_ff, device=device, dtype=dtype)
@staticmethod
def silu (x: torch.Tensor ) -> torch.Tensor:
return x * torch.sigmoid(x)
def forward (self, x: torch.Tensor ) -> torch.Tensor:
a = self .w1(x)
b = self .w3(x)
gated = self .silu(a) * b
return self .w2(gated)
测试适配器 [adapters.run_swiglu] 的实现如下:
def run_swiglu (
d_model: int ,
d_ff: int ,
w1_weight: Float[Tensor, "d_ff d_model" ],
w2_weight: Float[Tensor, "d_model d_ff" ],
w3_weight: Float[Tensor, "d_ff d_model" ],
in_features: Float[Tensor, "... d_model" ],
) -> Float[Tensor, "... d_model" ]:
""" Given the weights of a SwiGLU network, return the output of your implementation with these weights."""
from cs336_basics.modules import SwiGLU
swiglu = SwiGLU(d_model=d_model, d_ff=d_ff, device=in_features.device, dtype=w1_weight.dtype)
swiglu.load_state_dict({
"w1.weight" : w1_weight.to(device=in_features.device, dtype=w1_weight.dtype),
"w2.weight" : w2_weight.to(device=in_features.device, dtype=w2_weight.dtype),
"w3.weight" : w3_weight.to(device=in_features.device, dtype=w3_weight.dtype)
})
return swiglu(in_features)
执行 uv run pytest -k test_swiglu 后输出如下:
5. Problem (rope): Implement RoPE (2 points) Deliverable :请实现一个 RotaryPositionalEmbedding 类,用于将 RoPE(旋转位置编码) 应用于输入张量,推荐使用如下接口:
def __init__ (self, theta:float , d_k:int , max_seq_len:int , device=None )
用于构造 RoPE 模块,并在需要时创建缓冲区(buffers),该构造函数应当接收以下参数:
theta: float:RoPE 中使用的常数 Θ
d_k: int:查询(query)与键(key)向量的维度
max_seq_len: int:可能输入的最大序列长度
device: torch.device | None = None:用于存放缓冲区的设备
def forward (self, x: torch.Tensor, token_positions: torch.Tensor ) -> torch.Tensor
对形状为 (..., seq_len, d_k) 的输入张量进行处理,并返回 形状相同 的输出张量。
你的实现应当 支持任意数量的批处理维度 ,即 x 在 seq_len 之前可以有任意多个 batch 维度
可以假设 token_positions 是一个形状为 (..., seq_len) 的张量,用于指定序列维度上各 token 的位置
你应当使用 token_positions,在序列维度上 切片(slice) 你可能已经预计算好的 cos 和 sin 张量
import torch
from torch import nn
class RoPE (nn.Module):
""" Rotary Positional Embeddings (RoPE). Applies a position-dependent rotation to the last dimension (d_k) of an input tensor.
The rotation is applied pairwise on (x[..., 0], x[..., 1], x[..., 2], x[..., 3]), ...
This module has no learnable parameters. It can precompute and cache cos/sin table.
"""
def __init__ (self, theta: float , d_k: int , max_seq_len: int , device: torch.device | None = None ):
super ().__init__()
if d_k % 2 != 0 :
raise ValueError(f"d_k must be even for RoPE, got d_k={d_k} " )
if max_seq_len <= 0 :
raise ValueError(f"max_seq_len must be positive, got {max_seq_len} " )
self .theta = float (theta)
self .d_k = int (d_k)
self .max_seq_len = int (max_seq_len)
pair_idx = torch.arange(0 , self .d_k, 2 , device=device, dtype=torch.float32)
inv_freq = self .theta ** (-pair_idx / self .d_k)
positions = torch.arange(self .max_seq_len, device=device, dtype=torch.float32)
angles = positions[:, None ] * inv_freq[None , :]
cos = torch.cos(angles)
sin = torch.sin(angles)
self .register_buffer("cos" , cos, persistent=False )
self .register_buffer("sin" , sin, persistent=False )
def forward (self, x: torch.Tensor, token_positions: torch.Tensor ) -> torch.Tensor:
""" Args:
x: Tensor of shape (..., seq_len, d_k)
token_positions: Tensor of shape (..., seq_len) with integer positions
Returns: Tensor of shape (..., seq_len, d_k) after applying RoPE.
"""
if x.size(-1 ) != self .d_k:
raise ValueError(f"Expected x.size(-1)==d_k=={self.d_k} , got {x.size(-1 )} " )
pos = token_positions.to(device=x.device)
cos = self .cos.index_select(0 , pos.reshape(-1 )).reshape(*pos.shape, -1 )
sin = self .sin.index_select(0 , pos.reshape(-1 )).reshape(*pos.shape, -1 )
x_fp32 = x.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
x_even = x_fp32[..., ::2 ]
x_odd = x_fp32[..., 1 ::2 ]
while cos.dim() < x_even.dim():
cos = cos.unsqueeze(cos.dim()-2 )
sin = sin.unsqueeze(sin.dim()-2 )
out_even = x_even * cos - x_odd * sin
out_odd = x_even * sin + x_odd * cos
out = torch.stack((out_even, out_odd), dim=-1 ).flatten(-2 )
return out.to(dtype=x.dtype)
测试适配器 [adapters.run_rope] 的实现如下:
def run_rope (
d_k: int ,
theta: float ,
max_seq_len: int ,
in_query_or_key: Float[Tensor, "... sequence_length d_k" ],
token_positions: Int[Tensor, "... sequence_length" ],
) -> Float[Tensor, "... sequence_length d_k" ]:
""" Run RoPE for a given input tensor."""
from cs336_basics.modules import RoPE
rope = RoPE(theta=theta, d_k=d_k, max_seq_len=max_seq_len, device=in_query_or_key.device)
return rope(in_query_or_key, token_positions)
执行 uv run pytest -k test_rope 后输出如下:
6. Problem (softmax): Implement softmax (1 point) Deliverable :编写一个函数,用于对一个张量应用 softmax 操作,你的函数应当接受两个参数:一个输入张量(tensor)和一个维度索引 i,并在输入张量的第 i 个维度上应用 softmax 运算。
输出张量应当与输入张量具有 相同的性质 ,但其第 i 个维度上的值将构成一个 归一化的概率分布 。为避免数值稳定性问题,请使用如下技巧:在第 i 个维度上,对该维度的所有元素减去该维度上的最大值,再计算 softmax 。
def softmax (x: torch.Tensor, dim: int ) -> torch.Tensor:
""" Numerically stable softmax over a given dimension.
This implementation subtracts the maximum value along `dim` before exponentiation to improve numerical stability.
"""
x_max = torch.amax(x, dim=dim, keepdim=True )
z = x - x_max
exp_z = torch.exp(z)
sum_exp = torch.sum (exp_z, dim=dim, keepdim=True )
return exp_z / sum_exp
测试适配器 [adapters.run_softmax] 的实现如下:
def run_softmax (in_features: Float[Tensor, "..." ], dim: int ) -> Float[Tensor, "..." ]:
""" Given a tensor of inputs, return the output of softmaxing the given `dim` of the input."""
from cs336_basics.modules import softmax
return softmax(in_features, dim)
执行 uv run pytest -k test_softmax 后输出如下:
7. Problem (scaled_dot_product_attention): Implement scaled dot-product attention (5 points) Deliverable :实现缩放点积注意力(scaled dot-product attention)函数,你的实现需要支持如下输入形式:
Query 和 Key 的形状为 (batch_size, …, seq_len, d_k)
Value 的形状为:(batch_size, …, seq_len, d_v)
其中,'…' 表示任意数量的其他类似 batch 的维度,函数应当返回形状为 (batch_size, …, d_v) 的输出张量。
你的实现还需要支持一个 可选的、由用户提供的布尔掩码(mask) ,其形状为 (seq_len, seq_len),对应掩码中值为 True 的位置,其对应的注意力概率在该维度上应当 共同归一化为 1 ,而掩码值为 False 的位置,其注意力概率应当为 0 。
import math
import torch
def scaled_dot_product_attention (
query: torch.Tensor,
key: torch.Tensor,
value: torch.tensor,
mask: torch.Tensor | None = None
) -> torch.Tensor:
""" Scaled dot-product attention."""
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2 :
raise ValueError("query/key/value must have shape (..., seq_len, d_*)" )
if query.shape[:-2 ] != key.shape[:-2 ] or query.shape[:-2 ] != value.shape[:-2 ]:
raise ValueError("batch dimensions of query, key, value must match" )
d_k = query.shape[-1 ]
if d_k != key.shape[-1 ]:
raise ValueError("query and key must have the same d_k" )
q = query.to(torch.float32)
k = key.to(torch.float32)
v = value.to(torch.float32)
scale = 1.0 / math.sqrt(d_k)
logits = torch.einsum("... s d, ... t d -> ... s t" , q, k) * scale
if mask is not None :
if mask.dtype != torch.bool :
raise TypeError("mask must be a boolean tensor" )
neg_inf = torch.finfo(torch.float32).min
logits = torch.where(mask.to(device=logits.device), logits, neg_inf)
probs = softmax(logits, dim=-1 )
if mask is not None :
probs = probs * mask.to(device=probs.device, dtype=probs.dtype)
out = torch.einsum("... s t, ... t d -> ... s d" , probs, v)
return out.to(dtype=value.dtype)
测试适配器 [adapters.run_scaled_dot_product_attention] 的实现如下:
def run_scaled_dot_product_attention (
Q: Float[Tensor, "... queries d_k" ],
K: Float[Tensor, "... keys d_k" ],
V: Float[Tensor, "... values d_v" ],
mask: Bool[Tensor, "... queries keys" ] | None = None ,
) -> Float[Tensor, "... queries d_v" ]:
""" Given key (K), query (Q), and value (V) tensors, return the output of your scaled dot product attention implementation."""
from cs336_basics.modules import scaled_dot_product_attention
return scaled_dot_product_attention(query=Q, key=K, value=V, mask=mask)
执行 uv run pytest -k test_scaled_dot_product_attention 和 uv run pytest -k test_4d_scaled_dot_product_attention 的输出如下所示:
8. Problem (multihead_self_attention): Implement causal multi-head self-attention (5 points) Deliverable :实现一个 因果多头自注意力(causal multi-head self-attention) 模块,形式为一个 torch.nn.Module,你的实现至少应当接收以下参数:
d_model: int:Transformer 块输入的特征维度
num_heads: int:多头自注意力中使用的注意力头数量
按照 [Vaswani+ 2017] 的设定,令 d_k = d_v = d_model / h,其中 h 为注意力头的数量。
import math
import torch
from torch import nn
class CausalMultiHeadSelfAttention (nn.Module):
""" Causal multi-head self-attention (no RoPE). This module computes:
Q = W_Q x, K = W_K x, V = W_V x
heads = SDPA(Q_heads, K_heads, V_heads, causal_mask)
out = W_O concat(heads)
"""
def __init__ (self, d_model: int , num_heads: int , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .d_model = int (d_model)
self .num_heads = int (num_heads)
if self .d_model % self .num_heads != 0 :
raise ValueError("d_model must be divisible by num_heads" )
self .head_dim = self .d_model // self .num_heads
self .q_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .k_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .v_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .o_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
@staticmethod
def _causal_mask (seq_len: int , device: torch.device ) -> torch.Tensor:
""" Build a (seq_len, seq_len) causal mask where True means 'allowed' """
return torch.tril(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool ))
def forward (self, x: torch.Tensor ) -> torch.Tensor:
""" Args:
x: Tensor of shape (..., seq_len, d_model)
Returns: Tensor of shape (..., seq_len, d_model)
"""
if x.size(-1 ) != self .d_model:
raise ValueError(f"Expected last dim {self.d_model} , got {x.size(-1 )} " )
seq_len = x.size(-2 )
device = x.device
q = self .q_proj(x)
k = self .k_proj(x)
v = self .v_proj(x)
new_shape = q.shape[:-1 ] + (self .num_heads, self .head_dim)
q = q.view(new_shape).transpose(-3 , -2 )
k = k.view(new_shape).transpose(-3 , -2 )
v = v.view(new_shape).transpose(-3 , -2 )
mask = self ._causal_mask(seq_len, device=device)
out = scaled_dot_product_attention(q, k, v, mask=mask)
out = out.transpose(-3 , -2 ).contiguous().view(x.shape[:-1 ] + (self .d_model,))
return self .o_proj(out)
class CausalMultiHeadSelfAttentionWithRoPE (nn.Module):
""" Causal multi-head self-attention with RoPE applied to Q and K (not V). This version uses a fused QKV projection: qkv = W_qkv x q, k, v = split(qkv) """
def __init__ (self, d_model: int , num_heads: int , theta: float , max_seq_len: int , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .d_model = int (d_model)
self .num_heads = int (num_heads)
if self .d_model % self .num_heads != 0 :
raise ValueError("d_model must be divisible by num_heads" )
self .head_dim = self .d_model // self .num_heads
self .q_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .k_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .v_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .output_proj = Linear(self .d_model, self .d_model, device=device, dtype=dtype)
self .rope = RoPE(theta=theta, d_k=self .head_dim, max_seq_len=max_seq_len, device=device)
@staticmethod
def _causal_mask (seq_len: int , device: torch.device ) -> torch.Tensor:
"""Build a (seq_len, seq_len) causal mask where True means 'allowed'."""
return torch.tril(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool ))
def forward (self, x: torch.Tensor, token_positions: torch.Tensor ) -> torch.Tensor:
""" Args:
x: Tensor of shape (..., seq_len, d_model)
token_positions: Tensor of shape (..., seq_len)
Returns: Tensor of shape (..., seq_len, d_model)
"""
if x.size(-1 ) != self .d_model:
raise ValueError(f"Expected last dim {self.d_model} , got {x.size(-1 )} " )
seq_len = x.size(-2 )
device = x.device
q = self .q_proj(x)
k = self .k_proj(x)
v = self .v_proj(x)
new_shape = q.shape[:-1 ] + (self .num_heads, self .head_dim)
q = q.view(new_shape).transpose(-3 , -2 )
k = k.view(new_shape).transpose(-3 , -2 )
v = v.view(new_shape).transpose(-3 , -2 )
q = self .rope(q, token_positions)
k = self .rope(k, token_positions)
mask = self ._causal_mask(seq_len, device=device)
out = scaled_dot_product_attention(q, k, v, mask=mask)
out = out.transpose(-3 , -2 ).contiguous().view(x.shape[:-1 ] + (self .d_model,))
return self .output_proj(out)
测试适配器 [adapters.run_multihead_self_attention] 的实现如下:
def run_multihead_self_attention (
d_model: int ,
num_heads: int ,
q_proj_weight: Float[Tensor, "d_k d_in" ],
k_proj_weight: Float[Tensor, "d_k d_in" ],
v_proj_weight: Float[Tensor, "d_v d_in" ],
o_proj_weight: Float[Tensor, "d_model d_v" ],
in_features: Float[Tensor, "... sequence_length d_in" ],
) -> Float[Tensor, "... sequence_length d_out" ]:
""" Given the key, query, and value projection weights of a naive unbatched implementation of multi-head attention, return the output of an optimized batched implementation."""
from cs336_basics.modules import CausalMultiHeadSelfAttention
mha = CausalMultiHeadSelfAttention(d_model=d_model, num_heads=num_heads, device=in_features.device, dtype=q_proj_weight.dtype)
mha.load_state_dict({
"q_proj.weight" : q_proj_weight.to(device=in_features.device, dtype=q_proj_weight.dtype),
"k_proj.weight" : k_proj_weight.to(device=in_features.device, dtype=k_proj_weight.dtype),
"v_proj.weight" : v_proj_weight.to(device=in_features.device, dtype=v_proj_weight.dtype),
"o_proj.weight" : o_proj_weight.to(device=in_features.device, dtype=o_proj_weight.dtype),
})
return mha(in_features)
def run_multihead_self_attention_with_rope (
d_model: int ,
num_heads: int ,
max_seq_len: int ,
theta: float ,
q_proj_weight: Float[Tensor, "d_k d_in" ],
k_proj_weight: Float[Tensor, "d_k d_in" ],
v_proj_weight: Float[Tensor, "d_v d_in" ],
o_proj_weight: Float[Tensor, "d_model d_v" ],
in_features: Float[Tensor, "... sequence_length d_in" ],
token_positions: Int[Tensor, "... sequence_length" ] | None = None ,
) -> Float[Tensor, "... sequence_length d_out" ]:
""" Given the key, query, and value projection weights of a naive unbatched implementation of multi-head attention, return the output of an optimized batched implementation. This version of MHA should include RoPE."""
from cs336_basics.modules import CausalMultiHeadSelfAttentionWithRoPE
if token_positions is None :
seq_len = in_features.size(-2 )
token_positions = torch.arange(seq_len, device=in_features.device, dtype=torch.long)
token_positions = token_positions.view(*([1 ]*(in_features.dim()-2 )), seq_len)
mha = CausalMultiHeadSelfAttentionWithRoPE(
d_model=d_model, num_heads=num_heads, theta=theta, max_seq_len=max_seq_len, device=in_features.device, dtype=q_proj_weight.dtype
)
mha.load_state_dict({
"q_proj.weight" : q_proj_weight.to(device=in_features.device, dtype=q_proj_weight.dtype),
"k_proj.weight" : k_proj_weight.to(device=in_features.device, dtype=k_proj_weight.dtype),
"v_proj.weight" : v_proj_weight.to(device=in_features.device, dtype=v_proj_weight.dtype),
"output_proj.weight" : o_proj_weight.to(device=in_features.device, dtype=o_proj_weight.dtype),
})
return mha(in_features, token_positions)
执行 uv run pytest -k test_multihead_self_attention 的输出如下所示:
9. Problem (transformer_block): Implement the Transformer block (3 points) 请按照 §1.5 中的描述并参考 Figure 2 ,实现一个 pre-norm Transformer 块 ,你的 Transformer 块至少应当接受以下参数:
d_model: int:Transformer 块输入的特征维度
num_heads: int:多头自注意力中使用的注意力头数量
d_ff: int:位置前馈网络中内部隐藏层的维度
为了测试你的实现,请在 [adapters.run_transformer_block] 中实现对应的测试适配器,然后运行:
uv run pytest -k test_transformer_block
Deliverable :一份能够通过所有提供测试的 Transformer 块实现代码。
import torch
from torch import nn
class TransformerBlock (nn.Module):
""" Pre-norm Transformer block. Structure (pre-norm):
y = x + Attn(RMSNorm(x))
z = y + FFN(RMSNorm(y))
This block uses causal multi-head self-attention with RoPE
"""
def __init__ (self, d_model: int , num_heads: int , d_ff: int , *, max_seq_len: int , theta: float , eps: float = 1e-5 , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .d_model = int (d_model)
self .num_heads = int (num_heads)
self .d_ff = int (d_ff)
self .ln1 = RMSNorm(self .d_model, eps=eps, device=device, dtype=dtype)
self .attn = CausalMultiHeadSelfAttentionWithRoPE(
d_model=self .d_model, num_heads=self .num_heads, theta=theta, max_seq_len=max_seq_len, device=device, dtype=dtype
)
self .ln2 = RMSNorm(self .d_model, eps=eps, device=device, dtype=dtype)
self .ffn = SwiGLU(self .d_model, d_ff=self .d_ff, device=device, dtype=dtype)
def forward (self, x: torch.Tensor, token_positions: torch.Tensor ) -> torch.Tensor:
""" Args:
x: Tensor of shape (batch, seq_len, d_model)
token_positions: Tensor of shape (batch, seq_len) or broadcastable to it
Returns: Tensor of shape (batch, seq_len, d_model)
"""
h = self .ln1(x)
x = x + self .attn(h, token_positions)
h = self .ln2(x)
x = x + self .ffn(h)
return x
测试适配器 [adapters.run_transformer_block] 的实现如下:
def run_transformer_block (
d_model: int ,
num_heads: int ,
d_ff: int ,
max_seq_len: int ,
theta: float ,
weights: dict [str , Tensor],
in_features: Float[Tensor, "batch sequence_length d_model" ],
) -> Float[Tensor, "batch sequence_length d_model" ]:
""" Given the weights of a pre-norm Transformer block and input features, return the output of running the Transformer block on the input features. This function should use RoPE."""
from cs336_basics.modules import TransformerBlock
block = TransformerBlock(
d_model=d_model, num_heads=num_heads, d_ff=d_ff, max_seq_len=max_seq_len, theta=theta, device=in_features.device, dtype=in_features.dtype
)
sd = {k: v.to(device=in_features.device, dtype=in_features.dtype) for k, v in weights.items()}
block.load_state_dict(sd)
batch, seq_len, _ = in_features.shape
token_position = torch.arange(seq_len, device=in_features.device, dtype=torch.long).view(1 , seq_len).expand(batch, seq_len)
return block(in_features, token_position)
执行 uv run pytest -k test_transformer_block 的输出如下所示:
10. Problem (transformer_lm): Implementing the Transformer LM (3 points) 现在我们将所有模块组合在一起,整体流程如 Figure 1 中的高层结构示意所示,按照 §1.1.1 中对嵌入层(embedding)的描述,首先对输入进行嵌入处理,然后将结果送入 num_layers 个 Transformer 块中,最后再将输出传入三个输出层,从而得到在整个词表上的概率分布。
现在是把所有组件整合在一起的时候了!请按照 §1.1 中的描述,并结合 Figure 1 所示的结构,实现一个 Transformer 语言模型。至少,你的实现需要支持前面所有 Transformer 块的构造参数,此外还应支持以下额外参数:
vocab_size: int:词表大小,用于确定词嵌入矩阵(token embedding matrix)的维度
context_length: int:最大上下文长度,用于确定位置嵌入矩阵(position embedding matrix)的维度
num_layers: int:使用的 Transformer 块的数量
为了使用我们提供的测试来验证你的实现,你首先需要在 [adapters.run_transformer_lm] 中实现测试适配器,然后运行:
uv run pytest -k test_transformer_lm
Deliverable :一个能够通过上述测试的 Transformer 语言模型模块。
import torch
from torch import nn
from cs336_basics.modules import Embedding, Linear, RMSNorm, TransformerBlock
class TransformerLM (nn.Module):
""" A Transformer language model composed of:
token embedding -> N pre-norm Transformer blocks -> final RMSNorm -> LM head.
This implementation uses RoPE inside each TransformerBlock's attention module.
"""
def __init__ (self, vocab_size: int , context_length: int , d_model: int , num_layers: int , num_heads: int , d_ff: int , *, rope_theta: float , max_seq_len: int | None = None , eps: float = 1e-5 , device: torch.device | None = None , dtype: torch.dtype | None = None ):
super ().__init__()
self .vocab_size = int (vocab_size)
self .context_length = int (context_length)
self .d_model = int (d_model)
self .num_layers = int (num_layers)
self .max_seq_len = int (max_seq_len if max_seq_len is not None else context_length)
self .token_embeddings = Embedding(self .vocab_size, self .d_model, device=device, dtype=dtype)
self .layers = nn.ModuleList([
TransformerBlock(
d_model=self .d_model, num_heads=num_heads, d_ff=d_ff, max_seq_len=self .max_seq_len, theta=rope_theta, eps=eps, device=device, dtype=dtype
) for _ in range (self .num_layers)
])
self .ln_final = RMSNorm(self .d_model, eps=eps, device=device, dtype=dtype)
self .lm_head = Linear(self .d_model, self .vocab_size, device=device, dtype=dtype)
def forward (self, in_indices: torch.Tensor ) -> torch.Tensor:
""" Args:
in_indices: LongTensor of shape (batch, seq_len)
Returns: logits: Tensor of shape (batch, seq_len, vocab_size)
"""
if in_indices.dim() != 2 :
raise ValueError(f"in_indices must have shape (batch, seq_len), got {tuple (in_indices.shape)} " )
batch, seq_len = in_indices.shape
if seq_len > self .context_length:
raise ValueError(f"seq_len={seq_len} exceeds context_length={self.context_length} " )
token_positions = torch.arange(seq_len, device=in_indices.device, dtype=torch.long).view(1 , seq_len)
token_positions = token_positions.expand(batch, seq_len)
x = self .token_embeddings(in_indices)
for block in self .layers:
x = block(x, token_positions)
x = self .ln_final(x)
logits = self .lm_head(x)
return logits
测试适配器 [adapters.run_transformer_lm] 的实现如下:
def run_transformer_lm (
vocab_size: int ,
context_length: int ,
d_model: int ,
num_layers: int ,
num_heads: int ,
d_ff: int ,
rope_theta: float ,
weights: dict [str , Tensor],
in_indices: Int[Tensor, "batch_size sequence_length" ],
) -> Float[Tensor, "batch_size sequence_length vocab_size" ]:
r""" Given the weights of a Transformer language model and input indices, return the output of running a forward pass on the input indices. This function should use RoPE."""
from cs336_basics.transformer_lm import TransformerLM
model = TransformerLM(
vocab_size=vocab_size, context_length=context_length, d_model=d_model, num_layers=num_layers, num_heads=num_heads, d_ff=d_ff, rope_theta=rope_theta, device=in_indices.device, dtype=torch.float32
)
sd = {k: v.to(device=in_indices.device, dtype=torch.float32) for k, v in weights.items()}
model.load_state_dict(sd)
return model(in_indices)
执行 uv run pytest -k test_transformer_lm 的输出如下所示:
11. Problem (transformer_accounting): Transformer LM resource accounting (5 points) Resource accounting.
理解 Transformer 各个组成部分在 计算量和内存 方面的消耗是非常有帮助的,接下来我们将通过几个步骤进行一次基础的 FLOPs(浮点运算次数)核算 。
由于 Transformer 中绝大多数 FLOPs 都来自矩阵乘法,因此我们的核心思路非常简单:
列出 Transformer 前向传播过程中涉及的所有矩阵乘法
将每一个矩阵乘法转换为所需的 FLOPs 数量
Rule: 给定矩阵 A ∈ R^(m×n) 和 B ∈ R^(n×p),矩阵乘积 AB 需要 2mnp 次 FLOPs。
这是因为 (AB)[i, j] = A[i, :] · B[:, j] 这个点积需要 n 次加法和 n 次乘法,总共是 2n 次 FLOPs,而矩阵 AB 一共有 m × p 个元素,因此,总 FLOPs 数为 (2n)(mp) = 2mnp。
在继续下一个问题之前,建议你先逐一检查自己实现的 Transformer block 和 Transformer 语言模型(Transformer LM) 中的每一个组件,列出其中所有涉及的矩阵乘法,以及它们各自对应的 FLOPs 开销。
vocab_size :50,257
context_length :1,024
num_layers :48
d_model :1,600
num_heads :25
d_ff :6,400
这个模型一共有多少个 可训练参数 ?
如果每个参数都使用 单精度浮点数(float32) 表示,仅加载该模型需要多少内存?
Token embedding:vocab * d_model
LM head:vocab * d_model
每层 Attention(Q/K/V/O):4 * d_model^2
每层 SwiGLU FFN(W1/W3 上投影 + W2 下投影):3 * d_model * d_ff
每层 RMSNorm(ln1/ln2):2 * d_model
最终 RMSNorm:d_model
总参数 :
N = 48 ⋅ ( 4d^2 + 3dd_ff + 2d ) + 2Vd + d
Deliverable :这个模型一共有约 2.13B 个可训练参数,如果每个参数都使用 float32 表示,仅加载该模型需要 8.51GB。
(b)请识别完成一次 GPT-2 XL 规模模型 前向传播所需的 所有矩阵乘法操作 ,假设输入序列长度等于 context_length,这些矩阵乘法总共需要多少 FLOPs ?
Deliverable :列出所有矩阵乘法(附简要说明),并给出所需 FLOPs 的总数。
(S,d) · (d,d) -> (S,d)
每次 FLOPs:2*S*d*d
每层合计:F_proj = 4 ⋅ 2Sd^2
(S,d_k)·(d_k,S)->(S,S),所有头合计 d_k*h=d
F_QKT = 2S^2d
(S,S)·(S,d_v)->(S,d_v),所有头合计 d_v*h=d
F_AV = 2S^2d
所以注意力两次 matmul 合计:F_attn = 4S^2d
4. SwiGLU FFN 三次 matmul(W1, W3, W2)
(S,d)·(d,d_ff) 两次 + (S,d_ff)·(d_ff,d) 一次
F_ffn = 6Sdd_ff
(S,d)·(d,V)->(S,V)
F_lm = 2SdV
每层:
F_proj = 20.97152e9
F_attn = 6.7108864e9
F_ffn = 62.91456e9
每层合计 ≈ 90.5969664e9
48 层总计:
F_layers ≈ 4.3486543872e12
LM head:
总前向 FLOPs:
F_total ≈ 4.5133365248 × 10^12 FLOPs ( ≈ 4.51 TFLOPs )
(c)基于你在上一步中的分析,模型中 哪些部分消耗了最多的 FLOPs ?
Deliverable :在这个设置下,FFN(SwiGLU)是最大头(每层约 62.9B FLOPs),其次是 Q/K/V/O 投影(约 21.0B),再其次是 注意力两次 matmul(QKᵀ + AV)(约 6.7B);lm_head 占比相对较小。
GPT-2 small :12 层,d_model = 768,12 个注意力头
GPT-2 medium :24 层,d_model = 1024,16 个注意力头
GPT-2 large :36 层,d_model = 1280,20 个注意力头
Transformer LM 的哪些组成部分在总 FLOPs 中所占比例 增加 ?
哪些部分所占比例 减少 ?
Deliverable :对每个模型给出各组件的 FLOPs 分解(以占总前向 FLOPs 的比例表示),并用一到两句话说明模型规模变化如何影响各组件的 FLOPs 占比。
GPT-2 small(L=12, d=768, h=12)
Projections:16.58%
Attention matmuls:11.06%
FFN:49.75%
LM head:22.61%
GPT-2 medium(L=24, d=1024, h=16)
Projections:19.96%
Attention matmuls:9.98%
FFN:59.87%
LM head:10.20%
GPT-2 large(L=36, d=1280, h=20)
Projections:21.40%
Attention matmuls:8.56%
FFN:64.20%
LM head:5.84%
随着模型变大(层数/宽度增大),与 d^2 或 dd_ff 成正比的 FFN 与投影占比上升;与 S^2 d 成正比的 注意力 matmul 占比下降,而 lm_head(~Sd V)占比也明显下降。
(e)以 GPT-2 XL 为例,将 context_length 增加到 16,384 :
单次前向传播的 总 FLOPs 会如何变化?
模型各组件在 FLOPs 中的 相对贡献比例 将如何变化?
Deliverable :context_length 增长 16× 时,线性项(投影、FFN、LM head)FLOPs 也增长 16×,而注意力 matmul (QKᵀ、AV)的 FLOPS 是平方增长,即增长 256×。总 FLOPs 将从 ≈ 4.51×10¹² 增加到 ≈ 1.495×10¹⁴ FLOPs(约 33.1× 增长),相对贡献也会变化,其中注意力 matmul 将变成主导项(约 55%),FFN 与投影占比下降(FFN 约 32%,投影约 11%,LM head 约 1.8%)。
OK,以上就是本次 Transformer Language Model Architecture 作业的全部实现了
结语 本文完整实现了 CS336 Assignment 1 中的 Transformer Language Model,从最基础的线性层、嵌入层与归一化模块出发,逐步构建出支持 RoPE、因果多头自注意力、SwiGLU 前馈网络与 pre-norm 结构的完整 Transformer Block,并最终将其堆叠为一个可端到端前向计算的语言模型。
与只关注单个模块不同,本次实现的关键收获在于:所有组件在真实模型结构中是如何协同工作的。从 batch-like 维度的处理方式,到注意力中 Q/K/V 的张量重排,再到 RoPE 在多头维度上的应用,这些细节只有在全部拼接起来之后,才会真正暴露其复杂性与约束条件。
此外,通过对 GPT-2 XL 规模模型的参数量与 FLOPs 核算,可以清晰地看到 Transformer 在计算与内存上的主要瓶颈来自何处,也为后续我们理解训练吞吐、显存占用、并行化策略与模型规模扩展问题提供了一个可量化的基准。
完成本小节后,我们已经具备了一个结构正确、数值稳定、接口清晰的 Transformer LM 实现。在接下来的作业中,我们将进一步进入训练过程本身,包括损失函数、优化器、数据加载以及训练效率问题。
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
随机西班牙地址生成器 随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
Gemini 图片去水印 基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online