当前位置: 首页 > article >正文

深度学习06 寻找与保存最优模型

目录

什么是最优模型?

什么是保存最优模型?

核心保存机制

案例:食物分类模型训练过程中的最优模型保存


本篇文章我们以食物分类的例子来介绍如何保存训练过程中的最优模型

首先我们简单了解下什么是最优模型

什么是最优模型?

定义:在特定评估标准下(如验证集准确率、损失值等),训练过程中性能表现最佳的模型版本。

特点

  • 相对性:仅在当前任务的训练数据和评价指标下为最优

  • 阶段性:可能出现在训练中期而非末尾(防止过拟合时)

  • 目的导向:分类任务看准确率,生成任务看BLEU分数,医疗任务看特异性和敏感性的平衡

什么是保存最优模型?

核心操作:在训练过程中持续监控模型性能,当发现更好的模型时,将其参数或完整结构持久化存储。

关键作用

  • 保留最佳性能模型

  • 防止训练后期性能衰退

  • 便于后期部署和性能复现

核心保存机制
# PyTorch 保存方式对比示例
import torch
​
# (推荐) 仅保存模型参数
torch.save(model.state_dict(), 'model_params.pth')
​
# 完整模型保存(可能版本不兼容)
torch.save(model, 'full_model.pth')
​
# 带优化器和epoch信息的存档
checkpoint = {
    'epoch': current_epoch+1,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'val_acc': best_acc
}
torch.save(checkpoint, 'checkpoint.pth')
​
# (补充) Keras自动保存最佳模型
from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint('best_model.h5', 
                           monitor='val_accuracy',
                           save_best_only=True,
                           mode='max')

案例:食物分类模型训练过程中的最优模型保存

食物分类训练集与测试集

将训练集、验证集中的图片路径分别放入train.txt test.txt中

相关库的导入

import torch
from torch.utils.data import DataLoader,Dataset
import numpy as np
from PIL  import Image
from torchvision import transforms
import torchvision.models as models
from torch import nn

数据预处理与加载

data_transform={
    'train':
        transforms.Compose([
            transforms.Resize([256,256]),
            transforms.RandomRotation(45),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ColorJitter(0.2,0.1,0.1,0.1),
            transforms.RandomGrayscale(p=0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ]),
    'valid':
        transforms.Compose([
            transforms.Resize([256,256]),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ]),
}
class food_dataset(Dataset):
    def __init__(self,file_path,transform=None):
        self.file_path=file_path
        self.imgs=[]
        self.labels=[]
        self.transform=transform
        with open(self.file_path) as f:
            samples=[x.strip().split(' ') for x in f.readlines()]
            for img_path,label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)
    def __len__(self):
        return len(self.imgs)
    def __getitem__(self, idx):
        image=Image.open(self.imgs[idx])
        if self.transform:
            image=self.transform(image)
​
        label =self.labels[idx]
        label=torch.from_numpy(np.array(label,dtype=np.int64))
        return image,label
​
training_data=food_dataset(file_path=r'./train.txt',transform=data_transform['train'])
test_data=food_dataset(file_path=r'./test.txt',transform=data_transform['valid'])
​
train_dataloader=DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader=DataLoader(test_data,batch_size=64,shuffle=True)
确定使用的设备是cpu还是GPU
device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(device)

定义CNN模型

from torch import nn
class  CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(3,16,5,1,2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2=nn.Sequential(
            nn.Conv2d(16,32,5,1,2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv3=nn.Sequential(
            nn.Conv2d(32, 128, 5, 1, 2),
            nn.ReLU(),
        )
        self.out=nn.Linear(128*64*64,20)
    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        x=x.view(x.size(0),-1)
        output=self.out(x)
        return output
model=CNN().to(device)
print(model)
定义训练模型
def train(dataloader,model,loss_fn,optimizer):
    model.train()
    batch_size_num=1
    for x,y in dataloader:
        x,y=x.to(device),y.to(device)
        pred=model.forward(x)
        loss=loss_fn(pred,y)
​
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_value=loss.item()
        if batch_size_num%1==0:
            print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')
​
        batch_size_num+=1

定义测试模型&保存最优模型

best_acc=0
def test(dataloader,model,loss_fn):
    global best_acc
    size=len(dataloader.dataset)
    num_batches=len(dataloader)
    model.eval()
    test_loss,correct=0,0
​
    with torch.no_grad():
        for x,y in dataloader:
            x,y=x.to(device),y.to(device)
            pred=model.forward(x)
            test_loss+=loss_fn(pred,y).item()
            correct+=(pred.argmax(1)==y).type(torch.float).sum().item()
​
    test_loss/=num_batches
    correct/=size
​
    print(f'Test result:\n Accuracy:{(100*correct)}%,Avg loss:{test_loss}')
​
    if correct>best_acc:
        best_acc=correct
        #print(model.state_dict().keys())
        
        #保存了完整模型,包括模型的架构和参数
        torch.save(model.state_dict(),'best2025_2_15.pth')

定义损失函数、优化器、调整学习率

loss_fn=nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(param_to_update, lr=0.001)  # 学习率可以根据需要调整
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.5)

开始训练

epochs=150
acc_s=[]
loss_s=[]
for t in range(epochs):
    print(f'EPOCH {t+1}\n-----------')
    train(train_dataloader,model,loss_fn,optimizer)
print('结束')
print(best_acc)
test(test_dataloader,model,loss_fn)


http://www.kler.cn/a/552141.html

相关文章:

  • Flink SQL与Doris实时数仓Join实战教程(理论+实例保姆级教程)
  • WPS/WORD$OffterAI
  • Vue3项目,蛋糕商城系统
  • C++ Primer 访问控制与封装
  • Android Studio:如何使用 RxBus 类进行事件发布和订阅
  • Kafka分区管理大师指南:扩容、均衡、迁移与限流全解析
  • 算法12-贪心算法
  • 前端基础——axios、fetch和xhr来封装请求
  • 用LangGraph轻松打造测试用例生成AI Agent
  • 【保姆级教程】DeepSeek R1+RAG,基于开源三件套10分钟构建本地AI知识库
  • 青少年网络安全竞赛python 青少年网络安全大赛
  • 【故障处理】- 11g迁19C数据泵报错: ORA-39083 ORA-06598 导致数据库大量对象导入不进去
  • Linux环境Docker使用代理推拉镜像
  • Postgresql的三种备份方式_postgresql备份
  • ARM中断流程思考。
  • 百度搜索融合 DeepSeek 满血版,开启智能搜索新篇
  • 微信小程序---计划时钟设计与实现
  • 欢乐力扣:旋转图像
  • redis的应用,缓存,分布式锁
  • LeetCodeBug-member access within null pointer of type ‘struct ListNode‘