PyTorch 模型训练完整工作流程详解
机器学习和深度学习的本质是从历史数据中发现一般模式,然后用发现的模式预测未来的数据。在本文中,我们将以一个学习直线方程的例子说明用 PyTorch 训练模型的工作流程。
本文介绍了基于 PyTorch 进行机器学习模型训练的完整流程。涵盖数据准备与划分、模型构建(继承 nn.Module)、损失函数与优化器选择、训练循环实现、推理模式设置以及模型持久化保存与加载。通过线性回归示例演示了从数据生成到参数优化的全过程,并提供了完整的可运行代码结构,帮助开发者掌握 PyTorch 核心工作流及最佳实践。

机器学习和深度学习的本质是从历史数据中发现一般模式,然后用发现的模式预测未来的数据。在本文中,我们将以一个学习直线方程的例子说明用 PyTorch 训练模型的工作流程。
机器学习中的数据含义很广泛:文本、图像、视频、音频、表格、甚至是蛋白质结构都是数据。
我们创建一个线性回归数据集来演示流程:
import torch
# 设置随机种子以保证结果可复现
torch.manual_seed(42)
# 生成 50 个样本的 X 数据 (范围 0-1)
x = torch.rand(50, 1)
# 生成对应的 y 数据:y = 3x + 2 + noise
true_w = 3.0
true_b = 2.0
noise = torch.randn(50, 1) * 0.1
y = x * true_w + true_b + noise
机器学习最重要的一步是将你的数据集分成训练集、验证集(有时不需要)和测试集。
这里,我们仅将数据集分为训练集和测试集。在实际工作中,数据集在项目开始之前就被分好了,我们可以多次使用训练集,但是只能在最终训练完后,使用一次测试集测试模型的最终泛化性能。
from sklearn.model_selection import train_test_split
# 将数据划分为训练集和测试集 (8:2)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
print(f"训练集样本数:{len(x_train)}")
print(f"测试集样本数:{len(x_test)}")
现在我们要建立一个可以根据输入数据预测输出数据的模型。
我们首先看一下一些 PyTorch 基础模块,它们几乎都来自 torch.nn 模块:
torch.nn:包含计算图的所有构建模块。torch.nn.Parameter:存储可与 nn.Module 一起使用的张量参数。如果 requires_grad=True 则自动计算梯度。torch.nn.Module:所有神经网络模块的基类。如果你在 PyTorch 中构建神经网络,你的模型应该是 nn.Module 的子类,并且需要实现 forward() 方法。torch.optim:包含各种优化算法。forward() 方法:定义了对传递给特定 nn.Module 的数据进行的计算。通过子类化 nn.Module 创建 PyTorch 模型的基本构建模块。对于 nn.Module 子类的对象,必须定义 forward() 方法。
import torch.nn as nn
class LinearModel(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
# 定义线性层,等价于 y = wx + b
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
# 实例化模型
model = LinearModel(input_dim=1, output_dim=1)
# 查看模型参数
for name, param in model.named_parameters():
print(f"{name}: {param.shape}, requires_grad={param.requires_grad}")
由于我们使用了默认初始化,weights 和 bias 都是随机张量。
我们将测试集 X_test 的数据通过 forward() 方法得到模型的预测结果 y_preds。
with torch.inference_mode():
y_preds = model(x_test)
print("初始预测值:", y_preds[:5])
print("真实值:", y_test[:5])
在旧版本的 PyTorch 中,我们使用的是 torch.no_grad()。无论是哪种,我们都抑制了 PyTorch 自动计算梯度的功能,这有助于加速 forward 流程。由于我们模型参数都是随机初始化的,而且模型又没有经过训练,所以现在模型的预测性能显然很糟糕。
训练模型的过程其实就是不断更新我们之前由默认初始化或 nn.Parameter 初始化的模型参数 weights 和 bias。
损失函数(loss function) 负责衡量模型的预测值(y_preds)和真实值(y_tests)之间的差异,通常越小越好。torch.nn 模块内置了很多损失函数。一般对于回归任务,使用平均绝对误差(torch.nn.L1Loss());对于二分类任务,使用二元交叉熵损失(torch.nn.BCELoss())。
优化器(optimizer) 负责告诉模型如何调整内部参数以最小化损失函数。torch.optim 模块实现了许多优化函数。常见的优化函数有随机梯度下降(torch.optim.SGD())和 Adam(torch.optim.Adam())。
我们需要根据不同的任务选择合适的损失函数和优化器。由于我们现在是在预测数值,所以我们选择 torch.nn.L1Loss()。
# 定义损失函数
loss_fn = nn.L1Loss()
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
平均绝对误差测量两点(预测和标签)之间的绝对差,然后取所有样本的平均值。
我们将选择 torch.optim.SGD(params, lr):
现在我们可以创建训练循环(training loop)和测试循环(testing loop)。
模型在训练循环中遍历训练集中的样本,学习样本的特征和标签之间的关系。
典型的训练循环如下所示,首先我们需要设定模型训练的循环次数 epochs。然后在每个 epoch(也就是每次循环中):
epochs = 100
for epoch in range(epochs):
# 1. 前向传播
y_pred = model(x_train)
# 2. 计算损失
loss = loss_fn(y_pred, y_train)
# 3. 梯度清零
optimizer.zero_grad()
# 4. 反向传播
loss.backward()
# 5. 更新参数
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")
模型在测试循环中遍历测试集中的样本,衡量模型学习到关系在未知数据集中的表现。
测试循环就简单了,还是在每个 epoch(也就是每次循环中):
model.eval() 说明现在是在评估状态。torch.inference_mode() 中:隔几个 epoch(一般是 10)输出一下平均损失。
模型训练好之后就可以用来预测了,有的地方也称为推理(inference)。但是在此之前,我们需要做好以下三件事:
model.eval()。with torch.inference_mode(): … 上下文管理器中进行预测。model.eval()
with torch.inference_mode():
final_preds = model(x_test)
print("最终预测值:", final_preds[:5].detach())
print("真实值:", y_test[:5])
看起来结果已经是相当接近了。
有三种主要的方法可以保存和加载 PyTorch 模型:
torch.save:使用 Python 的 pickle 工具将对象序列化后保存到磁盘。模型、张量和其他 Python 对象都可以用这种方法保存。torch.load:将 torch.save 保存的文件反序列化之后加载到内存,我们还可以选择加载到 GPU 还是 CPU。torch.nn.Module.load_state_dict:加载保存在磁盘上的模型参数字典(用 model.state_dict() 方法可以获得)。不过根据 pickle 的官方文档介绍,pickle 是不安全的,所以你必须要相信你要加载的对象是安全的。
保存和加载模型的推荐方式是通过模型的 state_dict() 方法。
torch.save 的 obj 参数是模型的参数字典、f 参数是保存的路径。一般我们将 PyTorch 模型保存为以 .pt 或 .pth 这样的后缀。
不过以这种方式保存的模型仅仅是保存了模型训练之后的参数字典,而不是整个模型。所以我们首先需要初始化一个模型实例,然后用保存在磁盘上的模型参数字典更新这个初始化的模型参数。
# 保存模型参数
torch.save(model.state_dict(), 'linear_model.pth')
print("模型已保存")
看到上面的提示说明加载成功。
加载成功之后,我们再次用这个新加载的模型做预测,并将预测结果和之前训练的模型进行比较。
# 重新实例化模型
new_model = LinearModel(input_dim=1, output_dim=1)
# 加载参数
new_model.load_state_dict(torch.load('linear_model.pth'))
new_model.eval()
# 验证一致性
with torch.inference_mode():
loaded_preds = new_model(x_test)
print("加载后预测值:", loaded_preds[:5].detach())
print("原预测值:", final_preds[:5].detach())
结果显然是一致的。
根据官方文档,保存整个模型的缺点是序列化数据绑定到特定的类以及保存模型时使用的确切目录结构。这样做的原因是因为 pickle 不保存模型类本身。相反,它保存包含该类的文件的路径,该路径在加载时使用。因此,在其他项目中使用或重构后,您的代码可能会以各种方式损坏。
下面将前面所有的步骤整合到一个完整的脚本中,并支持设备(CPU/GPU)自动切换。
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
# 1. 确定当前可用的设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 2. 创建数据集
torch.manual_seed(42)
x = torch.rand(50, 1).to(device)
y = x * 3.0 + 2.0 + torch.randn(50, 1).to(device) * 0.1
# 3. 划分数据集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
# 4. 建立模型
class LinearModel(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
model = LinearModel(input_dim=1, output_dim=1).to(device)
# 5. 定义损失函数和优化器
loss_fn = nn.L1Loss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 6. 训练循环
epochs = 1000
for epoch in range(epochs):
model.train()
y_pred = model(x_train)
loss = loss_fn(y_pred, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 100 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")
# 7. 测试与预测
model.eval()
with torch.inference_mode():
test_pred = model(x_test)
print(f"Test Loss: {loss_fn(test_pred, y_test).item():.4f}")
# 8. 保存模型
torch.save(model.state_dict(), 'final_model.pth')
print("Training and Saving Complete.")
当我们将训练 epochs 从 100 增加到 1000 之后,模型学到的参数很逼近真实参数值了。现在红色点几乎和蓝色点重叠了,非常好!
本文详细拆解了 PyTorch 模型训练的标准工作流。从数据预处理、模型定义、损失计算、优化器配置到训练循环的实现,再到最终的模型持久化。掌握这一流程是深入深度学习开发的基础。开发者可根据具体任务调整网络结构、损失函数及超参数,以适应更复杂的场景。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online