当前位置: 首页 > article >正文

PyTorch 保存和加载模型状态和优化器状态

以下示例代码展示了如何在 PyTorch 中保存和加载模型状态和优化器状态,以便训练中断后可以继续训练。

1. 保存模型和优化器状态

假设模型训练了一段时间后,我们想要保存模型和优化器的状态,确保下次可以从同一位置继续训练。

2. 加载模型和优化器状态

加载保存的状态后,可以从保存的 epoch 继续训练。

示例代码

import torch
import torch.nn as nn
import torch.optim as optim

# 假设我们定义了一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# 创建模型和优化器
model = SimpleModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 模拟的训练代码片段
num_epochs = 20
checkpoint_path = "model_checkpoint.pth"

# 保存模型和优化器状态
def save_checkpoint(epoch, model, optimizer, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)
    print(f"Checkpoint saved at epoch {epoch}.")

# 加载模型和优化器状态
def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"Checkpoint loaded, starting at epoch {start_epoch}.")
    return start_epoch

# 尝试加载已保存的检查点
try:
    start_epoch = load_checkpoint(model, optimizer, checkpoint_path)
except FileNotFoundError:
    start_epoch = 0
    print("No checkpoint found, starting training from scratch.")

# 继续训练
for epoch in range(start_epoch, num_epochs):
    # 模拟训练步骤
    # output = model(input) ...
    # loss = loss_fn(output, target) ...
    # optimizer.zero_grad()
    # loss.backward()
    # optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs} completed.")

    # 每 5 个 epoch 保存一次模型状态
    if (epoch + 1) % 5 == 0:
        save_checkpoint(epoch + 1, model, optimizer, checkpoint_path)

解释

  1. 保存save_checkpoint 函数会在指定的 epoch 保存模型和优化器状态。
  2. 加载load_checkpoint 函数会加载模型和优化器状态,并返回上次的 epoch,以便继续训练。
  3. 训练控制start_epoch 变量控制了是否继续从之前的检查点继续训练,确保模型在中断后可以接着训练。


http://www.kler.cn/news/366068.html

相关文章:

  • Flutter鸿蒙next 状态管理高级使用:深入探讨 Provider
  • fpga系列 HDL: 竞争和冒险 01
  • 【HarmonyOS Next】原生沉浸式界面
  • 学习webservice的心得
  • 分布式光伏发电系统电气一次部分设计(开题报告3)
  • 中酱集团:黑松露酱油,天然配方定义健康生活
  • win10系统家庭版.net framework 3.5sp1启动错误如何解决
  • idea 集成maven
  • Maven(解决思路)
  • TCP标志位在网络故障排查中的作用
  • C语言与C++语言对比:为何C语言不支持函数重载而C++支持?
  • 【图论】Kruskal重构树
  • 《探索 HarmonyOS NEXT(5.0):开启构建模块化项目架构奇幻之旅 —— 模块化基础篇》
  • golang中的函数和结构体
  • Android H5页面性能分析与优化策略
  • 头歌——人工智能(机器学习 --- 决策树2)
  • SpringSecurity 简单使用,实现登录认证,通过过滤器实现自定义异常处理
  • 从汇编角度看C/C++函数指针与函数的调用差异
  • 甘特图代做服务
  • 银河麒麟相关
  • 若依 spring boot +vue3 前后端分离
  • Java语言的充电桩系统-云快充协议1.5-1.6
  • 互联网系统的微观与宏观架构
  • 重构案例:将纯HTML/JS项目迁移到Webpack
  • Nginx、Tomcat等项目部署问题及解决方案详解
  • 从零开始:构建一个高效的开源管理系统——使用 React 和 Ruoyi-Vue-Plus 的实战指南