跳到主要内容
动手学 PyTorch:从线性回归到图像分类 | 极客日志
Python AI 算法
动手学 PyTorch:从线性回归到图像分类 从线性回归开始,一步步构建神经网络,并用卷积网络对 MNIST 手写数字进行分类。代码全部基于 PyTorch,配有数据生成、训练和可视化。最后对比了 Python 与 Rust 在 AI 开发中的差异,给出兼顾两者的学习建议。整个流程覆盖了 AI 入门的核心概念,适合有编程背景的读者快速上手。
动手学 PyTorch:从线性回归到图像分类
环境准备
先装好 PyTorch 和常用库:
pip install torch torchvision numpy matplotlib
线性回归:AI 版的 Hello World
最简单的起步是线性回归。生成一些带噪声的数据,用模型拟合一条线:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
x = torch.linspace(0 , 10 , 100 ).unsqueeze(1 )
y = 2 * x + 1 + torch.randn(100 , 1 ) * 0.5
class LinearModel (nn.Module):
def __init__ (self ):
super ().__init__()
self .linear = nn.Linear(1 , 1 )
def forward (self, x ):
return self .linear(x)
model = LinearModel()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01 )
for epoch in range (100 ):
output = model(x)
loss = criterion(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1 ) % == :
( )
torch.no_grad():
predicted = model(x)
plt.scatter(x.numpy(), y.numpy(), label= )
plt.plot(x.numpy(), predicted.numpy(), , label= )
plt.legend()
plt.show()
( )
10
0
print
f'Epoch [{epoch+1 } /100], Loss: {loss.item():.4 f} '
with
'Original data'
'r-'
'Fitted line'
print
'线性回归训练完成'
这段代码展示了 PyTorch 的核心流程:定义模型、选损失函数和优化器,然后反复前向计算、反向传播、更新参数。损失从初始的几十迅速降到个位数,画出来的红线也基本贴合数据分布。
增加点非线性:神经网络 现实世界的数据通常不是直线能搞定的。用神经网络来拟合二次函数:
x = torch.linspace(-1 , 1 , 100 ).unsqueeze(1 )
y = x.pow (2 ) + 0.2 * torch.randn(100 , 1 )
class NeuralNet (nn.Module):
def __init__ (self ):
super ().__init__()
self .hidden = nn.Linear(1 , 10 )
self .output = nn.Linear(10 , 1 )
def forward (self, x ):
x = torch.relu(self .hidden(x))
x = self .output(x)
return x
model = NeuralNet()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01 )
for epoch in range (1000 ):
out = model(x)
loss = criterion(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1 ) % 100 == 0 :
print (f'Epoch [{epoch+1 } /1000], Loss: {loss.item():.4 f} ' )
with torch.no_grad():
predicted = model(x)
plt.scatter(x.numpy(), y.numpy(), label='Original data' )
plt.plot(x.numpy(), predicted.numpy(), 'r-' , label='Neural network prediction' )
plt.legend()
plt.show()
隐藏层有10个神经元,用 ReLU 激活后,网络就能捕捉到曲线的弯曲。1000 轮训练后损失降到很低,预测曲线和实际抛物线基本吻合。相比线性回归,神经网络引入非线性变换,表达能力更强。
图像分类:卷积网络攻 MNIST MNIST 是入门图像分类的必备数据集,包含 28×28 的手写数字灰度图。我们用卷积神经网络来识别它们。
加载并预览数据 import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5 ,), (0.5 ,))
])
trainset = torchvision.datasets.MNIST(root='./data' , train=True , download=True , transform=transform)
trainloader = DataLoader(trainset, batch_size=64 , shuffle=True )
testset = torchvision.datasets.MNIST(root='./data' , train=False , download=True , transform=transform)
testloader = DataLoader(testset, batch_size=64 , shuffle=False )
import matplotlib.pyplot as plt
import numpy as np
def imshow (img ):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1 , 2 , 0 )))
plt.show()
dataiter = iter (trainloader)
images, labels = next (dataiter)
imshow(torchvision.utils.make_grid(images[:4 ]))
print ('真实标签:' , labels[:4 ].tolist())
搭建卷积网络 import torch.nn.functional as F
class Net (nn.Module):
def __init__ (self ):
super ().__init__()
self .conv1 = nn.Conv2d(1 , 32 , 3 , 1 )
self .conv2 = nn.Conv2d(32 , 64 , 3 , 1 )
self .pool = nn.MaxPool2d(2 , 2 )
self .fc1 = nn.Linear(64 * 7 * 7 , 128 )
self .fc2 = nn.Linear(128 , 10 )
def forward (self, x ):
x = self .pool(F.relu(self .conv1(x)))
x = self .pool(F.relu(self .conv2(x)))
x = x.view(-1 , 64 * 7 * 7 )
x = F.relu(self .fc1(x))
x = self .fc2(x)
return x
net = Net()
print (net)
注意:这里全连接层的输入是 64 * 7 * 7,因为 MNIST 图像 28×28,经过两次 2×2 最大池化后变成 7×7,通道数为 64。如果写成 64 * 12 * 12 会报尺寸不匹配的错误。
训练与测试 import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001 , momentum=0.9 )
epochs = 5
for epoch in range (epochs):
running_loss = 0.0
for i, data in enumerate (trainloader, 0 ):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99 :
print (f'[{epoch+1 } , {i+1 } ] loss: {running_loss / 100 :.3 f} ' )
running_loss = 0.0
print ('训练完成' )
训练时每100个batch打印一次平均损失,通常从2.x降到0.0x。5个epoch在MNIST上就能达到不错的效果。
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max (outputs, 1 )
total += labels.size(0 )
correct += (predicted == labels).sum ().item()
print (f'测试准确率:{100 * correct / total:.2 f} %' )
我在自己机器上跑,5个epoch后准确率大概在98%上下。对于这么简单的网络,已经很不错了。
dataiter = iter (testloader)
images, labels = next (dataiter)
imshow(torchvision.utils.make_grid(images))
print ('真实标签:' , labels[:4 ].tolist())
outputs = net(images)
_, predicted = torch.max (outputs, 1 )
print ('预测标签:' , predicted[:4 ].tolist())
如果从 Rust 转战 Python AI 写惯了 Rust 再来写 Python AI,有几个地方需要适应:
开发效率与生态 :Python 不用自己管内存,又有 PyTorch 这种高度封装的库,几行代码就搭好训练流程。Rust 生态在 AI 方面还在追赶,绑定的库用起来不够顺手。
性能认知 :Python 本身慢,但 PyTorch 底层是 C++/CUDA,运算部分不慢。频繁的数据搬移和 Python 循环才是瓶颈,所以把数据尽量留在张量里,或者用 DataLoader 多线程加载。
类型观念 :动态类型写起来快,但 IDE 提示不如 Rust 精确。可以补上类型注解,配合 mypy 用起来会舒服很多。
工程化结合 :Rust 的性能和安全性很适合做数据处理、部署推理服务;Python 适合快速实验和原型。两者可以配合,PyO3 之类的工具能让 Rust 扩展模块直接在 Python 里用。
一点体会 走完线性回归、神经网络到卷积分类这一圈,基本的 AI 训练流程就熟悉了。PyTorch 的模式都差不多:准备数据、定义模型、选损失函数和优化器、写训练循环。动态类型刚开始容易写出低效的循环,得时刻想着'向量化'。Rust 用户可能会不适应 Python 的慢和'宽松',但别纠结,先把想法跑起来再说——毕竟连官方自己都强调'Write Python, run in C++ speed'。等你需要高性能单点或者给队友 Python 脚本加速时,再把 Rust 搬出来也不迟。
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
随机西班牙地址生成器 随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
Gemini 图片去水印 基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online