1. 模型加载与保存
1.1 模型的加载与保存
path_model =
path_state_dict =
torch.save(net, path_model)
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)
PyTorch 训练技巧涵盖模型加载保存、断点续训、微调及 GPU 使用。支持保存整个模型或仅参数,推荐后者。断点记录优化器状态以便意外中断后恢复。微调包括固定预训练参数或使用不同学习率组替换全连接层。GPU 迁移通过 to() 函数实现。
path_model =
path_state_dict =
torch.save(net, path_model)
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)
保存模型可以有两种方式:1.保存整个模型;2.保存模型参数;官网推荐第二种方式,保存模型参数。
模型加载也分两种方式:
1.如果是保存整个模型,则直接加载整个模型
# ================================== load net =========================== #
flag = 1
# flag = 0
if flag:
path_model = "./model.pkl"
net_load = torch.load(path_model)
print(net_load)
2.如果是保存模型参数,则加载模型参数,并更新网络模型的模型参数
# ================================== load state_dict ===========================
flag = 1
# flag = 0
if flag:
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)
print(state_dict_load.keys())
# ================================== update state_dict ===========================
flag = 1
# flag = 0
if flag:
net_new = LeNet2(classes=2019)
print("加载前:", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)
print("加载后:", net_new.features[0].weight[0, ...])
为什么要使用断点续训练?
因为在训练模型过程中可能会因为断电等意外中断,训练停止后,我们不想重新训练模型。可以通过保存断点的方式,重新接着上次停止的地方继续训练模型。
保存断点:
checkpoint = {"model_state_dict": net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)
断点恢复
# ============================ step 2/5 模型 ============================
net = LeNet(classes=2)
net.initialize_weights()
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()
# 选择损失函数
# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)
# 设置学习率下降策略
# ============================ step 5+/5 断点恢复 ============================
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
scheduler.last_epoch = start_epoch
模型微调包括 3 个步骤:1.构建模型;2.加载参数;3.修改全连接 fc 层
# 3/3 替换 fc 层
num_ftrs = resnet18_ft.fc.in_features
resnet18_ft.fc = nn.Linear(num_ftrs, classes)
模型微调方法包括两种:
1.固定预训练的参数(requires_grad=False, lr=0)
2.Features Extractor 较小学习率(params_group)
# ============================ step 2/5 模型 ============================
# 1/3 构建模型
resnet18_ft = models.resnet18()
# 2/3 加载参数
# flag = 0
flag = 1
if flag:
path_pretrained_model = "pretrained_model.pth" # 通用路径
state_dict_load = torch.load(path_pretrained_model)
resnet18_ft.load_state_dict(state_dict_load)
# 法 1 : 冻结卷积层
flag_m1 = 0
# flag_m1 = 1
if flag_m1:
for param in resnet18_ft.parameters():
param.requires_grad = False
print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...]))
# 3/3 替换 fc 层
num_ftrs = resnet18_ft.fc.in_features
resnet18_ft.fc = nn.Linear(num_ftrs, classes)
resnet18_ft.to(device)
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()
# 选择损失函数
# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)
# 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)
# 设置学习率下降策略
# ============================ step 2/5 模型 ============================
# 1/3 构建模型
resnet18_ft = models.resnet18()
# 2/3 加载参数
# flag = 0
flag = 1
if flag:
path_pretrained_model = "pretrained_model.pth" # 通用路径
state_dict_load = torch.load(path_pretrained_model)
resnet18_ft.load_state_dict(state_dict_load)
# 3/3 替换 fc 层
num_ftrs = resnet18_ft.fc.in_features
resnet18_ft.fc = nn.Linear(num_ftrs, classes)
resnet18_ft.to(device)
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()
# 选择损失函数
# ============================ step 4/5 优化器 ============================
# 法 2 : conv 小学习率
fc_params_id = list(map(id, resnet18_ft.fc.parameters()))
base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())
optimizer = optim.SGD([{'params': base_params, 'lr': LR*0}, {'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)
# 设置学习率下降策略
使用第二种方法更加灵活。
pytorch 有两种数据类型:tensor 和 module。data=tensor+module.
从 CPU 到 GPU:data.to("cuda")
从 GPU 到 CPU:data.to("cpu")
to() 函数,数据类型转换或者设备转换。执行 to 函数是张量不执行 replace,module 执行 replace。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online