1. 模型加载与保存
1.1 模型的加载与保存
path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"
# 保存整个模型
torch.save(net, path_model)
# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)
保存模型可以有两种方式:1.保存整个模型;2.保存模型参数;官网推荐第二种方式,保存模型参数。
模型加载也分两种方式:
1.如果是保存整个模型,则直接加载整个模型
# ================================== load net =========================== #
flag = 1
# flag = 0
if flag:
path_model = "./model.pkl"
net_load = torch.load(path_model)
print(net_load)
2.如果是保存模型参数,则加载模型参数,并更新网络模型的模型参数
# ================================== load state_dict ===========================
flag = 1
# flag = 0
if flag:
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)
print(state_dict_load.keys())
# ================================== update state_dict ===========================
flag = 1
# flag = 0
if flag:
net_new = LeNet2(classes=2019)
print("加载前:", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)
print("加载后:", net_new.features[0].weight[0, ...])
1.2 断点续训练
为什么要使用断点续训练?
因为在训练模型过程中可能会因为断电等意外中断,训练停止后,我们不想重新训练模型。可以通过保存断点的方式,重新接着上次停止的地方继续训练模型。
保存断点:

