当前位置: 首页 > article >正文

深入理解 PyTorch .pth 文件和 Python pickle 模块:功能、应用及实际示例

深入理解 PyTorch .pth 文件和 Python pickle 模块:功能、应用及实际示例

在深入理解Python的pickle模块和PyTorch的.pth文件,以及pickle.pth文件中的应用前,我们首先需要掌握序列化和反序列化的基本概念。

序列化和反序列化

序列化是指将程序中的对象转换为一个字节序列的过程,这样就可以将其存储到磁盘上或通过网络传输到其他位置。反序列化是序列化的逆过程,即将字节序列恢复为原始对象。这两个过程是数据持久化和远程计算通信的基础。

Python的pickle模块

pickle是Python的标准库之一,提供了一个简单的方法用于序列化和反序列化Python对象结构。任何Python对象都可以通过pickle进行序列化,只要它们是pickle支持的类型。

核心功能

  • pickle.dump(obj, file):将对象obj序列化并写入到文件对象file中。
  • pickle.load(file):从文件对象file中读取序列化的对象并反序列化。
  • pickle.dumps(obj):将对象obj序列化为一个字节对象,不写入文件。
  • pickle.loads(bytes_object):将字节对象bytes_object反序列化为一个Python对象。

pickle的序列化过程不仅包括对象当前的状态(例如,数字,字符串,或复杂对象的集合),也包括对象的类型信息和结构。

PyTorch的.pth文件

在PyTorch中,.pth文件扩展通常用于保存模型的权重或整个模型。这些文件通过使用torch.save()函数创建,它内部使用pickle来序列化对象。.pth文件通常包含模型的状态字典(state_dict),这是一个从每个层的参数名称映射到其张量值的字典。

核心用途

  • 模型持久化:保存训练后的模型状态,以便将来可以重新加载和使用模型,不需要重新训练。
  • 模型迁移:将训练好的模型参数迁移到新的模型结构或平台上。

pickle.pth文件中的应用

当使用torch.save()来保存一个PyTorch模型或张量时,pickle用于将对象和它的层次结构转换为一个字节流,然后这个字节流被写入到一个.pth文件中。在加载模型时,torch.load()使用pickle来反序列化这个字节流,重建模型或张量。

示例

import torch
import torchvision.models as models

# 实例化一个预训练的ResNet模型
model = models.resnet18(pretrained=True)

# 保存模型状态字典
torch.save(model.state_dict(), 'model_weights.pth')

# 加载模型状态字典
loaded_state_dict = torch.load('model_weights.pth')
new_model = models.resnet18(pretrained=False)
new_model.load_state_dict(loaded_state_dict)

# 打印以验证加载
print(new_model)

在这个示例中,torch.save()内部使用pickle来序列化model.state_dict(),并将其保存为model_weights.pth。然后,我们使用torch.load()来加载这个.pth文件,其中pickle负责反序列化文件内容,并恢复为Python对象(在这种情况下是模型的状态字典)。最后,状态字典被用于初始化一个新的模型实例。

通过这种方式,pickle在PyTorch的模型保存和加载过程中扮演了核心角色,使得模型的状态可以在不同的计算环境中被迁移和复用。


http://www.kler.cn/a/418872.html

相关文章:

  • 第144场双周赛:移除石头游戏、两个字符串得切换距离、零数组变换 Ⅲ、最多可收集的水果数目
  • Android习题第7章广播
  • HarmonyOS(61) 组件间状态共享的分类以及状态选择器的选取优先级
  • Next.js-样式处理
  • 记一次 Vue3 中 ref 初始化未完成导致方法未触发的解决方案
  • qt QGraphicsPolygonItem详解
  • 前端学习week8——vue.js
  • 支持向量机算法:原理、实现与应用
  • LeetCode题解:34.在排序数组中查找元素的第一个和最后一个位置【Python题解超详细,二分查找法、index法】,知识拓展:index方法详解
  • [MySQL]流程控制语句
  • SpringCloud书单推荐
  • 深度学习常见数据集处理方法
  • 爬虫专栏第一篇:深入探索爬虫世界:基础原理、类型特点与规范要点全解析
  • npm : 无法加载文件 D:\nodejs\npm.ps1,因为在此系统上禁止运行脚本
  • 云技术基础(泷羽sec)
  • ubuntu配置网络
  • 【论文投稿】国产游戏技术:迈向全球引领者的征途
  • 缓存算法FIFO的说说
  • 单片机蓝牙手机 APP
  • Matlab 绘制雷达图像完全案例和官方教程(亲测)
  • 云计算的发展历史与未来展望
  • 架构 | 基于 crontab 进程监控增强集群可用性
  • 十、Spring Boot集成Spring Security之HTTP请求授权
  • RabbitMQ 消息确认机制
  • OCR实现微信截图改名
  • 新版 Navicat Premium 17 安装教程 (亲测可用)