跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
PythonAI算法

从零开始用 Python 复现 LLaMA 4 MoE 架构

综述由AI生成基于 PyTorch 从零实现 LLaMA 4 风格的 MoE 架构,涵盖分词、RoPE 位置编码、RMSNorm 归一化及混合专家层设计。通过小数据集训练验证了模型对文本模式的捕捉能力,并实现了自回归生成。重点解析了路由器选择机制与专家权重组合逻辑,为理解大语言模型底层结构提供实战参考。

花里胡哨发布于 2026/4/9更新于 2026/5/219 浏览
从零开始用 Python 复现 LLaMA 4 MoE 架构

从零开始用 Python 复现 LLaMA 4 MoE 架构

LLaMA 系列模型展示了基于 MoE(Mixture-of-Experts,混合专家)架构的优势。在本教程中,我们将深入理解 MoE 层如何工作,并从头构建一个简化版的 LLaMA 4 MoE 模型。

MoE 架构概述

想象一下,你有一个非常艰巨的任务。与其雇佣一个对什么都懂一点的人,不如雇佣一个团队,每个成员都是某个特定领域的专家。AI 模型中的 MoE 就有点像这样:

  1. 一组'专家':这些是较小的、专门化的神经网络(通常是简单的前馈网络或 MLP)。
  2. 一个'路由器':另一个小型网络,负责查看输入数据,决定哪个专家最适合处理它。

假设我们的模型正在处理句子:'The cat sat'。

  1. 分词:将句子分解成片段:'The' 'cat' 'sat'。
  2. 路由器接收分词:MoE 层接收到 cat 的嵌入向量。
  3. 路由器选择:假设有 4 个专家,路由器决定哪些专家最适合。例如,认为 E2(擅长名词)和 E4(擅长动物概念)最合适,分配权重(如 E2 70%,E4 30%)。
  4. 组合结果:使用路由器权重将选定专家的结果组合起来:Final_Output = (0.7 * Output_E2) + (0.3 * Output_E4)。

这个过程会针对序列中的每个分词重复进行。整个流程包括:

  • 输入文本进入分词器,转换为嵌入向量并添加位置信息(RoPE)。
  • 向量通过多个Transformer 块,包含自注意力、MoE 层、归一化(RMSNorm)和残差连接。
  • 最后一个块的输出进入最终层,生成下一个分词的分数(logits)。

现在我们对 MoE 的作用有了初步了解,接下来让我们深入代码,逐步构建这些组件。

搭建环境

在编写模型代码之前,我们需要导入必要的模块并配置设备。

import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import math
import os
import collections
import re

# 设备配置
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备:{device}")

确认库已导入且设备配置正确。我将使用 GPU 来训练模型。

定义训练语料库

我们需要一些文本数据。为了演示代码逻辑,我们使用《爱丽丝梦游仙境》中的一小段文本。

corpus_raw = """ Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do: once or twice she had peeped into the book her sister was reading, but it had no pictures or conversations in it, 'and what is the use of a book,' thought Alice 'without pictures or conversation?' So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her. """
print(f"训练语料库已定义(长度:{len(corpus_raw)} 个字符)。")

这定义了一个包含示例文本的字符串变量。

字符级分词

计算机只懂数字。分词是将文本转换为模型可处理的数字的过程。我们使用最简单的字符级分词:

  1. 找出所有唯一字符。
  2. 为每个唯一字符分配唯一的整数 ID。
  3. 创建映射字典。
chars = sorted(list(set(corpus_raw)))
vocab_size = len(chars)
char_to_int = {ch: i for i, ch in enumerate(chars)}
int_to_char = {i: ch for i, ch in enumerate(chars)}
print(f"创建了大小为:{vocab_size} 的字符词汇表")

代码找到了 36 个唯一字符,并创建了双向映射字典。

编码语料库

使用 char_to_int 映射将整个语料库转换为整数 ID 序列,并存储为 PyTorch 张量。

encoded_corpus = [char_to_int[ch] for ch in corpus_raw]
full_data_sequence = torch.tensor(encoded_corpus, dtype=torch.long, device=device)
print(f"将语料库编码为张量,形状为:{full_data_sequence.shape}")

593 个字符的文本被转换为长度为 593 的张量。

定义超参数

接下来定义超参数,它们定义了模型的架构和学习方式。

# 模型架构超参数
d_model = 128
n_layers = 4
n_heads = 4
block_size = 64
rms_norm_eps = 1e-5
rope_theta = 10000.0

# MoE 特定超参数
num_local_experts = 4
num_experts_per_tok = 2
intermediate_size_expert = d_model * 2
intermediate_size_shared = d_model * 2

# 训练超参数
learning_rate = 5e-4
batch_size = 16
epochs = 3000
eval_interval = 300

