torch.save的用法
文章目录
- 介绍
- 基本用法
- 常见用法
- 保存张量
- 保存模型的参数
- 保存整个模型(不推荐)
- 注意事项
- 将网络模型保存到文件中,这将保存网络的结构和参数
- 将网络的状态字典保存到文件中,状态字典包含了网络的参数
介绍
torch.save 是 PyTorch 中用于保存对象(如模型、张量、字典等)的函数。它可以将数据序列化并保存到文件中,方便后续加载和使用。
基本用法
torch.save(obj, f)
参数说明:
- obj:要保存的对象,可以是模型、张量、字典等。
- f:保存的目标文件路径,可以是:
- 文件路径字符串(如 ‘model.pth’)。
- 一个文件对象(如 open(‘model.pth’, ‘wb’))。
- 一个 torch.ByteIO 对象(用于保存到内存中)。
常见用法
保存张量
import torch
# 创建一个张量
tensor = torch.tensor([1, 2, 3, 4])
# 保存张量到文件
torch.save(tensor, 'tensor.pth')
# 加载张量
loaded_tensor = torch.load('tensor.pth')
print(loaded_tensor) # 输出:tensor([1, 2, 3, 4])
保存模型的参数
保存模型的参数(state_dict)是 PyTorch 推荐的保存模型的方式,因为它只保存模型的权重和偏置,而不保存整个模型结构。
import torch
import torch.nn as nn
# 定义一个简单的模型
model = nn.Linear(10, 1)
# 保存模型的参数
torch.save(model.state_dict(), 'model.pth')
# 加载模型的参数
model2 = nn.Linear(10, 1) # 需要重新定义模型结构
model2.load_state_dict(torch.load('model.pth'))
print(model2.state_dict()) # 输出模型的参数
保存整个模型(不推荐)
可以直接保存整个模型(包括模型结构和参数),但这种方式依赖于保存时的代码环境,可能在不同版本的 PyTorch 或不同的代码结构中无法加载。
# 保存整个模型
torch.save(model, 'entire_model.pth')
# 加载整个模型
loaded_model = torch.load('entire_model.pth')
print(loaded_model)
注意事项
- 推荐保存 state_dict
- 保存 state_dict(模型参数)比保存整个模型更灵活,因为它不依赖于保存时的代码环境。
- 加载时需要重新定义模型结构,然后加载参数。
- 文件扩展名
- 通常使用 .pth 或 .pt 作为保存文件的扩展名,但这只是约定俗成,PyTorch 并不强制要求。
- GPU 和 CPU 的兼容性
- 如果保存的模型是在 GPU 上,但加载时在 CPU 上,需要显式指定 map_location 参数。
将网络模型保存到文件中,这将保存网络的结构和参数
torch.save(net, ‘./data/net.pkl’)
将网络的状态字典保存到文件中,状态字典包含了网络的参数
torch.save(net.state_dict(), ‘./data/net_params.pkl’)