PyTorch 从零训练大模型实战:Transformer 架构与训练流程详解
基于 PyTorch 从零构建 Transformer 大模型的完整流程,涵盖前馈网络、归一化层设计、编码器与解码器堆叠、训练循环与推理测试,深入理解 Attention 机制及架构内部原理。

基于 PyTorch 从零构建 Transformer 大模型的完整流程,涵盖前馈网络、归一化层设计、编码器与解码器堆叠、训练循环与推理测试,深入理解 Attention 机制及架构内部原理。

前馈网络利用深度神经网络结构,通过两层线性变换来捕捉嵌入向量的特征。第一层将维度从 d_model 扩展到 d_ff,第二层将其映射回 d_model。通常 d_ff 设置为 4 * d_model。
在第一层之后,引入 ReLU 激活函数赋予模型非线性特性,并通过 dropout 技术减少过拟合的风险。
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
通过层归一化处理,可以确保网络中嵌入向量的值分布均衡,从而促进模型的稳定学习。此外,引入 gamma 和 beta 两个可学习的参数,以便对嵌入值进行动态的缩放和平移调整。
class LayerNorm(nn.Module):
def __init__(self, features, eps=1e-6):
super().__init__()
self.alpha = nn.Parameter(torch.ones(features))
self.beta = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True)
return self.alpha * (x - mean) / (std + self.eps) + self.beta
此模块结合了跳跃连接和层归一化技术。在模型的前向传播中,跳跃连接帮助保留早期层学到的特征,使这些特征能在网络的深层中发挥作用。
在反向传播过程中,跳跃连接减少了梯度消失的问题,因为它允许梯度在反向传递时跳过某些层。
无论是编码器中的两次应用,还是解码器中的三次应用,加法归一化都先对输入进行归一化处理,再将其与前一层的输出相加,以此来丰富模型的表达能力。
class AddNorm(nn.Module):
def __init__(self, features, dropout, eps=1e-6):
super().__init__()
self.norm = LayerNorm(features, eps)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, mask=None):
normed = self.norm(x)
return self.dropout(normed) + x
编码器的核心是两个主要组件:多头注意力机制和前馈网络。除此之外,每个编码器块还包含两个加法归一化单元,负责调整和规范化信息流。我们依照注意力机制论文中的指导,将这些组件整合到 EncoderBlock 类中,并重复此结构 6 次以深化学习效果。
class EncoderBlock(nn.Module):
def __init__(self, features, heads, dropout, d_ff):
super().__init__()
self.attention = MultiHeadAttention(features, heads, dropout)
self.feed_forward = FeedForward(features, d_ff, dropout)
self.add_norm_1 = AddNorm(features, dropout)
self.add_norm_2 = AddNorm(features, dropout)
def forward(self, x, mask):
attn_output = self.attention(x, x, x, mask)
x = self.add_norm_1(x, attn_output)
ff_output = self.feed_forward(x)
x = self.add_norm_2(x, ff_output)
return x
在编码器块的基础上,进一步创建 Encoder 类,接收一系列编码器块并有序堆叠起来。这一整合过程不仅增强了信息的流通性,还确保了网络能够输出高质量的编码结果,为后续的解码过程打下坚实基础。
class Encoder(nn.Module):
def __init__(self, layers, features, heads, dropout, d_ff):
super().__init__()
self.layers = nn.ModuleList([layers(features, heads, dropout, d_ff) for _ in range(layers)])
self.norm = LayerNorm(features)
def forward(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
解码器块由三大核心组件构成:掩码多头注意力、标准多头注意力以及前馈网络。每个解码器块还包含三个加法归一化单元,用以优化信息处理流程。根据论文,将这些组件精心组合在 DecoderBlock 类中,并重复此结构 6 次,以丰富模型的解码能力。
class DecoderBlock(nn.Module):
def __init__(self, features, heads, dropout, d_ff):
super().__init__()
self.mask_attn = MaskedMultiHeadAttention(features, heads, dropout)
self.cross_attn = MultiHeadAttention(features, heads, dropout)
self.feed_forward = FeedForward(features, d_ff, dropout)
self.add_norm_1 = AddNorm(features, dropout)
self.add_norm_2 = AddNorm(features, dropout)
self.add_norm_3 = AddNorm(features, dropout)
def forward(self, x, enc_output, src_mask, tgt_mask):
attn1 = self.mask_attn(x, x, x, tgt_mask)
x = self.add_norm_1(x, attn1)
attn2 = self.cross_attn(x, enc_output, enc_output, src_mask)
x = self.add_norm_2(x, attn2)
ff_out = self.feed_forward(x)
x = self.add_norm_3(x, ff_out)
return x
在解码器块的基础上,构建 Decoder 类,接收一系列解码器块并进行有效堆叠,以实现连续的信息处理和特征提取,最终生成解码器的输出。
class Decoder(nn.Module):
def __init__(self, layers, features, heads, dropout, d_ff):
super().__init__()
self.layers = nn.ModuleList([layers(features, heads, dropout, d_ff) for _ in range(layers)])
self.norm = LayerNorm(features)
def forward(self, x, enc_output, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, enc_output, src_mask, tgt_mask)
return self.norm(x)
解码器的最终输出将进入投影层进行进一步处理。在这一层,输出首先通过一个线性层进行变换,以适应模型的输出需求。紧接着,应用 softmax 函数将输出转化为词汇表上的概率分布,从而选出概率最高的标记作为模型的预测结果。
class Projection(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.proj = nn.Linear(d_model, vocab_size)
def forward(self, x):
return self.proj(x)
至此,已经完成了 Transformer 架构中的所有组件块的构建工作。剩余的任务是将它们全部组装起来。
首先,创建一个 Transformer 类,以初始化所有组件类的实例。在 Transformer 类中,先定义一个编码函数,该函数执行 Transformer 编码部分的所有任务,并生成编码器输出。
其次,定义一个解码函数,该函数执行 Transformer 解码部分的所有任务,并生成解码器输出。
第三,定义一个投影函数,接收解码器输出,并将输出映射到词汇表以进行预测。
现在,Transformer 架构准备就绪,可以通过定义函数来构建翻译语言模型(LLM),该函数接收如下代码中给出的所有必要参数。
class Transformer(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, projection):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.projection = projection
def encode(self, src, src_mask):
return self.encoder(self.src_embed(src), src_mask)
def decode(self, enc_output, src_mask, tgt, tgt_mask):
tgt_embed = self.tgt_embed(tgt)
return self.decoder(tgt_embed, enc_output, src_mask, tgt_mask)
def project(self, x):
return self.projection(x)
def forward(self, src, tgt, src_mask, tgt_mask):
enc_output = self.encode(src, src_mask)
dec_output = self.decode(enc_output, src_mask, tgt, tgt_mask)
return self.project(dec_output)
我们已经抵达了模型训练的关键阶段。这一过程其实颇为直接明了,使用之前在第三步中构建的训练 DataLoader 来执行训练任务。
鉴于训练数据集规模达到了百万,推荐在 GPU 上进行模型训练以提升效率。依据经验,完成 20 个 epoch 的训练大约耗时 5 小时。为了便于训练过程中的断点续传,我们在每个 epoch 结束时都会保存模型的权重和优化器的状态。
训练之后,紧接着的是验证环节。将动用规模为 2000 的验证 DataLoader 来执行这一任务,这一数据量设置是合理的。
在验证过程中,只在最初计算一次编码器的输出,随后便等待解码器输出句子结束的标记 [SEP]。这样的设计是因为在解码器接收到 [SEP] 标记之前,重复发送相同的编码器输出是无益的。
至于解码器的输入,则从句子的起始标记 [CLS] 开始。在每次预测之后,解码器的输入会追加上新生成的标记,直至遇到句子结束的 [SEP] 标记。最终,由投影层将这些输出映射转换为相应的文本表示,完成整个翻译过程。
def train_model(model, optimizer, criterion, loader, device):
model.train()
for batch_idx, batch in enumerate(loader):
src, tgt = batch['src'].to(device), batch['tgt'].to(device)
src_mask, tgt_mask = create_masks(src, tgt)
optimizer.zero_grad()
output = model(src, tgt, src_mask, tgt_mask)
loss = criterion(output.view(-1, output.size(-1)), tgt.view(-1))
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f"Batch {batch_idx}, Loss: {loss.item()}")
torch.save(model.state_dict(), 'checkpoint.pth')
为这个翻译功能命名为'malaygpt',专门用来处理英文到马来文的翻译任务。这个函数设计得非常直观:用户只需输入英文文本,它便能智能地输出相应的马来文翻译。接下来,启动这个函数,亲自体验翻译效果。
def translate(text, model, tokenizer, device):
model.eval()
tokens = tokenizer.encode(text)
input_tensor = torch.tensor(tokens).unsqueeze(0).to(device)
src_mask = create_mask(input_tensor)
output_tokens = []
while True:
output = model(input_tensor, None, src_mask, None)
pred_token = output.argmax(dim=-1)[0, -1].item()
output_tokens.append(pred_token)
if pred_token == tokenizer.sep_token_id:
break
new_input = torch.tensor([pred_token]).unsqueeze(0).to(device)
input_tensor = torch.cat([input_tensor, new_input], dim=1)
return tokenizer.decode(output_tokens)
进行一些翻译测试: 输入:Hello world 输出:Helo dunia
翻译得似乎还不错。
基于 PyTorch 从零构建 Transformer 大模型的完整流程,涵盖前馈网络、归一化层设计、编码器与解码器堆叠、训练循环与推理测试,深入理解 Attention 机制及架构内部原理。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online