跳到主要内容
极客日志极客日志
首页博客AI提示词GitHub精选代理工具
搜索
|注册
博客列表
PythonAI算法

PyTorch 实战:加载模型权重与 ONNX 推理部署

PyTorch 模型训练完成后,如何加载权重并进行跨平台推理?本文演示了实例化网络结构、加载 state_dict 以及调用 eval() 进入评估模式的关键步骤。针对多语言环境部署需求,通过 torch.onnx.export 导出模型为 ONNX 格式,并利用 onnxruntime 在 Python 环境中完成推理预测。结合 FashionMNIST 数据集示例,展示了从模型加载到输出分类结果的完整流程,解决了不同框架间模型共享与加速的问题。

猫巷少女发布于 2025/1/19更新于 2026/5/13 浏览
PyTorch 实战:加载模型权重与 ONNX 推理部署

加载和运行模型预测

加载模型

这一节我们重点看如何加载持久化的参数状态,并进行模型推断。首先需要定义模型类,它包含了神经网络的结构信息。

import torch
from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

实例化模型类后,就可以加载保存好的权重文件了。这里有个关键点:在推理前务必调用 model.eval()。这会将 Dropout 和批量归一化层切换到评估模式,否则结果可能会因为随机性而不一致。

model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()

模型推理与 ONNX 导出

把神经网络放到各种平台和硬件上运行往往很麻烦,不同框架间的性能调优也很耗时。ONNX (Open Neural Network Exchange) 提供了一种通用格式,支持跨语言和跨设备推理。比如你可以用它在 Java、C# 或 ML.NET 上跑模型。

PyTorch 原生支持导出 ONNX。由于 PyTorch 是动态图,导出时需要传入一个固定大小的张量作为输入来追踪计算图。通常创建一个合适尺寸的零张量即可。

import torch.onnx as onnx

input_image = torch.zeros((1, 28, 28))
onnx_model = 
onnx.export(model, input_image, onnx_model)
'data/model.onnx'

使用 ONNX Runtime 进行预测

导出成功后,我们可以用 onnxruntime 来加载模型并跑通一次完整的预测流程。这里以 FashionMNIST 数据集为例,取第一张图进行测试。

import onnxruntime
from torchvision import datasets
from torchvision.transforms import ToTensor

test_data = datasets.FashionMNIST(
    root="data", train=False, download=True, transform=ToTensor()
)
classes = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]

x, y = test_data[0][0], test_data[0][1]
session = onnxruntime.InferenceSession(onnx_model, None)

input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
result = session.run([output_name], {input_name: x.numpy()})

predicted = classes[result[0][0].argmax(0)]
actual = classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')

输出结果如下:

Predicted: "Ankle boot", Actual: "Ankle boot"

可以看到,模型成功识别出了短靴。通过这种方式,我们实现了在不同环境下复用训练好的模型。

完整代码参考

为了方便大家直接运行,这里汇总了上述所有步骤的代码。

import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
from torchvision import datasets
from torchvision.transforms import ToTensor

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

# 加载模型
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()

# 导出 ONNX
input_image = torch.zeros((1, 28, 28))
onnx.export(model, input_image, 'data/model.onnx')

# 推理
from torchvision import datasets
from torchvision.transforms import ToTensor

test_data = datasets.FashionMNIST(
    root="data", train=False, download=True, transform=ToTensor()
)
classes = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]

x, y = test_data[0][0], test_data[0][1]
session = onnxruntime.InferenceSession('data/model.onnx', None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
result = session.run([output_name], {input_name: x.numpy()})

predicted = classes[result[0][0].argmax(0)]
actual = classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')

总结

本节完成了模型从加载到跨平台推理的闭环。我们不仅掌握了 load_state_dict 和 eval() 的使用细节,还学会了利用 ONNX 解决部署兼容性问题。结合之前的张量操作、数据预处理和自动微分知识,现在你已经具备了构建基本图像分类模型的能力。

目录

  1. 加载和运行模型预测
  2. 加载模型
  3. 模型推理与 ONNX 导出
  4. 使用 ONNX Runtime 进行预测
  5. 完整代码参考
  6. 加载模型
  7. 导出 ONNX
  8. 推理
  9. 总结
  • 💰 8折买阿里云服务器限时8折了解详情
  • GPT-5.5 超高智商模型1元抵1刀ChatGPT中转购买
  • 代充Chatgpt Plus/pro 帐号了解详情
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • 开源模型 Mistral 与 Qwen Prompt 实验报告
  • 国内环境部署 n8n 与私有 AI 模型实战指南
  • 基于 Coze Skills 与 OpenClaw 的 AI 智能体自动化实战指南
  • C++可变参数队列与压栈顺序:模板语法及汇编调用约定
  • LeetCode 160 相交链表
  • 2025 年 Java 与 AI 技术融合学习路线
  • 基于 Rust+Tauri 构建带安全沙箱的跨平台清理工具
  • SpringBoot 整合 Neo4j 图数据库实战指南
  • YOLO 模型 TensorRT C++ 推理实战指南
  • Scapy 详细安装教程、功能介绍与快速上手
  • Python 趣味小游戏代码示例:吃金币、打乒乓等 13 款
  • 大模型应用元年,有哪些场景可以实际落地?
  • Python 依赖注入(DI)实战:三种实现方式、代价权衡与可测试性案例
  • 学习 Python 对软件测试有哪些优势?
  • MATLAB 实现基于强制导向函数法(PFA)的无人机三维路径规划项目实例
  • 使用 Python SDK 调用 Coze 工作流详解
  • 基于 DeepSeek 和 Cursor 构建智能代码审查工具实践
  • 机器人轨迹规划详解:从概念到常用方法
  • 无线联邦学习:隐私保护下的 AI 协同训练
  • 华为交换机首次开局配置完整步骤(Console + Web)

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online