当前位置: 首页 > article >正文

torch.save的用法

文章目录

  • 介绍
  • 基本用法
  • 常见用法
    • 保存张量
    • 保存模型的参数
    • 保存整个模型(不推荐)
  • 注意事项
  • 将网络模型保存到文件中,这将保存网络的结构和参数
  • 将网络的状态字典保存到文件中,状态字典包含了网络的参数

介绍

torch.save 是 PyTorch 中用于保存对象(如模型、张量、字典等)的函数。它可以将数据序列化并保存到文件中,方便后续加载和使用。

基本用法

torch.save(obj, f)

参数说明:

  • obj:要保存的对象,可以是模型、张量、字典等。
  • f:保存的目标文件路径,可以是:
    • 文件路径字符串(如 ‘model.pth’)。
    • 一个文件对象(如 open(‘model.pth’, ‘wb’))。
    • 一个 torch.ByteIO 对象(用于保存到内存中)。

常见用法

保存张量

import torch  

# 创建一个张量  
tensor = torch.tensor([1, 2, 3, 4])  

# 保存张量到文件  
torch.save(tensor, 'tensor.pth')  

# 加载张量  
loaded_tensor = torch.load('tensor.pth')  
print(loaded_tensor)  # 输出:tensor([1, 2, 3, 4])

保存模型的参数

保存模型的参数(state_dict)是 PyTorch 推荐的保存模型的方式,因为它只保存模型的权重和偏置,而不保存整个模型结构。

import torch  
import torch.nn as nn  

# 定义一个简单的模型  
model = nn.Linear(10, 1)  

# 保存模型的参数  
torch.save(model.state_dict(), 'model.pth')  

# 加载模型的参数  
model2 = nn.Linear(10, 1)  # 需要重新定义模型结构  
model2.load_state_dict(torch.load('model.pth'))  
print(model2.state_dict())  # 输出模型的参数

保存整个模型(不推荐)

可以直接保存整个模型(包括模型结构和参数),但这种方式依赖于保存时的代码环境,可能在不同版本的 PyTorch 或不同的代码结构中无法加载。

# 保存整个模型  
torch.save(model, 'entire_model.pth')  

# 加载整个模型  
loaded_model = torch.load('entire_model.pth')  
print(loaded_model)

注意事项

  • 推荐保存 state_dict
    • 保存 state_dict(模型参数)比保存整个模型更灵活,因为它不依赖于保存时的代码环境。
    • 加载时需要重新定义模型结构,然后加载参数。
  • 文件扩展名
    • 通常使用 .pth 或 .pt 作为保存文件的扩展名,但这只是约定俗成,PyTorch 并不强制要求。
  • GPU 和 CPU 的兼容性
    • 如果保存的模型是在 GPU 上,但加载时在 CPU 上,需要显式指定 map_location 参数。

将网络模型保存到文件中,这将保存网络的结构和参数

torch.save(net, ‘./data/net.pkl’)

将网络的状态字典保存到文件中,状态字典包含了网络的参数

torch.save(net.state_dict(), ‘./data/net_params.pkl’)


http://www.kler.cn/a/458920.html

相关文章:

  • HTML——28.音频的引入
  • Chapter4.2:Normalizing activations with layer normalization
  • 普及组集训数据结构--并查集
  • UE5.3 虚幻引擎 Windows插件开发打包(带源码插件打包、无源码插件打包)
  • SQLALchemy如何将SQL语句编译为特定数据库方言
  • 服务器等保测评日志策略配置
  • 机器学习中的常用特征选择方法及其应用案例
  • 【Qt】多元素控件:QListWidget、QTableWidget、QTreeWidget
  • I2C(一):存储器模式:stm32作为主机对AT24C02写读数据
  • 2024年12月28日人工智能与科技新闻速递
  • 使用 MediaDevices API 录制和下载视频教程
  • 基于Spring Boot + Vue3实现的在线预约看房管理系统源码+文档
  • 软硬件开发相关标准汇总
  • 联邦协作训练大模型的一些研究进展
  • 【LC】3159. 查询数组中元素的出现位置
  • mac docker部署jar包流程
  • 循环服务器
  • [Bert] 提取特征之后训练模型报梯度图错误
  • Effective C++ 条款42:了解 typename 的双重意义
  • 玉米中的元基因调控网络突出了功能上相关的调控相互作用。\ca.19a5.R
  • vue项目利用webpack进行优化案例
  • 小米路由器开启SSH,配置阿里云ddns,开启外网访问SSH和WEB管理界面
  • SAP-MM-物资库存调度调剂清单
  • 深入探讨C++中的互斥锁管理:`std::lock_guard`与`std::unique_lock`
  • C++ 设计模式:模板方法(Template Method)
  • Zookeeper中version-2目录下存放数据