当前位置: 首页 > article >正文

【深度学习】(7)--神经网络之保存最优模型

文章目录

  • 保存最优模型
    • 一、两种保存方法
      • 1. 保存模型参数
      • 2. 保存完整模型
    • 二、迭代模型
  • 总结

保存最优模型

我们在迭代模型训练时,随着次数初始的增多,模型的准确率会逐渐的上升,但是同时也随着迭代次数越来越多,由于模型会开始学习到训练数据中的噪声或非共性特征,发生过拟合现象,使得模型的准确率会上下震荡甚至于下降。

本篇就是介绍我们如何在进行那么多次迭代之中,找到训练最好效果时,模型的参数或完整模型。也方便以后使用模型时直接使用。

一、两种保存方法

我们知道,一个模型到底好不好,主要体现在对测试集数据结果上的表现,所以我们的方法主要从测试集入手,计算每次迭代测试集数据的准确率,取到准确率最大时对应的模型和参数

那么,我们该如何保存模型和参数呢?介绍一个小东西:

  • 文件拓展名pt\pth,t7,使用pt\pth或t7作为模型文件扩展名,保存模型的整个状态(包括模型架构和参数)或仅保存模型的参数(即状态字典,state_dict)。

1. 保存模型参数

方法

torch.save(model.state_dict(),path)
# model.state_dict()是一个从参数名称映射到参数张量的字典对象,它包含了模型的所有权重和偏置项
# path为创建的保存模型的文件

通过比较每一次迭代准确率的大小,取准确率最大时模型的参数

best_acc = 0
"""-----测试集-----"""
def test(dataloader,model,loss_fn):
    global best_acc
    size = len(dataloader.dataset) # 总数据大小
    num_batches = len(dataloader) # 划分的小批次数量
    model.eval()
    test_loss,correct = 0,0
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(device),y.to(device)
            pred = model.forward(x)
            test_loss += loss_fn(pred,y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 预测正确的个数
    test_loss /= num_batches
    correct /= size
    correct = round(correct, 4)
    print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")

    # 保存最优模型的方法(文件扩展名一般:pt\pth,t7)
    if correct > best_acc:
        best_acc = correct
    # 1. 保存模型参数方法:torch.save(model.state_dict(),path)  (w,b)
        print(model.state_dict().keys()) # 输出模型参数名称cnn
        torch.save(model.state_dict(),"best.pth") 

2. 保存完整模型

方法

torch.save(model,path)
# 直接得到整个模型

依旧是通过比较每一次迭代准确率的大小,但是取准确率最大时的整个模型

def test(dataloader,model,loss_fn):
    global best_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss,correct = 0,0
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(device),y.to(device)
            pred = model.forward(x)
            test_loss += loss_fn(pred,y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            a = (pred.argmax(1) == y)
            b = (pred.argmax(1) == y).type(torch.float)
    test_loss /= num_batches
    correct /= size
    correct = round(correct, 4)
    print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")

# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)
    if correct > best_acc:
        best_acc = correct
    # 2. 保存完整模型(w,b,模型cnn)
        torch.save(model,"best1.pt")

二、迭代模型

接下来就要迭代模型,得到最优的模型:

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.0001)

epochs = 150
# training_data、test_data:数据预处理好的数据
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)
for t in range(epochs):
    print(f"Epoch {t+1} \n-------------------------")
    train(train_dataloader,model,loss_fn,optimizer)
    test(test_dataloader,model,loss_fn)
print("Done!")

在每轮数据迭代后,project工程栏中的best1.ptbest.pth文件中模型会随着迭代及时更新,迭代结束后,文件中保存的就是最优模型以及最优的模型参数。

在这里插入图片描述

总结

本篇介绍了:

  1. 为什么随着迭代次数越来越多,模型的准确率会上下震荡甚至于下降。—> 过拟合
  2. pt\pth,t7三个扩展名,用于保存完整模型或者模型参数。
  3. 模型的好坏,通过体现在测试集的结果上。
  4. 保存最优模型的两种方法:保存模型参数和保存完整模型。

http://www.kler.cn/a/320533.html

相关文章:

  • 【Nginx】反向代理Https时相关参数:
  • 深度解读混合专家模型(MoE):算法、演变与原理
  • 机器学习(1)
  • 【GPTs】Gif-PT:DALL·E制作创意动图与精灵动画
  • 基于 CentOS7.6 的 Docker 下载常用的容器(MySQLRedisMongoDB),解决拉取容器镜像失败问题
  • STM32 Option Bytes(选项字节)
  • 自动驾驶,被逼着上市?
  • 【Python机器学习】NLP信息提取——提取人物/事物关系
  • WPS文字 分栏注意项
  • Java项目实战II基于Java+Spring Boot+MySQL的汽车销售网站(文档+源码+数据库)
  • PyTorch开源的深度学习框架
  • 2、electron vue3 怎么创建子窗口,并给子窗口路由传参
  • 【Linux系统编程】第二十二弹---操作系统核心概念:进程创建与终止机制详解
  • LInux操作系统安装Jenkins
  • MFC-基础架构
  • 实验二十:ds1302时钟实验
  • 【MYSQL】聚合查询、分组查询、联合查询
  • CSS开发全攻略
  • 后端开发面试题7(附答案)
  • 概率论与数理统计复习笔记
  • 本地电脑基于nginx的https单向认证和双向认证(自制证书+nginx配置)保姆级
  • 一天认识一个硬件之鼠标
  • web前端(本地存储问题超过5MB不继续保存解决办法)
  • Leetcode 378. 有序矩阵中第 K 小的元素
  • TypeScript 设计模式之【建造者模式】
  • 基于python+spark的外卖餐饮数据分析系统设计与实现(含论文)-Spark毕业设计选题推荐