加载和运行模型预测
加载模型
这一节我们重点看如何加载持久化的参数状态,并进行模型推断。首先需要定义模型类,它包含了神经网络的结构信息。
import torch
from torch import nn
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
实例化模型类后,就可以加载保存好的权重文件了。这里有个关键点:在推理前务必调用 model.eval()。这会将 Dropout 和批量归一化层切换到评估模式,否则结果可能会因为随机性而不一致。
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()
模型推理与 ONNX 导出
把神经网络放到各种平台和硬件上运行往往很麻烦,不同框架间的性能调优也很耗时。ONNX (Open Neural Network Exchange) 提供了一种通用格式,支持跨语言和跨设备推理。比如你可以用它在 Java、C# 或 ML.NET 上跑模型。
PyTorch 原生支持导出 ONNX。由于 PyTorch 是动态图,导出时需要传入一个固定大小的张量作为输入来追踪计算图。通常创建一个合适尺寸的零张量即可。
import torch.onnx as onnx
input_image = torch.zeros((1, 28, 28))
onnx_model =
onnx.export(model, input_image, onnx_model)