assert d_model % n_heads == 0
d_k = d_model // n_heads
expert_dim = intermediate_size_expert
shared_expert_dim = intermediate_size_shared

这些值比真实模型小得多,以便在典型硬件上快速运行。

训练数据准备

语言模型通过预测给定之前分词的下一个分词来学习。我们在 full_data_sequence 上滑动一个长度为 block_size 的窗口。

all_x = []
all_y = []
num_total_tokens = len(full_data_sequence)
for i in range(num_total_tokens - block_size):
    x_chunk = full_data_sequence[i : i + block_size]
    y_chunk = full_data_sequence[i + 1: i + block_size + 1]
    all_x.append(x_chunk)
    all_y.append(y_chunk)

train_x = torch.stack(all_x)
train_y = torch.stack(all_y)
num_sequences_available = train_x.shape[0]
print(f"创建了 {num_sequences_available} 个重叠的输入/目标序列对。")

从 593 个字符中提取出 529 个长度为 64 的重叠序列。

批量策略

使用 mini-batch 进行训练。在每个训练步骤中,随机选择 batch_size 个索引。

if num_sequences_available < batch_size:
    print(f"警告:序列数量 ({num_sequences_available}) 小于批量大小 ({batch_size})。正在调整批量大小。")
    batch_size = num_sequences_available
print(f"数据已准备好用于训练。将随机抽取大小为 {batch_size} 的批量。")

模型组件初始化

嵌入层

将整数分词 ID 转换为大小为 d_model 的密集向量。

token_embedding_table = nn.Embedding(vocab_size, d_model).to(device)
print(f"初始化分词嵌入层:权重形状 {token_embedding_table.weight.shape}")

RoPE 预计算

Transformer 本身不理解词序。RoPE 根据位置旋转 Q 和 K 向量。

rope_freq_indices = torch.arange(0, d_k, 2, dtype=torch.float, device=device)
inv_freq = 1.0 / (rope_theta ** (rope_freq_indices / d_k))
print("预计算的 RoPE 逆频率 (inv_freq):", inv_freq[:5].tolist())

RMSNorm 层

LLaMA 使用 RMSNorm,比标准层归一化更简单。

rmsnorm_weights_input = []
rmsnorm_weights_post_attn = []
for i in range(n_layers):
    weight_in = nn.Parameter(torch.ones(d_model, device=device))
    rmsnorm_weights_input.append(weight_in)
    weight_post = nn.Parameter(torch.ones(d_model, device=device))
    rmsnorm_weights_post_attn.append(weight_post)
final_rmsnorm_weight = nn.Parameter(torch.ones(d_model, device=device))

注意力层(MHA)

初始化 QKV 投影和输出投影线性层。

mha_qkv_linears = []
mha_output_linears = []
for i in range(n_layers):
    qkv_linear = nn.Linear(d_model, 3 * d_model, bias=False).to(device)
    mha_qkv_linears.append(qkv_linear)
    output_linear = nn.Linear(d_model, d_model, bias=False).to(device)
    mha_output_linears.append(output_linear)

混合专家(MoE)层

这是特殊的部分。包含路由器、专家 MLP 和共享专家。

moe_routers = []
moe_expert_gate_up_proj = []
moe_expert_down_proj = []
shared_expert_gate_proj = []
shared_expert_up_proj = []
shared_expert_down_proj = []
activation_fn = nn.SiLU()

for i in range(n_layers):
    # 路由器
    router_linear = nn.Linear(d_model, num_local_experts, bias=False).to(device)
    moe_routers.append(router_linear)
    
    # 专家权重
    gate_up_w = nn.Parameter(torch.empty(num_local_experts, d_model, 2 * expert_dim, device=device))
    nn.init.normal_(gate_up_w, mean=0.0, std=0.02)
    moe_expert_gate_up_proj.append(gate_up_w)
    down_w = nn.Parameter(torch.empty(num_local_experts, expert_dim, d_model, device=device))
    nn.init.normal_(down_w, mean=0.0, std=0.02)
    moe_expert_down_proj.append(down_w)
    
    # 共享专家
    shared_gate = nn.Linear(d_model, shared_expert_dim, bias=False).to(device)
    shared_up = nn.Linear(d_model, shared_expert_dim, bias=False).to(device)
    shared_down = nn.Linear(shared_expert_dim, d_model, bias=False).to(device)
    shared_expert_gate_proj.append(shared_gate)
    shared_expert_up_proj.append(shared_up)
    shared_expert_down_proj.append(shared_down)

最终输出层

将隐藏状态投影到词汇表大小。

output_linear_layer = nn.Linear(d_model, vocab_size, bias=False).to(device)

因果掩码

仅解码器 Transformer 需要因果掩码,确保只能关注当前位置及之前的位置。

causal_mask = torch.tril(torch.ones(block_size, block_size, device=device))
causal_mask = causal_mask.view(1, 1, block_size, block_size)

