【深度学习】Pytorch:导入导出模型参数
PyTorch 是深度学习领域中广泛使用的框架,熟练掌握其模型参数的管理对于模型训练、推理以及部署非常重要。本文将全面讲解 PyTorch 中关于模型参数的操作,包括如何导出、导入以及如何下载模型参数。
什么是模型参数
模型参数是指深度学习模型中需要通过训练来优化的变量,如神经网络中的权重和偏置。这些参数存储在 PyTorch 的 torch.nn.Module
对象中,通过以下方式访问:
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
model = SimpleModel()
print(model.state_dict())
state_dict
是 PyTorch 用于保存模型参数的标准 Python 字典。字典中的键是参数的名称,值是对应的张量。
导出模型参数
导出模型参数有助于保存训练好的模型,以便以后重新加载或部署。
保存为文件
PyTorch 推荐使用 .pt
或 .pth
扩展名保存模型参数。以下是保存模型参数的步骤:
# 保存模型参数
torch.save(model.state_dict(), "model.pth")
上述代码会将 state_dict
保存到名为 model.pth
的文件中。
自定义保存
你还可以将模型参数与其他信息一起保存,例如优化器的状态:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
save_path = "checkpoint.pth"
# 保存模型参数和优化器状态
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': 10,
}, save_path)
导入模型参数
导入模型参数时需要确保模型结构与保存时一致,否则可能会发生错误。
加载参数到模型
# 加载模型参数
model = SimpleModel()
model.load_state_dict(torch.load("model.pth"))
model.eval() # 切换到推理模式
load_state_dict
方法会将保存的参数加载到当前模型中。加载后,使用 eval()
方法将模型切换到推理模式以禁用 dropout 和 batch normalization 的训练行为。
加载带优化器的检查点
如果保存时包括了优化器状态,加载时可以这样操作:
# 加载模型参数和优化器状态
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
这对于继续训练时尤其有用。
下载预训练模型
PyTorch 提供了一些常见模型的预训练参数,可以通过 torchvision
或其他库直接下载。
使用 torchvision
torchvision.models
提供了许多计算机视觉模型:
from torchvision import models
# 下载 ResNet-18 的预训练模型
model = models.resnet18(pretrained=True)
pretrained=True
会自动下载并加载预训练参数。
从官方链接下载
官方模型仓库:https://pytorch.org/vision/stable/models.html 提供了每个模型的下载链接。
可以通过 URL 下载并加载:
import torch
url = "https://download.pytorch.org/models/resnet18-f37072fd.pth"
state_dict = torch.hub.load_state_dict_from_url(url)
# 加载到模型中
model = models.resnet18()
model.load_state_dict(state_dict)
自定义模型下载
如果使用非官方模型,可以使用 torch.hub
来加载:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
检查和管理参数
查看参数名和形状
for name, param in model.named_parameters():
print(f"Name: {name}, Shape: {param.shape}")
冻结部分参数
冻结参数可以减少计算量:
for param in model.parameters():
param.requires_grad = False
冻结部分层:
for name, param in model.named_parameters():
if "layer4" not in name: # 仅训练 layer4
param.requires_grad = False
总结
本文详细介绍了 PyTorch 中模型参数的导出、导入以及下载的各种方法,同时提供了参数管理的技巧。在实际开发中,熟练运用这些方法可以帮助你高效地保存、加载和迁移模型,从而提升工作效率。