当前位置: 首页 > article >正文

【深度学习】Pytorch:导入导出模型参数

PyTorch 是深度学习领域中广泛使用的框架,熟练掌握其模型参数的管理对于模型训练、推理以及部署非常重要。本文将全面讲解 PyTorch 中关于模型参数的操作,包括如何导出、导入以及如何下载模型参数。

什么是模型参数

模型参数是指深度学习模型中需要通过训练来优化的变量,如神经网络中的权重和偏置。这些参数存储在 PyTorch 的 torch.nn.Module 对象中,通过以下方式访问:

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

model = SimpleModel()
print(model.state_dict())

state_dict 是 PyTorch 用于保存模型参数的标准 Python 字典。字典中的键是参数的名称,值是对应的张量。

导出模型参数

导出模型参数有助于保存训练好的模型,以便以后重新加载或部署。

保存为文件

PyTorch 推荐使用 .pt.pth 扩展名保存模型参数。以下是保存模型参数的步骤:

# 保存模型参数
torch.save(model.state_dict(), "model.pth")

上述代码会将 state_dict 保存到名为 model.pth 的文件中。

自定义保存

你还可以将模型参数与其他信息一起保存,例如优化器的状态:

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
save_path = "checkpoint.pth"

# 保存模型参数和优化器状态
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': 10,
}, save_path)

导入模型参数

导入模型参数时需要确保模型结构与保存时一致,否则可能会发生错误。

加载参数到模型

# 加载模型参数
model = SimpleModel()
model.load_state_dict(torch.load("model.pth"))
model.eval()  # 切换到推理模式

load_state_dict 方法会将保存的参数加载到当前模型中。加载后,使用 eval() 方法将模型切换到推理模式以禁用 dropout 和 batch normalization 的训练行为。

加载带优化器的检查点

如果保存时包括了优化器状态,加载时可以这样操作:

# 加载模型参数和优化器状态
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

这对于继续训练时尤其有用。

下载预训练模型

PyTorch 提供了一些常见模型的预训练参数,可以通过 torchvision 或其他库直接下载。

使用 torchvision

torchvision.models 提供了许多计算机视觉模型:

from torchvision import models

# 下载 ResNet-18 的预训练模型
model = models.resnet18(pretrained=True)

pretrained=True 会自动下载并加载预训练参数。

从官方链接下载

官方模型仓库:https://pytorch.org/vision/stable/models.html 提供了每个模型的下载链接。

可以通过 URL 下载并加载:

import torch

url = "https://download.pytorch.org/models/resnet18-f37072fd.pth"
state_dict = torch.hub.load_state_dict_from_url(url)

# 加载到模型中
model = models.resnet18()
model.load_state_dict(state_dict)

自定义模型下载

如果使用非官方模型,可以使用 torch.hub 来加载:

model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)

检查和管理参数

查看参数名和形状

for name, param in model.named_parameters():
    print(f"Name: {name}, Shape: {param.shape}")

冻结部分参数

冻结参数可以减少计算量:

for param in model.parameters():
    param.requires_grad = False

冻结部分层:

for name, param in model.named_parameters():
    if "layer4" not in name:  # 仅训练 layer4
        param.requires_grad = False

总结

本文详细介绍了 PyTorch 中模型参数的导出、导入以及下载的各种方法,同时提供了参数管理的技巧。在实际开发中,熟练运用这些方法可以帮助你高效地保存、加载和迁移模型,从而提升工作效率。


http://www.kler.cn/a/507400.html

相关文章:

  • [Mac + Icarus Verilog + gtkwave] Mac运行Verilog及查看波形图
  • 静态综合路由实验
  • 工作中redis常用的5种场景
  • 从漏洞管理到暴露管理:网络安全的新方向
  • 传统摄像头普通形态的系统连接方式
  • 深度学习电影推荐-CNN算法
  • python mysql库的三个库mysqlclient mysql-connector-python pymysql如何选择,他们之间的区别
  • 【Linux】打破Linux神秘的面纱
  • 西门子【Library of Basic Controls (LBC)基本控制库”(LBC) 提供基本控制功能】
  • 神经网络基础-正则化方法
  • 机器学习-常用的三种梯度下降法
  • CSS 样式 margin:0 auto; 详细解读
  • Jackson 中的多态类型支持:@JsonTypeInfo 和 @JsonSubTypes 使用技巧
  • 蓝桥杯刷题第四天——字符排序
  • 基于智能物联网的肉鸡舍控制器:设计、实施、性能评估与优化
  • 个人vue3-学习笔记
  • 服务器数据恢复—EMC存储POOL中数据卷被删除的数据恢复案例
  • Qt类的提升(Python)
  • 大模型赋能医疗项目,深兰科技与武汉协和医院达成合作
  • deepin-如何在 ArchLinux 发行版上安装 DDE 桌面环境
  • 老centos7 升级docker.io为docker-ce 脚本
  • 【GIS操作】使用ArcGIS Pro进行海图的地理配准(附:墨卡托投影对比解析)
  • 七大排序算法
  • 网络协议基础--IP协议
  • 【Linux】gawk编辑器二
  • nginx 修改内置 404 页面、点击劫持攻击。