当前位置: 首页 > article >正文

PyTorch系列教程:评估和推理模式下模型预测

使用PyTorch时,将模型从训练阶段过渡到推理阶段是至关重要的一步。在推理过程中,该模型用于对以前从未见过的新数据进行预测。这种转换的一个重要方面是使用推理模式,它通过禁用仅在训练期间需要的操作来帮助优化模型的性能。

理解推理模式

训练期间需要的某些特征(例如autograd相关操作)来加快张量计算速度。当使用推理模式时,不构建计算图,这减少了内存的使用,并加快了前传过程中的内存分配和释放。

推理模式的引入是对传统 torch.no_grad() 上下文的一种改进,专门针对那些已知不需要梯度的应用场景。虽然这两种方法都能通过不存储梯度信息来节省内存,但推理模式更进一步,进行了更多的优化。
在这里插入图片描述

设置评估模式

PyTorch提供了一种简化的方式来设置评估模式,这可以使用模型的.eval()方法来完成。该方法将模型的状态从训练模式切换到评估模式。这里有一个快速的演示:

import torch
import torch.nn as nn

# Define a simple neural network
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        return self.linear(x)

# Initialize the model
model = SimpleModel()

# Set model to evaluation mode
model.eval()

通过调用model.eval(),某些层(如dropout和批处理归一化)的行为将与训练期间不同,从而确保模型的预测尽可能准确。

禁用梯度计算

在推理过程中,不需要计算梯度,可以节省计算资源,加快评估速度。PyTorch提供了上下文管理器torch.no_grad()来关闭这些计算:

input_tensor = torch.randn(1, 10)  # Example input

# Disabling gradient computation
with torch.no_grad():
    output = model(input_tensor)

print(output)

使用torch.no_grad(), PyTorch在调用模型时防止跟踪历史和未来的计算,使评估过程更快,内存更高效。

设置推理模式

在PyTorch中,实现推理模式非常简单。PyTorch库提供的key方法是torch.inference_mode()。以下是如何利用此功能:

import torch

# Sample PyTorch Model
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

# Initialize model
model = SimpleModel()

# Data for making prediction
input_data = torch.tensor([[5.0]])

# Using Inference Mode
with torch.inference_mode():
    prediction = model(input_data)
    print(f"Prediction: {prediction.item()}")

在上面的代码中,我们首先定义一个简单的线性模型。通过将预测代码逻辑包装在with torch.inference_mode()块中,你可以指示PyTorch在推理模式下运行,从而提高执行速度并减少所定义操作的内存占用。

理解模型输出

PyTorch模型输出原始分数,而不是概率。为了将这些输出转换成适合解释的形式(例如,用于分类的softmax),通常需要额外的步骤:

# Using the same output from the model
# Apply Softmax to convert scores into probabilities
probabilities = torch.softmax(output, dim=1)

# Getting the predicted class
_, predicted_class = torch.max(probabilities, 1)

print('Predicted probabilities:', probabilities)
print('Predicted class:', predicted_class)

torch.softmax函数通常用于将模型输出转换为分类任务,torch.max决定了最可能类别的概率。

加载预训练模型

在探索推理时,通常使用PyTorch社区提供的预训练模型。让我们看看如何加载预训练的模型,比如ResNet:

from torchvision import models

# Load a pre-trained ResNet model
resnet_model = models.resnet18(pretrained=True)

# Set the model to eval mode
resnet_model.eval()
  • models.resnet18 :ResNet(残差网络)是由微软研究院的Kaiming He等人在2015年提出的深度卷积神经网络架构。ResNet通过引入“残差连接”解决了深层网络训练中的梯度消失问题,使得训练更深的网络成为可能。

    • resnet18 是ResNet系列中的一种,包含18层(包括卷积层和全连接层)。
  • 参数 pretrained=True

    • 当设置为 True 时,表示加载在ImageNet数据集上预训练好的模型权重。这些权重是通过在大规模ImageNet数据集上训练得到的,能够有效地提取图像特征。
    • 预训练模型对于迁移学习非常有用,可以在新的任务上快速适应和提升性能,而无需从头开始训练模型。

