当前位置: 首页 > article >正文

pytorch 模型测试

在使用 PyTorch 进行模型测试时,一般包含加载测试数据、加载训练好的模型、进行推理以及评估模型性能等步骤。以下为你详细介绍每个步骤及对应的代码示例。

1. 导入必要的库

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

2. 加载测试数据

假设我们使用的是 CIFAR - 10 数据集作为示例,你需要定义数据预处理的转换操作,然后加载测试数据集。

# 定义数据预处理的转换操作
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载测试数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

# 类别标签
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

3. 定义模型结构

如果你已经有训练好的模型,这一步可以跳过。但为了完整性,这里给出一个简单的卷积神经网络(CNN)示例。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

4. 加载训练好的模型

假设你已经将训练好的模型保存为 cifar_net.pth 文件,现在可以加载它。

# 加载模型
net.load_state_dict(torch.load('cifar_net.pth'))

5. 进行推理和评估

在测试阶段,我们需要将模型设置为评估模式,然后遍历测试数据集,对每个样本进行推理,并计算模型的准确率。

# 将模型设置为评估模式
net.eval()

correct = 

http://www.kler.cn/a/571168.html

相关文章:

  • 一周学会Flask3 Python Web开发-WTForms表单验证
  • LeeCode题库第四十三题
  • 20道Redis面试题
  • AI 编译器学习笔记之十六 -- TVM
  • 2025年Linux主力系统选择指南:基于最新生态的深度解析(附2025年发行版对比速查表)
  • 代码贴——堆(二叉树)数据结构
  • 【华为OD机试真题29.9¥】(E卷,100分) - TLV解码(Java Python JS C++ C )
  • 【一个月备战蓝桥算法】递归与递推
  • IEEE33节点搭建matlab模型
  • TMS320F28P550SJ9学习笔记2:Sysconfig 配置与点亮LED
  • Python自动化测试-使用Pandas来高效处理测试数据
  • MYOJ_1102:(洛谷P1540)[NOIP 2010 提高组]机器翻译
  • 责任链模式:优雅处理复杂流程的设计艺术
  • Week3_250303~250309_OI日志(待完善)
  • 利用Git和wget批量下载网页数据
  • 维度建模基础篇:从理论到核心组件解析
  • 国内如何快速拿下微软AI-900!?
  • RabbitMQ 最新版:安装、配置 与Java 接入详细教程
  • 如何利用客户端双向TLS认证保护云上应用安全
  • 认识时钟树