详解如何复现 LLaMA 4:从零开始利用 Python 构建
LLaMA 4 展示了基于 MoE(Mixture-of-Experts,混合专家)模型的优势。本文从零开始构建 LLaMA 4 的 MoE 架构,以了解它是如何实际构建的。
LLaMA 4 MoE 架构概述
想象一下,你有一个非常艰巨的任务。与其雇佣一个对什么都懂一点的人,不如雇佣一个团队,每个成员都是某个特定领域的专家。AI 模型中的 MoE 就有点像这样:
- 一组'专家':这些是较小的、专门化的神经网络(通常是简单的前馈网络或 MLP)。
- 一个'路由器'(经理):这是另一个小型网络。它的任务是查看输入数据,并决定哪个专家最适合处理它。

假设我们的模型正在处理句子:'The cat sat。'
- 分词:首先,我们将句子分解成片段(分词):'The' 'cat' 'sat'。
- 路由器接收分词:MoE 层接收到分词
cat。路由器查看这个cat向量。 - 路由器选择:假设我们有 4 个专家(
E1、E2、E3、E4)。路由器决定哪些专家最适合处理cat。 - 权重分配:假设它认为
E2和E4是最合适的选择。它会给这些选择分配分数或'权重'。

cat 向量只发送给 Expert 2 和 Expert 4。E2 处理 cat 并生成其结果。E4 处理 cat 并生成其结果。现在,我们使用路由器权重将选定专家的结果组合起来:Final_Output = (0.7 * Output_E2) + (0.3 * Output_E4)。
当我们的模型处理像 "The cat sat." 这样的文本时,整个流程如下所示:

