跳到主要内容从零实现 LLaMA 架构:构建轻量级大语言模型 | 极客日志PythonAI算法
从零实现 LLaMA 架构:构建轻量级大语言模型
LLaMA 大语言模型的核心架构设计,包括 RMSNorm 归一化、SwiGLU 激活函数、RoPE 位置编码及 Pre-Norm 结构。通过 Python 代码从零实现轻量级 LLaMA-like 模型,涵盖配置管理、基础层、注意力机制及主模型构建。实战测试验证了自回归生成逻辑的正确性,为理解大模型底层原理及后续训练部署奠定基础。
活在当下0 浏览 一、LLaMA 核心设计亮点
先梳理 LLaMA 相对于经典 Transformer 的核心改进(也是本文实现的核心),为后续代码解析铺垫:
| 优化点 | 传统 Transformer | LLaMA 设计 | 优势 |
|---|
| 归一化层 | LayerNorm(含均值中心化 + 偏置) | RMSNorm(仅均方根归一化) | 计算更快,训练稳定性相当 |
| 前馈网络激活 | ReLU/GELU + 单线性层 | SwiGLU(门控激活) | 提升模型表达能力 |
| 位置编码 | 绝对位置编码 | 旋转位置编码(RoPE) | 更好的长序列泛化能力 |
| 归一化位置 | Post-Norm(注意力 / FFN 后) | Pre-Norm(注意力 / FFN 前) | 训练更稳定,梯度传播更顺畅 |
| 线性层偏置 | 带 bias | 无 bias | 减少参数规模,提升推理速度 |
二、代码架构总览
我们将模型拆解为 5 个职责清晰的核心文件,从基础组件到完整模型再到测试,层层递进:
| 文件名称 | 核心功能 |
|---|
config.py | 模型超参数管理(类型安全的 dataclass) |
layers.py | 基础层实现(RMSNorm、SwiGLU FeedForward、RoPE) |
attention.py | 因果自注意力层(集成 RoPE+Flash Attention) |
model.py | Transformer 块封装 + 完整 LLM 模型 |
main.py | 前向传播测试 + 自回归文本生成 |
三、逐模块解析代码
3.1 配置模块:config.py
模型超参数是大模型的'骨架',用 dataclass 可以简洁、类型安全地管理这些参数,方便后续扩展和修改:
from dataclasses import dataclass
@dataclass
class LLMConfig:
vocab_size: int = 32000
hidden_size: int = 1024
num_layers: int = 12
num_heads: =
intermediate_size: =
max_seq_len: =
rms_norm_eps: =
dropout: =
微信扫一扫,关注极客日志
微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
- Base64 字符串编码/解码
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
- Base64 文件转换器
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online
int
16
int
2816
int
2048
float
1e-5
float
0.1
intermediate_size:FFN 中间层维度选择 hidden_size * 8/3 是 LLaMA 的经验值,平衡模型表达能力和参数量;
rms_norm_eps:极小值(1e-5)避免均方根计算时除以 0;
max_seq_len:决定模型能处理的最长文本长度,也影响 RoPE 频率矩阵的预计算范围。
3.2 基础层模块:layers.py
这是模型的'基础组件库',实现了 LLaMA 最核心的三个基础层:RMSNorm、SwiGLU FeedForward、RoPE。
3.2.1 均方根归一化(RMSNorm)
RMSNorm 是 LLaMA 的核心优化之一,数学公式为:y=E[x^2]+ϵx×γ。相比 LayerNorm,它去掉了均值中心化和偏置项,计算效率更高:
import torch
import torch.nn as nn
import torch.nn.functional as F
class RMSNorm(nn.Module):
"""均方根归一化 (Root Mean Square Normalization)"""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
3.2.2 SwiGLU 前馈网络
LLaMA 的 FeedForward 层使用 SwiGLU 激活(替代传统 GELU),公式为:SwiGLU(x)=Swish(xW1)⊗(xW3)W2。其中 ⊗ 是逐元素相乘,SwiGLU 通过'门控机制'提升模型的非线性表达能力:
class FeedForward(nn.Module):
"""采用 SwiGLU 激活的基于门控的前馈神经网络"""
def __init__(self, config):
super().__init__()
self.w1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.w2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.w3 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
设计细节:所有线性层都去掉了 bias,这是 LLaMA 的核心设计之一,减少参数的同时提升训练稳定性。
3.2.3 旋转位置编码(RoPE)
RoPE 的核心是将位置信息编码为复数旋转角度,让 Query/Key 在注意力计算时随位置'旋转',既保留绝对位置信息,又具备相对位置的泛化能力。
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""预计算 RoPE 的频率矩阵(复数形式)"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
"""将 RoPE 应用到 Query/Key 上"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
precompute_freqs_cis:提前计算所有位置的旋转频率(复用性高,无需每次前向都计算);
apply_rotary_emb:将 Q/K 按两个维度为一组拆分为复数,与频率矩阵相乘实现'旋转',再转回实数。
3.3 注意力模块:attention.py
因果自注意力是 Transformer 的核心,LLaMA 的注意力层做了两大优化:QKV 合并映射(工程高效)、集成 Flash Attention(PyTorch 2.0 + 内置)。
import math
import torch
import torch.nn as nn
from layers import apply_rotary_emb
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.hidden_size % config.num_heads == 0
self.n_heads = config.num_heads
self.head_dim = config.hidden_size // config.num_heads
self.wqkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False)
self.wo = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
def forward(self, x, freqs_cis, mask=None):
B, T, C = x.size()
qkv = self.wqkv(x)
q, k, v = qkv.split(C, dim=2)
q = q.view(B, T, self.n_heads, self.head_dim)
k = k.view(B, T, self.n_heads, self.head_dim)
v = v.view(B, T, self.n_heads, self.head_dim)
q, k = apply_rotary_emb(q, k, freqs_cis[:T])
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=True if mask is None else False
)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.resid_dropout(self.wo(y))
scaled_dot_product_attention:PyTorch 2.0 + 内置接口,自动启用 Flash Attention,大幅降低显存占用、提升计算速度;
is_causal=True:自动生成因果掩码,避免手动构造掩码矩阵,代码更简洁。
3.4 Transformer 块与主模型:model.py
将注意力层和 FFN 层组合成 Transformer 块,再堆叠为完整的 LLM 模型,核心是Pre-Norm 架构和残差连接。
import torch
import torch.nn as nn
from config import LLMConfig
from layers import RMSNorm, FeedForward, precompute_freqs_cis
from attention import CausalSelfAttention
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = CausalSelfAttention(config)
self.feed_forward = FeedForward(config)
self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, x, freqs_cis):
h = x + self.attention(self.attention_norm(x), freqs_cis)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class LLM(nn.Module):
def __init__(self, config: LLMConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_heads, config.max_seq_len)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, tokens, targets=None):
B, T = tokens.size()
h = self.tok_embeddings(tokens)
for layer in self.layers:
h = layer(h, self.freqs_cis[:T])
h = self.norm(h)
logits = self.output(h)
loss = None
if targets is not None:
loss = nn.functional.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
return logits, loss
register_buffer:将 RoPE 频率矩阵注册为非训练参数,避免每次前向都重新计算;
Pre-Norm:归一化层在注意力 / FFN 之前,相比 Post-Norm,训练时梯度更稳定,无需额外的初始化技巧;
- 残差连接:每个子层(注意力 / FFN)的输出都与输入相加,保证梯度能有效传播到浅层。
temperature:温度越高,生成的随机性越强(logits 除以 temperature 后,概率分布更平缓);
torch.multinomial:多项式采样(相比 argmax 的'贪心采样',生成结果更丰富);
- 序列裁剪:每次生成前裁剪序列到
max_seq_len,避免超出模型的上下文长度限制。
四、实战运行与结果解读
将所有代码文件放在同一目录,运行 main.py,输出示例如下:
正在初始化 LLM 模型 (类 LLaMA 架构)...
模型参数量:0.85 M
前向传播测试:Loss = 6.9078, Logits Shape = torch.Size([1, 5, 1000])
开始生成文本...
原始输入:[10, 20, 30, 40, 50]
生成结果:[10, 20, 30, 40, 50, 88, 123, 45, 789, 23, 90, 111, 56, 89]
- 模型参数量约 0.85M,属于轻量级,可在 CPU 上快速测试;
- 初始 Loss≈6.9,符合预期(随机初始化的模型,Loss 接近
ln(vocab_size)=ln(1000)≈6.9);
- 生成的 token 序列是随机的(模型未训练),但验证了自回归生成逻辑的正确性。
五、总结
本文从 LLaMA 的核心设计出发,拆解并实现了一个轻量级的 LLaMA-like 模型,覆盖了 RMSNorm、SwiGLU、RoPE、因果自注意力等关键组件。
大模型看似复杂,但本质是'简单组件的有序组合'—— 掌握这些核心设计,就能理解大模型的底层逻辑,为后续的模型训练、优化和部署打下基础。
- 训练模型:用小数据集(如 WikiText)训练模型,观察 Loss 的下降趋势;
- 扩展参数:将
hidden_size 调至 4096、num_layers 调至 32,实现原版 LLaMA 7B 的架构;
- 部署推理:将模型导出为 ONNX/TensorRT,提升推理速度。