当前位置: 首页 > article >正文

深度学习在训练时更新和保存最佳训练结果的方法(字典方法,本地保存方法,模型深拷贝方法)

1.用参数字典 model.state_dict()更新最优参数

best_state_dict = model.state_dict()  # 训练前
best_state_dict = model.state_dict()  # 训练时更新最优state_dict

完整代码:

 # 初始化一个变量来保存最优的state_dict
  best_state_dict = model.state_dict()
  for epoch in range(epochs):
      model.train()
      # 训练集上训练模型权重
      for data, targets in tqdm.tqdm(train_dataloader):
          # 把数据加载到GPU上
          data = data.to(devices[0])
          targets = targets.to(devices[0])

          # 前向传播
          preds = model(data)
          loss = criterion(preds, targets)

          # 反向传播
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

      # 测试集上评估模型性能
      model.eval()
      num_correct = 0
      num_samples = 0
      with torch.no_grad():
          for x, y in tqdm.tqdm(test_dataloader):
              x = x.to(devices[0])
              y = y.to(devices[0])
              preds = model(x)
              predictions = preds.max(1).indices  # 返回每一行的最大值和该最大值在该行的列索引
              num_correct += (predictions == y).sum()
              num_samples += predictions.size(0)
          acc = (num_correct / num_samples).item()
          if acc > best_acc:
              best_acc = acc
              best_epoch = epoch+1
              # 保存模型最优准确率的参数
              best_state_dict = model.state_dict()  # 更新最优state_dict
      model.train()
  # 训练结束保存
  torch.save(best_state_dict, f"weights/{model_name}_{epochs}_{best_acc}.pth")

2.训练过程中保存最优参数

if acc > best_acc:
    best_acc = acc
    best_epoch = epoch+1
    torch.save(best_state_dict, f"weights/{model_name}_{epochs}_{best_acc}.pth")

3.对模型深拷贝方法保存最优模型

深拷贝方法介绍

copy模块可以用来创建一个对象的深拷贝。这意味着复制后的模型和原始模型是完全独立的,包括它们的参数。

import torch  
import copy  
import torch.nn as nn  
  
# 假设我们有一个模型实例  
original_model = nn.Sequential(  
    nn.Linear(10, 5),  
    nn.ReLU(),  
    nn.Linear(5, 2)  
)  
  
# 复制模型  
model_copy = copy.deepcopy(original_model)

深拷贝方法保存最优模型

best_model = copy.deepcopy(model.state_dict())  # 训练前
best_model = copy.deepcopy(model.state_dict())  # 训练时更新最优state_dict

代码案例:

   def fit_zsl(self):
        best_acc = 0
        mean_loss = 0
        last_loss_epoch = 1e8
        # 定义best_model
        best_model = copy.deepcopy(self.model.state_dict())
        for epoch in range(self.nepoch):
            for i in range(0, self.ntrain, self.batch_size):
                self.model.zero_grad()
                batch_input, batch_label = self.next_batch(self.batch_size)
                self.input.copy_(batch_input)
                self.label.copy_(batch_label)

                inputv = Variable(self.input)
                labelv = Variable(self.label)
                output = self.model(inputv)
                loss = self.criterion(output, labelv)
                mean_loss += loss.item()
                loss.backward()
                self.optimizer.step()
            acc = self.val(
                self.test_unseen_feature,
                self.test_unseen_label,
                self.unseenclasses,
            )
            if acc > best_acc:
                best_acc = acc
                # 更新best_model
                best_model = copy.deepcopy(self.model.state_dict())
        #训练完毕本地保存
		torch.save(best_model.state_dict(), f"weights/{self.nepoch}_{best_acc}.pth")
        return best_acc, best_model

http://www.kler.cn/news/157118.html

相关文章:

  • selenium中元素定位正确但是操作失败,6种解决办法全搞定
  • 六、ZooKeeper Java API操作
  • 【数据结构】——栈|队列(基本功能)
  • Python字符串模糊匹配工具:TheFuzz 库详解
  • 关于使用百度开发者平台处理语音朗读问题排查
  • Spring-Security取消验证-Get请求接口正常,Post请求报错403
  • java后端技术演变杂谈(未完结)
  • c语言笔记之小项目家庭收支记账软件
  • java synchronized详解
  • ruby安装(vscode、rubymine)
  • 「Qt Widget中文示例指南」如何创建一个计算器?(二)
  • 深度学习(五):pytorch迁移学习之resnet50
  • MySQL安装,建立,导入本地Txt文件
  • 寻找两个有序数组的中位数算法(leetcode第4题)
  • Android 7.1 点击清空全部按钮清空一切运行进程(包括后台在播音乐)
  • 【Linux】进程控制--进程创建/进程终止/进程等待/进程程序替换/简易shell实现
  • CPP-SCNUOJ-Problem P29. [算法课指针] 颜色分类,小白偏题超简单方法
  • 前端---JavaScript篇
  • 【LeeCode】链表总结
  • 大数据之Redis
  • Python按要求从多个txt文本中提取指定数据
  • 卷积神经网络(CNN):艺术作品识别
  • 【算法每日一练]-图论(保姆级教程 篇6(图上dp))#最大食物链 #游走
  • redis的缓存击穿,缓存穿透,缓存雪崩
  • 2023年抗量子加密的十件大事
  • java后端redis缓存缓存预热
  • Ubuntu开机出现Welcome to emergency mode解决办法
  • 【qml入门系列教程】:qml QtObject用法介绍
  • c++ day5
  • Windows下打包C++程序无法执行:无法定位程序输入点于动态链接库