输入文本进入分词器。分词器将分词 ID 转换为有意义的数字向量(嵌入向量),并添加位置信息(稍后在注意力中使用 RoPE)。
这些向量通过多个Transformer 块。每个块包含:
- 自注意力(分词相互查看,由 RoPE 增强)。
- MoE 层(路由器将分词发送到特定的专家)。
- 归一化(RMSNorm)和残差连接有助于学习。
最后一个块的输出进入最终层。这一层为词汇表中每个可能的下一个分词生成分数(logits)。
我们将分数转换为概率,并预测下一个分词。
搭建舞台
在开始编写模型代码之前,我们需要导入我们将要使用的模块。
# 导入必要的库
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import math
import os
import collections
import re
# --- 设备配置 ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备:{device}")
定义训练语料库
我们需要一些文本数据来训练我们的语言模型。在我们的例子中,我们将使用刘易斯·卡罗尔的《爱丽丝梦游仙境》中的一个小段落。
# 定义原始文本语料库用于训练
corpus_raw = """ Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do: once or twice she had peeped into the book her sister was reading, but it had no pictures or conversations in it, 'and what is the use of a book,' thought Alice 'without pictures or conversation?' So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her. """
print(f"训练语料库已定义(长度:{len(corpus_raw)} 个字符)。")
字符级分词
计算机不懂字母,它只懂数字。分词是将文本转换为模型可以处理的数字的过程。
- 找出
corpus_raw中的所有唯一字符。 - 为每个唯一字符分配一个唯一的整数 ID。
- 创建映射(字典),将字符转换为 ID(
char_to_int)和将 ID 转换回字符(int_to_char)。
# 找出原始语料库中的所有唯一字符
chars = sorted(list(set(corpus_raw)))
vocab_size = len(chars)
# 创建字符到整数的映射(编码)
char_to_int = {ch: i for i, ch in enumerate(chars)}
# 创建整数到字符的映射(解码)
int_to_char = {i: ch for i, ch in enumerate(chars)}
print(f"创建了大小为:{vocab_size} 的字符词汇表")
print(f"词汇表:{''.join(chars)}")
编码语料库
现在我们使用刚才创建的 char_to_int 映射,将整个 corpus_raw 字符串转换为对应的整数 ID 序列。
# 将整个语料库编码为整数 ID 列表
encoded_corpus = [char_to_int[ch] for ch in corpus_raw]
# 将列表转换为 PyTorch 张量
full_data_sequence = torch.tensor(encoded_corpus, dtype=torch.long, device=device)
print(f"将语料库编码为张量,形状为:{full_data_sequence.shape}")
定义超参数
接下来,我们需要定义超参数设置,这些是在训练之前选择的。它们定义了模型的架构以及它是如何学习的。
# --- 模型架构超参数 ---
d_model = 128 # 嵌入维度(大幅降低)
n_layers = 4 # Transformer 块的数量(降低)
n_heads = 4 # 注意力头的数量
block_size = 64 # 最大上下文长度(序列长度)
rms_norm_eps = 1e-5 # RMSNorm 稳定性的微小值
rope_theta = 10000.0 # RoPE 的 theta 参数
# --- MoE 特定超参数 ---
num_local_experts = 4 # 每个 MoE 层中的专家数量
num_experts_per_tok = 2 # 每个分词路由到的专家数量(Top-K)
intermediate_size_expert = d_model * 2 # 专家 MLP 中的隐藏维度
intermediate_size_shared = d_model * 2 # 共享 MLP 中的隐藏维度
# --- 训练超参数 ---
learning_rate = 5e-4
batch_size = 16
epochs = 3000
eval_interval = 300
# --- 推导超参数 ---
assert d_model % n_heads == 0, "d_model 必须能被 n_heads 整除"
d_k = d_model // n_heads
expert_dim = intermediate_size_expert
shared_expert_dim = intermediate_size_shared
训练数据准备
像我们这样的语言模型是通过预测给定之前分词的下一个分词来学习的。为了准备数据,我们在 full_data_sequence 上滑动一个长度为 block_size 的窗口。
# 创建列表以保存所有可能的输入(x)和目标(y)序列
all_x = []
all_y = []
num_total_tokens = len(full_data_sequence)
for i in range(num_total_tokens - block_size):
x_chunk = full_data_sequence[i : i + block_size]
y_chunk = full_data_sequence[i + 1: i + block_size + 1]
all_x.append(x_chunk)
all_y.append(y_chunk)
# 将列表中的张量堆叠成单个大张量
train_x = torch.stack(all_x)
train_y = torch.stack(all_y)
num_sequences_available = train_x.shape[0]
print(f"创建了 {num_sequences_available} 个重叠的输入/目标序列对。")
print(f"train_x 的形状:{train_x.shape}")
print(f"train_y 的形状:{train_y.shape}")
批量策略(随机抽样)
一次性在整个数据集上进行训练通常会占用过多的内存。相反,我们使用 mini-batch 进行训练。
if num_sequences_available < batch_size:
print(f"警告:序列数量 ({num_sequences_available}) 小于批量大小 ({batch_size})。正在调整批量大小。")
batch_size = num_sequences_available
print(f"数据已准备好用于训练。将随机抽取大小为 {batch_size} 的批量。")
模型组件初始化
嵌入层初始化
这是模型的第一层。它将整数分词 ID 转换为大小为 d_model 的密集向量。
token_embedding_table = nn.Embedding(vocab_size, d_model).to(device)
print(f"初始化分词嵌入层:")
print(f" 输入词汇表大小:{vocab_size}")
print(f" 输出嵌入维度 (d_model):{d_model}")
print(f" 权重形状:{token_embedding_table.weight.shape}")
旋转位置嵌入(RoPE)预计算
Transformer 本身并不理解词序。位置编码会添加这种信息。RoPE 是像 LLaMA 这样的模型中使用的一种巧妙方法。
# 预计算 RoPE 的逆频率
rope_freq_indices = torch.arange(0, d_k, 2, dtype=torch.float, device=device)
inv_freq = 1.0 / (rope_theta ** (rope_freq_indices / d_k))
print("预计算的 RoPE 逆频率 (inv_freq):")
print(f" 形状:{inv_freq.shape}")
print(f" 值(前 5 个):{inv_freq[:5].tolist()}")
RMSNorm 层初始化
归一化层有助于稳定训练。LLaMA 使用 RMSNorm(Root Mean Square Normalization)。
rmsnorm_weights_input = []
rmsnorm_weights_post_attn = []
print(f"初始化 {n_layers} 层的 RMSNorm 权重...")
for i in range(n_layers):
weight_in = nn.Parameter(torch.ones(d_model, device=device))
rmsnorm_weights_input.append(weight_in)
weight_post = nn.Parameter(torch.ones(d_model, device=device))
rmsnorm_weights_post_attn.append(weight_post)
final_rmsnorm_weight = nn.Parameter(torch.ones(d_model, device=device))
print(f"初始化最终 RMSNorm 权重,形状:{final_rmsnorm_weight.shape}")
注意力层初始化(MHA)
Transformer 的核心是自注意力机制。我们使用的是多头注意力(MHA)。
mha_qkv_linears = []
mha_output_linears = []
print(f"初始化 {n_layers} 层的注意力(MHA)线性层...")
for i in range(n_layers):
qkv_linear = nn.Linear(d_model, 3 * d_model, bias=False).to(device)
mha_qkv_linears.append(qkv_linear)
output_linear = nn.Linear(d_model, d_model, bias=False).to(device)
mha_output_linears.append(output_linear)
print("注意力(MHA)线性层已初始化。")
混合专家(MoE)层初始化
在注意力块之后,我们使用了一个 MoE 层。
moe_routers = []
moe_expert_gate_up_proj = []
moe_expert_down_proj = []
shared_expert_gate_proj = []
shared_expert_up_proj = []
shared_expert_down_proj = []
print(f"初始化 {n_layers} 层的 MoE 和共享 MLP 组件...")
for i in range(n_layers):
router_linear = nn.Linear(d_model, num_local_experts, bias=False).to(device)
moe_routers.append(router_linear)
gate_up_w = nn.Parameter(torch.empty(num_local_experts, d_model, 2 * expert_dim, device=device))
nn.init.normal_(gate_up_w, mean=0.0, std=0.02)
moe_expert_gate_up_proj.append(gate_up_w)
down_w = nn.Parameter(torch.empty(num_local_experts, expert_dim, d_model, device=device))
nn.init.normal_(down_w, mean=0.0, std=0.02)
moe_expert_down_proj.append(down_w)
shared_gate = nn.Linear(d_model, shared_expert_dim, bias=False).to(device)
shared_up = nn.Linear(d_model, shared_expert_dim, bias=False).to(device)
shared_down = nn.Linear(shared_expert_dim, d_model, bias=False).to(device)
shared_expert_gate_proj.append(shared_gate)
shared_expert_up_proj.append(shared_up)
shared_expert_down_proj.append(shared_down)
activation_fn = nn.SiLU()
最终输出层初始化
经过所有 Transformer 层之后,最终的隐藏状态需要转换为下一个分词的预测。
output_linear_layer = nn.Linear(d_model, vocab_size, bias=False).to(device)
print(f"初始化最终输出线性层:")
print(f" 输入维度 (d_model):{d_model}")
print(f" 输出维度 (vocab_size):{vocab_size}")
print(f" 权重形状:{output_linear_layer.weight.shape}")
因果掩码预计算
在仅解码器 Transformer 中,当预测位置 t 的分词时,模型只能关注位置 0 到 t。
causal_mask = torch.tril(torch.ones(block_size, block_size, device=device))
causal_mask = causal_mask.view(1, 1, block_size, block_size)
print("预计算的因果注意力掩码:")
print(f" 形状:{causal_mask.shape}")
训练设置
优化器
all_model_parameters = list(token_embedding_table.parameters())
all_model_parameters.extend(rmsnorm_weights_input)
all_model_parameters.extend(rmsnorm_weights_post_attn)
all_model_parameters.append(final_rmsnorm_weight)
for i in range(n_layers):
all_model_parameters.extend(list(mha_qkv_linears[i].parameters()))
all_model_parameters.extend(list(mha_output_linears[i].parameters()))
all_model_parameters.extend(list(moe_routers[i].parameters()))
all_model_parameters.extend(moe_expert_gate_up_proj)
all_model_parameters.extend(moe_expert_down_proj)
all_model_parameters.extend(list(shared_expert_gate_proj[i].parameters()))
all_model_parameters.extend(list(shared_expert_up_proj[i].parameters()))
all_model_parameters.extend(list(shared_expert_down_proj[i].parameters()))
all_model_parameters.extend(list(output_linear_layer.parameters()))
num_param_groups = len(all_model_parameters)
total_params = sum(p.numel() for p in all_model_parameters if p.requires_grad)
optimizer = optim.AdamW(all_model_parameters, lr=learning_rate)
print("优化器设置:")
print(f" 优化器:{type(optimizer).__name__}")
print(f" 学习率:{learning_rate}")
print(f" 总可训练参数:{total_params:,}")
损失函数
criterion = nn.CrossEntropyLoss()
训练模型
我们将通过迭代地向模型输入批量数据,计算损失,并使用优化器更新参数来进行训练。
print(f"\n--- 开始训练循环,共 {epochs} 个周期 ---")
losses = []
for epoch in range(epochs):
xb, yb = train_x[torch.randint(0, num_sequences_available, (batch_size,))].to(device), \
train_y[torch.randint(0, num_sequences_available, (batch_size,))].to(device)
token_embed = token_embedding_table(xb)
position_ids = torch.arange(xb.shape[1], device=device).unsqueeze(0)
freqs_cis = torch.polar(torch.ones_like(position_ids), (inv_freq.unsqueeze(0).unsqueeze(-1).expand(xb.shape[0], -1, 1).float() @ position_ids.unsqueeze(1).expand(xb.shape[0], -1).float()).transpose(1, 2))
x = token_embed
for i in range(n_layers):
x_norm = (x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + rms_norm_eps)) * rmsnorm_weights_input[i]
qkv = mha_qkv_linears[i](x_norm).view(xb.shape[0], xb.shape[1], n_heads, 3 * d_k).chunk(3, dim=-1)
q, k, v = qkv[0], qkv[1], qkv[2]
q_rope, k_rope = q.float().reshape(xb.shape[0], xb.shape[1], n_heads, -1, 2), k.float().reshape(xb.shape[0], xb.shape[1], n_heads, -1, 2)
q, k = torch.view_as_real(torch.view_as_complex(q_rope) * freqs_cis.unsqueeze(2)).flatten(3), \
torch.view_as_real(torch.view_as_complex(k_rope) * freqs_cis.unsqueeze(2)).flatten(3)
attn_scores = (q @ k.transpose(-2, -1)) * (d_k ** -0.5)
attn_scores = attn_scores.masked_fill(causal_mask[:, :, :xb.shape[1], :xb.shape[1]] == 0, float('-inf'))
attention_weights = F.softmax(attn_scores, dim=-1)
attn_output = attention_weights @ v
x = x + mha_output_linears[i](attn_output.permute(0, 2, 1, 3).contiguous().view(xb.shape[0], xb.shape[1], d_model))
x_norm = (x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + rms_norm_eps)) * rmsnorm_weights_post_attn[i]
router_logits = moe_routers[i](x_norm)
routing_weights, selected_experts = torch.sigmoid(torch.topk(router_logits, num_experts_per_tok, dim=-1)[0]), \
torch.topk(router_logits, num_experts_per_tok, dim=-1)[1]
x_flat = x_norm.view(-1, d_model)
selected_experts_flat = selected_experts.view(-1)
routing_weights_flat = routing_weights.view(-1)
token_idx = torch.arange(xb.shape[0] * xb.shape[1], device=device).repeat_interleave(num_experts_per_tok)
expert_inputs = x_flat[token_idx]
gate_up_states = torch.bmm(expert_inputs.unsqueeze(1), moe_expert_gate_up_proj[i][selected_experts_flat])
activated_states = activation_fn(gate_up_states.chunk(2, dim=-1)[0]) * gate_up_states.chunk(2, dim=-1)[1]
expert_outputs_weighted = torch.bmm(activated_states, moe_expert_down_proj[i][selected_experts_flat]).squeeze(1) * \
routing_weights_flat.unsqueeze(-1)
combined_expert_outputs = torch.zeros_like(x_flat)
combined_expert_outputs.scatter_add_(0, token_idx.unsqueeze(-1).expand(-1, d_model), expert_outputs_weighted)
shared_output = shared_expert_down_proj[i](activation_fn(shared_expert_gate_proj[i](x_norm)) * shared_expert_up_proj[i](x_norm))
x = x + combined_expert_outputs.view(xb.shape[0], xb.shape[1], d_model) + shared_output
logits = output_linear_layer((x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + rms_norm_eps)) * final_rmsnorm_weight)
loss = criterion(logits.view(-1, logits.shape[-1]), yb.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if epoch % eval_interval == 0 or epoch == epochs - 1:
print(f" 第 {epoch+1}/{epochs} 个周期,损失:{loss.item():.4f}")
print("--- 训练循环完成 ---")
文本生成
现在模型已经训练完成,让我们看看它能写出什么!
print("\n--- 第 7 步:文本生成 ---")
seed_chars = "Alice "
num_tokens_to_generate = 200
print(f"种子文本:'{seed_chars}'")
print(f"生成 {num_tokens_to_generate} 个新分词...")
seed_ids = [char_to_int[ch] for ch in seed_chars if ch in char_to_int]
generated_sequence = torch.tensor([seed_ids], dtype=torch.long, device=device)
print(f"初始上下文形状:{generated_sequence.shape}")
token_embedding_table.eval()
for i in range(n_layers):
mha_qkv_linears[i].eval()
mha_output_linears[i].eval()
moe_routers[i].eval()
shared_expert_gate_proj[i].eval()
shared_expert_up_proj[i].eval()
shared_expert_down_proj[i].eval()
output_linear_layer.eval()
print("已将模型组件设置为评估模式(适用时)。")
生成循环
print("开始生成循环...")
with torch.no_grad():
for _ in range(num_tokens_to_generate):
current_context = generated_sequence[:, -block_size:]
B_gen, T_gen = current_context.shape
token_embed_gen = token_embedding_table(current_context)
freqs_gen = torch.polar(torch.ones_like(position_ids_gen), (inv_freq.unsqueeze(0).unsqueeze(-1).expand(B_gen, -1, 1) @ position_ids_gen.float()).transpose(1, 2))
x_gen = token_embed_gen
for i in range(n_layers):
x_norm_gen = (x_gen.float() * torch.rsqrt(x_gen.float().pow(2).mean(-1, keepdim=True) + rms_norm_eps)) * rmsnorm_weights_input[i]
qkv_gen = mha_qkv_linears[i](x_norm_gen).view(B_gen, T_gen, n_heads, 3 * d_k).chunk(3, dim=-1)
q_rotated_gen = torch.view_as_real(torch.view_as_complex(qkv_gen[0].reshape(B_gen, T_gen, n_heads, -1, 2)) * freqs_gen.unsqueeze(2))
k_rotated_gen = torch.view_as_real(torch.view_as_complex(qkv_gen[1].reshape(B_gen, T_gen, n_heads, -1, 2)) * freqs_gen.unsqueeze(2))
attn_output_gen = (F.softmax((q_rotated_gen.permute(0, 2, 1, 3) @ k_rotated_gen.permute(0, 2, 1, 3).transpose(-2, -1)) * (d_k ** -0.5), dim=-1) @ qkv_gen[2].permute(0, 2, 1, 3)).view(B_gen, T_gen, d_model)
x_gen = x_gen + mha_output_linears[i](attn_output_gen)
x_norm_gen = (x_gen.float() * torch.rsqrt(x_gen.float().pow(2).mean(-1, keepdim=True) + rms_norm_eps)) * rmsnorm_weights_post_attn[i]
routing_weights_gen = torch.sigmoid(torch.topk(moe_routers[i](x_norm_gen), num_experts_per_tok, dim=-1)[0])
expert_outputs_gen = torch.bmm(activation_fn(torch.chunk(torch.bmm(x_norm_gen.view(-1, d_model), moe_expert_gate_up_proj[i][torch.topk(moe_routers[i](x_norm_gen), num_experts_per_tok, dim=-1)[1]]).squeeze(1), 2)[0]) * routing_weights_gen, moe_expert_down_proj[i][torch.topk(moe_routers[i](x_norm_gen), num_experts_per_tok, dim=-1)[1]]).squeeze(1)
x_gen = x_gen + expert_outputs_gen.view(B_gen, T_gen, d_model) + shared_expert_down_proj[i](activation_fn(shared_expert_gate_proj[i](x_norm_gen)) * shared_expert_up_proj[i](x_norm_gen))
logits_gen = output_linear_layer((x_gen.float() * torch.rsqrt(x_gen.float().pow(2).mean(-1, keepdim=True) + rms_norm_eps)) * final_rmsnorm_weight)
next_token = torch.multinomial(F.softmax(logits_gen[:, -1, :], dim=-1), num_samples=1)
generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
print("...生成循环完成。")
解码生成序列
final_generated_ids = generated_sequence[0].tolist()
decoded_text = ''.join([int_to_char.get(id_val, '[UNK]') for id_val in final_generated_ids])
print("\n--- 最终生成的文本 ---")
print(decoded_text)
结论
本文涵盖了以下内容:
- 设置和分词:基本的环境设置和字符级分词。
- 超参数定义:从大型模型中缩小的配置值。
- 数据准备:为下一个分词预测创建输入/目标序列。
- 模型初始化(内联):显式创建和初始化组件,如分词嵌入、RMSNorm 权重、注意力线性层、RoPE 频率基础、MoE 路由器、MoE 专家权重、共享专家 MLP 和最终输出层。
- 训练循环(内联):在循环中实现完整的前向传播,展示应用 RMSNorm、RoPE、MoE 前向传播、标准 Transformer 操作、损失计算、反向传播和优化器步骤。
- 文本生成:在评估模式下使用训练好的模型组件进行自回归采样。


