当前位置: 首页 > article >正文

动手学深度学习(pytorch)学习记录21-读写文件(模型与参数)[学习记录]

目录

  • 加载和保存张量
  • 加载和保存模型参数

保存模型的好处众多,涵盖了从开发到部署的整个机器学习生命周期。

  • 节省资源:训练模型可能需要大量的时间和计算资源。保存模型可以避免重复训练,从而节省时间和计算资源。
  • 快速部署:一旦模型被训练并保存,它可以迅速部署到生产环境中,加速产品上市时间。
  • 版本控制:保存不同版本的模型有助于跟踪模型的迭代过程,便于比较和回滚到之前的版本。
  • 离线使用:保存的模型可以在没有网络连接的情况下使用,这对于需要在本地设备上运行模型的应用程序非常有用。
  • 模型共享:研究人员和开发者可以共享他们的模型,促进合作和知识传播。
  • 模型评估:保存的模型可以在不同的数据集上进行评估,帮助验证模型的泛化能力和性能。
  • 实验复现:保存模型的状态使得其他研究者可以复现实验结果,增加研究的可验证性。
  • 业务连续性:在系统升级或迁移过程中,保存的模型可以确保业务的连续性,减少停机时间。
  • 法律合规:在某些行业,如医疗和金融,保存模型可能是必须的,以满足法律和合规要求。
  • 模型优化:保存的模型可以用于进一步的优化,如模型压缩、加速等,以适应不同的部署环境。
  • 模型监控:在模型部署后,保存的模型可以用于监控和比较,以检测模型性能随时间的变化。
  • 用户信任:提供透明的模型保存信息可以增加用户对模型决策的信任。
  • 教育和研究:保存的模型可以作为教育材料,帮助学生和研究人员学习模型的工作原理。
  • 灾难恢复:在发生系统故障时,保存的模型可以作为备份,快速恢复服务。
  • 长期维护:随着时间的推移,保存的模型可以用于维护和更新,以适应新的数据和需求。

加载和保存张量

# 保存张量
import torch
from torch import nn
from torch.nn import functional as F

x = torch.arange(4)
torch.save(x, 'x-file')

将存储在文件中的数据读回内存。

x2 = torch.load('x-file')
x2
tensor([0, 1, 2, 3])

存储一个张量列表,然后把它们读回内存。

y = torch.zeros(4)
torch.save([x, y],'x-files')
x2, y2 = torch.load('x-files')
(x2, y2)
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))

可以写入或读取从字符串映射到张量的字典。 当我们要读取或写入模型中的所有权重时,这很方便。

mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')
mydict2
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

加载和保存模型参数

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

将模型的参数存储在一个叫做“mlp.params”的文件中

torch.save(net.state_dict(), 'mlp.params')

为恢复模型,需实例化原始多层感知机模型的一个备份, 直接读取文件中存储的参数作为初始参数。

clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
MLP(
  (hidden): Linear(in_features=20, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
)

由于两个实例具有相同的模型参数,在输入相同的X时, 两个实例的计算结果应该相同。

Y_clone = clone(X)
Y_clone == Y
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

保存整个模型

torch.save(net, 'net.pt')
net1 = torch.load('net.pt')
net1.eval()
MLP(
  (hidden): Linear(in_features=20, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
)

原模型和新加载的模型参数应该是相同的。

net.state_dict()['hidden.weight'].data == net1.state_dict()['hidden.weight'].data
tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

封面图片来源

欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。


http://www.kler.cn/a/287180.html

相关文章:

  • 【YOLOv8】安卓端部署-2-项目实战
  • Linux的目录结构
  • 领海基点的重要性-以黄岩岛(民主礁)的领海及专属经济区时空构建为例
  • Matplotlib | 理解直方图中bins表示的数据含义
  • 向量数据库FAISS之五:原理(LSH、PQ、HNSW、IVF)
  • Nginx Spring boot指定域名跨域设置
  • Oracle rac模式下undo表空间爆满的解决
  • 部署project_exam_system项目——及容器的编排
  • stm32开发之rt-thread使SysTick处于微妙级运行时,出现的问题记录
  • GraphPad Prism下载安装教程怎样中文汉化
  • 第3章-03-Python库Requests安装与讲解
  • 机器学习数学公式推导之线性回归
  • 系统监控和命令行环境
  • python中**字典的含义
  • MATLAB下的粒子滤波例程|三维非线性模型|组合导航|PF代码(无需下载,直接复制到MATLAB上即可运行)
  • http的三次握手和四次挥手
  • 制造企业SRM系统中如何进行供应商的管理
  • 质量小议43 - 提效
  • 如何通过选择合适的编程工具来提升编程效率
  • 零基础5分钟上手亚马逊云科技-高可用负载均衡器
  • 浅谈SpringMvc的核心流程与组件
  • 零基础学习Redis(7) -- hash类型命令使用
  • 【区块链 + 司法存证】数据存证区块链服务开放平台 | FISCO BCOS应用案例
  • Qt详解QHostInfo
  • MindSearch CPU-only 版部署
  • 华为云征文|部署内容管理系统 Joomla