训练设置

收集所有需要梯度的参数,定义优化器和损失函数。

all_model_parameters = list(token_embedding_table.parameters())
all_model_parameters.extend(rmsnorm_weights_input)
all_model_parameters.extend(rmsnorm_weights_post_attn)
all_model_parameters.append(final_rmsnorm_weight)
for i in range(n_layers):
    all_model_parameters.extend(list(mha_qkv_linears[i].parameters()))
    all_model_parameters.extend(list(mha_output_linears[i].parameters()))
    all_model_parameters.extend(list(moe_routers[i].parameters()))
    all_model_parameters.extend(moe_expert_gate_up_proj)
    all_model_parameters.extend(moe_expert_down_proj)
    all_model_parameters.extend(list(shared_expert_gate_proj[i].parameters()))
    all_model_parameters.extend(list(shared_expert_up_proj[i].parameters()))
    all_model_parameters.extend(list(shared_expert_down_proj[i].parameters()))
all_model_parameters.extend(list(output_linear_layer.parameters()))

optimizer = optim.AdamW(all_model_parameters, lr=learning_rate)
criterion = nn.CrossEntropyLoss()

训练模型

迭代地向模型输入批量数据,计算损失并更新参数。

print(f"\n--- 开始训练循环,共 {epochs} 个周期 ---")
losses = []
for epoch in range(epochs):
    xb, yb = train_x[torch.randint(0, num_sequences_available, (batch_size,))].to(device), \
             train_y[torch.randint(0, num_sequences_available, (batch_size,))].to(device)
    
    token_embed = token_embedding_table(xb)
    position_ids = torch.arange(xb.shape[1], device=device).unsqueeze(0)
    freqs_cis = torch.polar(torch.ones_like(position_ids), (inv_freq.unsqueeze(0).unsqueeze(-1).expand(xb.shape[0], -1, 1).float() @ position_ids.unsqueeze(1).expand(xb.shape[0], -1).float()).transpose(1, 2))
    
    x = token_embed
    for i in range(n_layers):
        # RMSNorm 和注意力
        x_norm = (x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + rms_norm_eps)) * rmsnorm_weights_input[i]
        qkv = mha_qkv_linears[i](x_norm).view(xb.shape[0], xb.shape[1], n_heads, 3 * d_k).chunk(3, dim=-1)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q_rope, k_rope = q.float().reshape(xb.shape[0], xb.shape[1], n_heads, -1, 2), k.float().reshape(xb.shape[0], xb.shape[1], n_heads, -1, 2)
        q, k = torch.view_as_real(torch.view_as_complex(q_rope) * freqs_cis.unsqueeze(2)).flatten(3), \
               torch.view_as_real(torch.view_as_complex(k_rope) * freqs_cis.unsqueeze(2)).flatten(3)
        attn_scores = (q @ k.transpose(-2, -1)) * (d_k ** -0.5)
        attn_scores = attn_scores.masked_fill(causal_mask[:, :, :xb.shape[1], :xb.shape[1]] == 0, float('-inf'))
        attention_weights = F.softmax(attn_scores, dim=-1)
        attn_output = attention_weights @ v
        x = x + mha_output_linears[i](attn_output.permute(0, 2, 1, 3).contiguous().view(xb.shape[0], xb.shape[1], d_model))
        
        # MoE 块
        x_norm = (x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + rms_norm_eps)) * rmsnorm_weights_post_attn[i]
        router_logits = moe_routers[i](x_norm)
        routing_weights, selected_experts = torch.sigmoid(torch.topk(router_logits, num_experts_per_tok, dim=-1)[0]), \
                                            torch.topk(router_logits, num_experts_per_tok, dim=-1)[1]
        x_flat = x_norm.view(-1, d_model)
        selected_experts_flat = selected_experts.view(-1)
        routing_weights_flat = routing_weights.view(-1)
        token_idx = torch.arange(xb.shape[0] * xb.shape[1], device=device).repeat_interleave(num_experts_per_tok)
        expert_inputs = x_flat[token_idx]
        gate_up_states = torch.bmm(expert_inputs.unsqueeze(1), moe_expert_gate_up_proj[i][selected_experts_flat])
        activated_states = activation_fn(gate_up_states.chunk(2, dim=-1)[0]) * gate_up_states.chunk(2, dim=-1)[1]
        expert_outputs_weighted = torch.bmm(activated_states, moe_expert_down_proj[i][selected_experts_flat]).squeeze(1) * \
                                  routing_weights_flat.unsqueeze(-1)
        combined_expert_outputs = torch.zeros_like(x_flat)
        combined_expert_outputs.scatter_add_(0, token_idx.unsqueeze(-1).expand(-1, d_model), expert_outputs_weighted)
        shared_output = shared_expert_down_proj[i](activation_fn(shared_expert_gate_proj[i](x_norm)) * shared_expert_up_proj[i](x_norm))
        x = x + combined_expert_outputs.view(xb.shape[0], xb.shape[1], d_model) + shared_output
    
    logits = output_linear_layer((x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + rms_norm_eps)) * final_rmsnorm_weight)
    loss = criterion(logits.view(-1, logits.shape[-1]), yb.view(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    
    if epoch % eval_interval == 0 or epoch == epochs - 1:
        print(f" 第 {epoch+1}/{epochs} 个周期,损失:{loss.item():.4f}")
print("--- 训练循环完成 ---")

文本生成

模型训练完成后,我们可以尝试生成文本。将模型设置为评估模式并使用 torch.no_grad()。

print("\n--- 文本生成 ---")
seed_chars = "Alice "
num_tokens_to_generate = 200

seed_ids = [char_to_int[ch] for ch in seed_chars if ch in char_to_int]
generated_sequence = torch.tensor([seed_ids], dtype=torch.long, device=device)

with torch.no_grad():
    for _ in range(num_tokens_to_generate):
        current_context = generated_sequence[:, -block_size:]
        B_gen, T_gen = current_context.shape
        token_embed_gen = token_embedding_table(current_context)
        # ... (此处省略部分前向传播细节以保持简洁,逻辑同训练循环) ...
        # 实际实现需完整复制前向传播逻辑
        next_token = torch.multinomial(F.softmax(logits_gen[:, -1, :], dim=-1), num_samples=1)
        generated_sequence = torch.cat((generated_sequence, next_token), dim=1)

final_generated_ids = generated_sequence[0].tolist()
decoded_text = ''.join([int_to_char.get(id_val, '[UNK]') for id_val in final_generated_ids])
print("\n--- 最终生成的文本 ---")
print(decoded_text)

从 "Alice " 开始,模型生成了接下来的 200 个字符,展示了其学习到的文本风格和内容。

结论

我们完成了以下工作:

  1. 设置和分词:环境搭建和字符级分词。
  2. 超参数定义:适配硬件的配置值。
  3. 数据准备:创建输入/目标序列。
  4. 模型初始化:显式创建嵌入、RMSNorm、注意力、RoPE、MoE 等组件。
  5. 训练循环:实现完整的前向传播、损失计算和优化器步骤。
  6. 文本生成:在评估模式下进行自回归采样。

这个简化版模型成功展示了 MoE 层、RMSNorm 和 RoPE 如何协同工作,为大语言模型的底层结构提供了清晰的实战参考。

目录

  1. 从零开始用 Python 复现 LLaMA 4 MoE 架构
  2. MoE 架构概述
  3. 搭建环境
  4. 设备配置
  5. 定义训练语料库
  6. 字符级分词
  7. 编码语料库
  8. 定义超参数
  9. 模型架构超参数
  10. MoE 特定超参数
  11. 训练超参数
  12. 训练数据准备
  13. 批量策略
  14. 模型组件初始化
  15. 嵌入层
  16. RoPE 预计算
  17. RMSNorm 层
  18. 注意力层(MHA)
  19. 混合专家(MoE)层
  20. 最终输出层
  21. 因果掩码
  22. 训练设置
  23. 训练模型
  24. 文本生成
  25. 结论
  • 💰 8折买阿里云服务器限时8折了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • 从零开始用 Python 复现 LLaMA 4 MoE 架构
  • OpenClaw 本地 AI 助手安装与配置实战指南
  • Windows 系统安装配置 Neo4j 图数据库图文教程
  • 从零开始利用 Python 构建 LLaMA 4 MoE 模型详解
  • DeepSeek-R1 使用技巧:如何平衡深度思考与回复质量
  • 无需拓展插件:Copilot 接入第三方 OpenAI 接口方案
  • 基于Python的轻量级上位机开发流程解析
  • 无人机 5.8G 模拟图传电路设计方案及性能分析
  • Agent 智能体开发框架对比:主流方案选型指南
  • Windows 权限提升:滥用 Windows 服务提权(上)
  • 通义灵码企业知识库 RAG 五大核心应用场景
  • Windows 安装 Python 后 CMD 命令行无法识别
  • 机器人视觉感知系统:YOLOv8 与 ROS 集成应用指南
  • Linux 基础 IO 解析:从 C 库函数到系统调用,理解文件操作本质
  • C++ 红黑树原理与实现:平衡规则、旋转操作及代码详解
  • Whisper-Large-V3-Turbo:语音识别技术架构与性能分析
  • 如何修改 Conda 环境的 Python 版本
  • Linux 系统权限概念与操作详解
  • 基于DamoFD-0.5G的AR虚拟试妆系统
  • C++ 容器适配器:优先级队列与反向迭代器实现原理

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online