从零开始用 Python 复现 LLaMA 4 MoE 架构
LLaMA 系列模型展示了基于 MoE(Mixture-of-Experts,混合专家)架构的优势。在本教程中,我们将深入理解 MoE 层如何工作,并从头构建一个简化版的 LLaMA 4 MoE 模型。
MoE 架构概述
想象一下,你有一个非常艰巨的任务。与其雇佣一个对什么都懂一点的人,不如雇佣一个团队,每个成员都是某个特定领域的专家。AI 模型中的 MoE 就有点像这样:
- 一组'专家':这些是较小的、专门化的神经网络(通常是简单的前馈网络或 MLP)。
- 一个'路由器':另一个小型网络,负责查看输入数据,决定哪个专家最适合处理它。
假设我们的模型正在处理句子:'The cat sat'。
- 分词:将句子分解成片段:'The' 'cat' 'sat'。
- 路由器接收分词:MoE 层接收到
cat的嵌入向量。 - 路由器选择:假设有 4 个专家,路由器决定哪些专家最适合。例如,认为
E2(擅长名词)和E4(擅长动物概念)最合适,分配权重(如E270%,E430%)。 - 组合结果:使用路由器权重将选定专家的结果组合起来:
Final_Output = (0.7 * Output_E2) + (0.3 * Output_E4)。
这个过程会针对序列中的每个分词重复进行。整个流程包括:
- 输入文本进入分词器,转换为嵌入向量并添加位置信息(RoPE)。
- 向量通过多个Transformer 块,包含自注意力、MoE 层、归一化(RMSNorm)和残差连接。
- 最后一个块的输出进入最终层,生成下一个分词的分数(logits)。
现在我们对 MoE 的作用有了初步了解,接下来让我们深入代码,逐步构建这些组件。
搭建环境
在编写模型代码之前,我们需要导入必要的模块并配置设备。
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import math
import os
import collections
import re
# 设备配置
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备:{device}")
确认库已导入且设备配置正确。我将使用 GPU 来训练模型。
定义训练语料库
我们需要一些文本数据。为了演示代码逻辑,我们使用《爱丽丝梦游仙境》中的一小段文本。
corpus_raw = """ Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do: once or twice she had peeped into the book her sister was reading, but it had no pictures or conversations in it, 'and what is the use of a book,' thought Alice 'without pictures or conversation?' So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her. """
print(f"训练语料库已定义(长度:{len(corpus_raw)} 个字符)。")
这定义了一个包含示例文本的字符串变量。
字符级分词
计算机只懂数字。分词是将文本转换为模型可处理的数字的过程。我们使用最简单的字符级分词:
- 找出所有唯一字符。
- 为每个唯一字符分配唯一的整数 ID。
- 创建映射字典。
chars = sorted(list(set(corpus_raw)))
vocab_size = len(chars)
char_to_int = {ch: i for i, ch in enumerate(chars)}
int_to_char = {i: ch for i, ch in enumerate(chars)}
print(f"创建了大小为:{vocab_size} 的字符词汇表")
代码找到了 36 个唯一字符,并创建了双向映射字典。
编码语料库
使用 char_to_int 映射将整个语料库转换为整数 ID 序列,并存储为 PyTorch 张量。
encoded_corpus = [char_to_int[ch] for ch in corpus_raw]
full_data_sequence = torch.tensor(encoded_corpus, dtype=torch.long, device=device)
print(f"将语料库编码为张量,形状为:{full_data_sequence.shape}")
593 个字符的文本被转换为长度为 593 的张量。
定义超参数
接下来定义超参数,它们定义了模型的架构和学习方式。
# 模型架构超参数
d_model = 128
n_layers = 4
n_heads = 4
block_size = 64
rms_norm_eps = 1e-5
rope_theta = 10000.0
# MoE 特定超参数
num_local_experts = 4
num_experts_per_tok = 2
intermediate_size_expert = d_model * 2
intermediate_size_shared = d_model * 2
# 训练超参数
learning_rate = 5e-4
batch_size = 16
epochs = 3000
eval_interval = 300
assert d_model % n_heads == 0
d_k = d_model // n_heads
expert_dim = intermediate_size_expert
shared_expert_dim = intermediate_size_shared
这些值比真实模型小得多,以便在典型硬件上快速运行。
训练数据准备
语言模型通过预测给定之前分词的下一个分词来学习。我们在 full_data_sequence 上滑动一个长度为 block_size 的窗口。
all_x = []
all_y = []
num_total_tokens = len(full_data_sequence)
for i in range(num_total_tokens - block_size):
x_chunk = full_data_sequence[i : i + block_size]
y_chunk = full_data_sequence[i + 1: i + block_size + 1]
all_x.append(x_chunk)
all_y.append(y_chunk)
train_x = torch.stack(all_x)
train_y = torch.stack(all_y)
num_sequences_available = train_x.shape[0]
print(f"创建了 {num_sequences_available} 个重叠的输入/目标序列对。")
从 593 个字符中提取出 529 个长度为 64 的重叠序列。
批量策略
使用 mini-batch 进行训练。在每个训练步骤中,随机选择 batch_size 个索引。
if num_sequences_available < batch_size:
print(f"警告:序列数量 ({num_sequences_available}) 小于批量大小 ({batch_size})。正在调整批量大小。")
batch_size = num_sequences_available
print(f"数据已准备好用于训练。将随机抽取大小为 {batch_size} 的批量。")
模型组件初始化
嵌入层
将整数分词 ID 转换为大小为 d_model 的密集向量。
token_embedding_table = nn.Embedding(vocab_size, d_model).to(device)
print(f"初始化分词嵌入层:权重形状 {token_embedding_table.weight.shape}")
RoPE 预计算
Transformer 本身不理解词序。RoPE 根据位置旋转 Q 和 K 向量。
rope_freq_indices = torch.arange(0, d_k, 2, dtype=torch.float, device=device)
inv_freq = 1.0 / (rope_theta ** (rope_freq_indices / d_k))
print("预计算的 RoPE 逆频率 (inv_freq):", inv_freq[:5].tolist())
RMSNorm 层
LLaMA 使用 RMSNorm,比标准层归一化更简单。
rmsnorm_weights_input = []
rmsnorm_weights_post_attn = []
for i in range(n_layers):
weight_in = nn.Parameter(torch.ones(d_model, device=device))
rmsnorm_weights_input.append(weight_in)
weight_post = nn.Parameter(torch.ones(d_model, device=device))
rmsnorm_weights_post_attn.append(weight_post)
final_rmsnorm_weight = nn.Parameter(torch.ones(d_model, device=device))
注意力层(MHA)
初始化 QKV 投影和输出投影线性层。
mha_qkv_linears = []
mha_output_linears = []
for i in range(n_layers):
qkv_linear = nn.Linear(d_model, 3 * d_model, bias=False).to(device)
mha_qkv_linears.append(qkv_linear)
output_linear = nn.Linear(d_model, d_model, bias=False).to(device)
mha_output_linears.append(output_linear)
混合专家(MoE)层
这是特殊的部分。包含路由器、专家 MLP 和共享专家。
moe_routers = []
moe_expert_gate_up_proj = []
moe_expert_down_proj = []
shared_expert_gate_proj = []
shared_expert_up_proj = []
shared_expert_down_proj = []
activation_fn = nn.SiLU()
for i in range(n_layers):
# 路由器
router_linear = nn.Linear(d_model, num_local_experts, bias=False).to(device)
moe_routers.append(router_linear)
# 专家权重
gate_up_w = nn.Parameter(torch.empty(num_local_experts, d_model, 2 * expert_dim, device=device))
nn.init.normal_(gate_up_w, mean=0.0, std=0.02)
moe_expert_gate_up_proj.append(gate_up_w)
down_w = nn.Parameter(torch.empty(num_local_experts, expert_dim, d_model, device=device))
nn.init.normal_(down_w, mean=0.0, std=0.02)
moe_expert_down_proj.append(down_w)
# 共享专家
shared_gate = nn.Linear(d_model, shared_expert_dim, bias=False).to(device)
shared_up = nn.Linear(d_model, shared_expert_dim, bias=False).to(device)
shared_down = nn.Linear(shared_expert_dim, d_model, bias=False).to(device)
shared_expert_gate_proj.append(shared_gate)
shared_expert_up_proj.append(shared_up)
shared_expert_down_proj.append(shared_down)
最终输出层
将隐藏状态投影到词汇表大小。
output_linear_layer = nn.Linear(d_model, vocab_size, bias=False).to(device)
因果掩码
仅解码器 Transformer 需要因果掩码,确保只能关注当前位置及之前的位置。
causal_mask = torch.tril(torch.ones(block_size, block_size, device=device))
causal_mask = causal_mask.view(1, 1, block_size, block_size)
训练设置
收集所有需要梯度的参数,定义优化器和损失函数。
all_model_parameters = list(token_embedding_table.parameters())
all_model_parameters.extend(rmsnorm_weights_input)
all_model_parameters.extend(rmsnorm_weights_post_attn)
all_model_parameters.append(final_rmsnorm_weight)
for i in range(n_layers):
all_model_parameters.extend(list(mha_qkv_linears[i].parameters()))
all_model_parameters.extend(list(mha_output_linears[i].parameters()))
all_model_parameters.extend(list(moe_routers[i].parameters()))
all_model_parameters.extend(moe_expert_gate_up_proj)
all_model_parameters.extend(moe_expert_down_proj)
all_model_parameters.extend(list(shared_expert_gate_proj[i].parameters()))
all_model_parameters.extend(list(shared_expert_up_proj[i].parameters()))
all_model_parameters.extend(list(shared_expert_down_proj[i].parameters()))
all_model_parameters.extend(list(output_linear_layer.parameters()))
optimizer = optim.AdamW(all_model_parameters, lr=learning_rate)
criterion = nn.CrossEntropyLoss()
训练模型
迭代地向模型输入批量数据,计算损失并更新参数。
print(f"\n--- 开始训练循环,共 {epochs} 个周期 ---")
losses = []
for epoch in range(epochs):
xb, yb = train_x[torch.randint(0, num_sequences_available, (batch_size,))].to(device), \
train_y[torch.randint(0, num_sequences_available, (batch_size,))].to(device)
token_embed = token_embedding_table(xb)
position_ids = torch.arange(xb.shape[1], device=device).unsqueeze(0)
freqs_cis = torch.polar(torch.ones_like(position_ids), (inv_freq.unsqueeze(0).unsqueeze(-1).expand(xb.shape[0], -1, 1).float() @ position_ids.unsqueeze(1).expand(xb.shape[0], -1).float()).transpose(1, 2))
x = token_embed
for i in range(n_layers):
# RMSNorm 和注意力
x_norm = (x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + rms_norm_eps)) * rmsnorm_weights_input[i]
qkv = mha_qkv_linears[i](x_norm).view(xb.shape[0], xb.shape[1], n_heads, 3 * d_k).chunk(3, dim=-1)
q, k, v = qkv[0], qkv[1], qkv[2]
q_rope, k_rope = q.float().reshape(xb.shape[0], xb.shape[1], n_heads, -1, 2), k.float().reshape(xb.shape[0], xb.shape[1], n_heads, -1, 2)
q, k = torch.view_as_real(torch.view_as_complex(q_rope) * freqs_cis.unsqueeze(2)).flatten(3), \
torch.view_as_real(torch.view_as_complex(k_rope) * freqs_cis.unsqueeze(2)).flatten(3)
attn_scores = (q @ k.transpose(-2, -1)) * (d_k ** -0.5)
attn_scores = attn_scores.masked_fill(causal_mask[:, :, :xb.shape[1], :xb.shape[1]] == 0, float('-inf'))
attention_weights = F.softmax(attn_scores, dim=-1)
attn_output = attention_weights @ v
x = x + mha_output_linears[i](attn_output.permute(0, 2, 1, 3).contiguous().view(xb.shape[0], xb.shape[1], d_model))
# MoE 块
x_norm = (x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + rms_norm_eps)) * rmsnorm_weights_post_attn[i]
router_logits = moe_routers[i](x_norm)
routing_weights, selected_experts = torch.sigmoid(torch.topk(router_logits, num_experts_per_tok, dim=-1)[0]), \
torch.topk(router_logits, num_experts_per_tok, dim=-1)[1]
x_flat = x_norm.view(-1, d_model)
selected_experts_flat = selected_experts.view(-1)
routing_weights_flat = routing_weights.view(-1)
token_idx = torch.arange(xb.shape[0] * xb.shape[1], device=device).repeat_interleave(num_experts_per_tok)
expert_inputs = x_flat[token_idx]
gate_up_states = torch.bmm(expert_inputs.unsqueeze(1), moe_expert_gate_up_proj[i][selected_experts_flat])
activated_states = activation_fn(gate_up_states.chunk(2, dim=-1)[0]) * gate_up_states.chunk(2, dim=-1)[1]
expert_outputs_weighted = torch.bmm(activated_states, moe_expert_down_proj[i][selected_experts_flat]).squeeze(1) * \
routing_weights_flat.unsqueeze(-1)
combined_expert_outputs = torch.zeros_like(x_flat)
combined_expert_outputs.scatter_add_(0, token_idx.unsqueeze(-1).expand(-1, d_model), expert_outputs_weighted)
shared_output = shared_expert_down_proj[i](activation_fn(shared_expert_gate_proj[i](x_norm)) * shared_expert_up_proj[i](x_norm))
x = x + combined_expert_outputs.view(xb.shape[0], xb.shape[1], d_model) + shared_output
logits = output_linear_layer((x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + rms_norm_eps)) * final_rmsnorm_weight)
loss = criterion(logits.view(-1, logits.shape[-1]), yb.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if epoch % eval_interval == 0 or epoch == epochs - 1:
print(f" 第 {epoch+1}/{epochs} 个周期,损失:{loss.item():.4f}")
print("--- 训练循环完成 ---")
文本生成
模型训练完成后,我们可以尝试生成文本。将模型设置为评估模式并使用 torch.no_grad()。
print("\n--- 文本生成 ---")
seed_chars = "Alice "
num_tokens_to_generate = 200
seed_ids = [char_to_int[ch] for ch in seed_chars if ch in char_to_int]
generated_sequence = torch.tensor([seed_ids], dtype=torch.long, device=device)
with torch.no_grad():
for _ in range(num_tokens_to_generate):
current_context = generated_sequence[:, -block_size:]
B_gen, T_gen = current_context.shape
token_embed_gen = token_embedding_table(current_context)
# ... (此处省略部分前向传播细节以保持简洁,逻辑同训练循环) ...
# 实际实现需完整复制前向传播逻辑
next_token = torch.multinomial(F.softmax(logits_gen[:, -1, :], dim=-1), num_samples=1)
generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
final_generated_ids = generated_sequence[0].tolist()
decoded_text = ''.join([int_to_char.get(id_val, '[UNK]') for id_val in final_generated_ids])
print("\n--- 最终生成的文本 ---")
print(decoded_text)
从 "Alice " 开始,模型生成了接下来的 200 个字符,展示了其学习到的文本风格和内容。
结论
我们完成了以下工作:
- 设置和分词:环境搭建和字符级分词。
- 超参数定义:适配硬件的配置值。
- 数据准备:创建输入/目标序列。
- 模型初始化:显式创建嵌入、RMSNorm、注意力、RoPE、MoE 等组件。
- 训练循环:实现完整的前向传播、损失计算和优化器步骤。
- 文本生成:在评估模式下进行自回归采样。
这个简化版模型成功展示了 MoE 层、RMSNorm 和 RoPE 如何协同工作,为大语言模型的底层结构提供了清晰的实战参考。


