当前位置: 首页 > article >正文

基于CNN的FashionMNIST数据集识别3——模型验证

源码

import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet



def test_data_process():
    test_data = FashionMNIST(root='./data',
                              train=False,
                              transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
                              download=True)

    test_dataloader = Data.DataLoader(dataset=test_data,
                                       batch_size=1,
                                       shuffle=True,
                                       num_workers=0)
    return test_dataloader


def test_model_process(model, test_dataloader):
    # 设定测试所用到的设备,有GPU用GPU没有GPU用CPU
    device = "cuda" if torch.cuda.is_available() else 'cpu'

    # 讲模型放入到训练设备中
    model = model.to(device)

    # 初始化参数
    test_corrects = 0.0
    test_num = 0

    # 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度
    with torch.no_grad():
        for test_data_x, test_data_y in test_dataloader:
            # 将特征放入到测试设备中
            test_data_x = test_data_x.to(device)
            # 将标签放入到测试设备中
            test_data_y = test_data_y.to(device)
            # 设置模型为评估模式
            model.eval()
            # 前向传播过程,输入为测试数据集,输出为对每个样本的预测值
            output= model(test_data_x)
            # 查找每一行中最大值对应的行标
            pre_lab = torch.argmax(output, dim=1)
            # 如果预测正确,则准确度test_corrects加1
            test_corrects += torch.sum(pre_lab == test_data_y.data)
            # 将所有的测试样本进行累加
            test_num += test_data_x.size(0)

    # 计算测试准确率
    test_acc = test_corrects.double().item() / test_num
    print("测试的准确率为:", test_acc)




if __name__=="__main__":
    # 加载模型
    model = LeNet()
    model.load_state_dict(torch.load('best_model.pth'))
    # 加载测试数据
    test_dataloader = test_data_process()
    # 加载模型测试的函数
    test_model_process(model, test_dataloader)

源码讲解

当模型训练完毕后,我们得到的是一组最优的参数配置。

最后要做的就是验证这组参数的表现。

数据准备

def test_data_process():
    test_data = FashionMNIST(root='./data',
                              train=False,
                              transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
                              download=True)

    test_dataloader = Data.DataLoader(dataset=test_data,
                                       batch_size=1,
                                       shuffle=True,
                                       num_workers=0)
    return test_dataloader

测试数据的准备和训练数据的准备有明显不同:

  1. 测试数据集是将MINIST里所有的数据当做验证集。并且训练模式设置为false。
  2. dataloader里面的batch大小设置为1,也就是说对每个样本进行验证,不再存在“分批”的概念。

循环验证

    # 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度
    with torch.no_grad():
        for test_data_x, test_data_y in test_dataloader:
            # 将特征放入到测试设备中
            test_data_x = test_data_x.to(device)
            # 将标签放入到测试设备中
            test_data_y = test_data_y.to(device)
            # 设置模型为评估模式
            model.eval()
            # 前向传播过程,输入为测试数据集,输出为对每个样本的预测值
            output= model(test_data_x)
            # 查找每一行中最大值对应的行标
            pre_lab = torch.argmax(output, dim=1)
            # 如果预测正确,则准确度test_corrects加1
            test_corrects += torch.sum(pre_lab == test_data_y.data)
            # 将所有的测试样本进行累加
            test_num += test_data_x.size(0)

和之前训练模型时的验证逻辑基本相同。只进行前向传播,将预测正确的样本个数进行累加。

代码运行

if __name__=="__main__":
    # 加载模型
    model = LeNet()
    model.load_state_dict(torch.load('best_model.pth'))
    # 加载测试数据
    test_dataloader = test_data_process()
    # 加载模型测试的函数
    test_model_process(model, test_dataloader)

在代码运行时,需要将参数载入到模型里,再进行验证。


http://www.kler.cn/a/560019.html

相关文章:

  • Java多线程与高并发专题——深入synchronized
  • PythonWeb开发框架—Django之DRF框架的使用详解
  • ai-1、人工智能概念与学习方向
  • 商业化运作的“日记”
  • system运行进程以及应用场景
  • 【Python爬虫(61)】Python金融数据挖掘之旅:从爬取到预测
  • 【odoo18-文件管理】在uniapp上访问odoo系统上的图片
  • 第二个接口-分页查询
  • 网站快速收录:如何优化网站图片Alt标签?
  • 如何安装vm和centos
  • 基于 IMX6ULL 的环境监测自主调控系统
  • github如何创建空文件夹
  • 图像处理篇---图像处理中常见参数
  • 基础学科与职业教育“101计划”:推动教育创新与人才培养
  • Windows逆向工程入门之逻辑运算指令解析与应用
  • 湖北中医药大学谱度众合(武汉)生命科技有限公司研究生工作站揭牌
  • 异常(1)
  • 如何在java中用httpclient实现rpc post 请求
  • linux-多进程基础(1) 程序、进程、多道程序、并发与并行、进程相关命令,fork
  • 瑞幸咖啡×动漫IP:精选联名案例,解锁品牌营销新玩法