[Pytorch] 保存模型与加载模型
1、保存模型
# 定义模型
model = BPNetModel(n_feature=n_feature,n_hidden=n_hidden,n_output=n_output) #调用网络
# 保存模型
torch.save(model, 'BPNetModel0.pth')
2、加载模型
import torch
## 读取模型
model = torch.load('BPNetModel0.pth')
3、保存模型参数
#调用网络
model = BPNetModel(n_feature=n_feature,n_hidden=n_hidden,n_output=n_output)
# 保存模型
torch.save({'model': model.state_dict()}, 'BPNetModel0.pth')
4、加载参数
# 读取模型
state_dict = torch.load('model_name.pth')
model.load_state_dict(state_dict['model'])