这段代码的主要目的是加载一个在ImageNet上预训练好的ResNet-18模型,并将其设置为评估模式,以便在后续步骤中进行推理(如图像分类)而不进行参数更新。这在迁移学习、特征提取或任何需要使用预训练模型进行预测的任务中非常常见。这些模型加载完成后即可用于推理,这极大地加快了开发周期,因为它们允许在无需从头构建模型的情况下对最先进的架构进行实验。

比较推理模式和无梯度模式

虽然torch.no_grad()torch.inference_mode() 都通过防止梯度跟踪来优化模型的推理过程,但它们在用例和效率上存在差异。考虑一下这个简单的时间比较:

import time

# Experiment setup
iterations = 1000

# Measure with no_grad
start_time = time.time()
with torch.no_grad():
    for _ in range(iterations):
        _ = model(input_data)
end_time = time.time()
print(f'Using no_grad: {end_time - start_time:.5f} seconds')

# Measure with inference_mode
start_time = time.time()
with torch.inference_mode():
    for _ in range(iterations):
        _ = model(input_data)
end_time = time.time()
print(f'Using inference_mode: {end_time - start_time:.5f} seconds')

上面的脚本说明了如何比较两种模式下多次运行的时间,以查看执行时间的差异。通常,期望torch.inference_mode()通过比torch.no_grad()更有效地执行来提供更好的性能。

  • 用例和注意事项

在大多数生产场景中,应该使用Inference Mode,特别是在部署预测速度至关重要的模型时。这包括实时数据处理应用程序、嵌入式设备和在资源受限环境中运行的应用程序。但是,请记住,这不会阻止自动梯度,因此不应该在需要计算梯度的情况下使用它,例如在训练期间或当您仍然需要进一步操作的梯度信息时。

最后总结

评估模式是在实际应用程序中有效部署和利用PyTorch模型的一个关键方面。通过利用model.eval()、torch.no_grad(),并了解如何处理模型输出,你可以显著地最大化机器学习应用程序的性能。

将推理模式集成到PyTorch应用程序中可以显著提高性能。随着模型变得越来越复杂,数据集越来越大,优化预测时间的需求变得越来越重要。凭借PyTorch API的简单性,利用这些优化有效地与未来高效,高性能的机器学习模型保持一致。


http://www.kler.cn/a/571774.html

相关文章:

  • post get 给后端传参数
  • 爬虫系列之发送请求与响应《一》
  • 通俗版解释:分布式和微服务就像开餐厅
  • sa-token全局过滤器之写法优化(包含设置Order属性)
  • HiRT:利用分层机器人Transformer 增强机器人控制
  • 企业级Python后端数据库使用指南(简略版)
  • 计算机视觉算法实战——医学影像分割(主页有源码)
  • 深入解析Tiktokenizer:大语言模型中核心分词技术的原理与架构
  • Spring线程池学习笔记
  • [LeetCode]day33 150.逆波兰式求表达值 + 239.滑动窗口最大值
  • STM32MP1xx的启动流程
  • 【数据结构与算法】常见数据结构与算法在JDK和Spring中的实现:源码解析与实战代码!
  • Arm64架构的Linux服务器安装jdk8
  • 珈和科技亮相CCTV-13《新闻直播间》,AI多模态农业大模型引领智慧农业新变革
  • 【蓝桥杯集训·每日一题2025】 AcWing 5526. 平衡细菌 python
  • 最新Flutter导航拦截PopScope使用
  • 国家网络安全通报中心:大模型工具Ollama存在安全风险
  • Ubuntu的tmux配置
  • Delphi连接MySql数据库房
  • 高效玩转 PDF:实用的分割、合并操作详解