当前位置: 首页 > article >正文

Python加载 TorchScript 格式的 ResNet18 模型分类该模型进行预测并输出预测的类别和置信度

  • 首先加载预训练的 ResNet18 模型。
  • 将模型设置为评估模式,以确保特定层(如 Dropout 和 BatchNorm)在评估时具有确定性的行为。
  • 创建一个形状为 (1, 3, 224, 224) 的随机张量作为示例输入。
  • 使用 torch.jit.trace 函数追踪模型在给定示例输入上的行为,将模型转换为 TorchScript 格式。
  • 保存 TorchScript 格式的模型为 resnet18_torchscript.pt 文件,并打印转换成功的消息。
import torch
import torchvision.models as models

# 加载预训练的 ResNet18 模型
model = models.resnet18(pretrained=True)

# 将模型设置为评估模式
model.eval()

# 创建示例输入张量
example_input = torch.rand(1, 3, 224, 224)

# 使用 torch.jit.trace 追踪模型
traced_model = torch.jit.trace(model, example_input)

# 保存 TorchScript 模型
traced_model.save('resnet18_torchscript.pt')

print("ResNet18 模型已成功转换为 TorchScript 格式并保存。")

定义图像处理函数 process_img

    • process_img 函数接受一个图像路径作为参数。

    • 使用 cv2.imread 读取图像,将图像从 BGR 颜色空间转换为 RGB 颜色空间(因为很多深度学习模型期望输入为 RGB 格式)。

    • 将图像的像素值归一化到 [0, 1] 范围。

    • 使用 cv2.resize 将图像调整为 (224, 224) 的尺寸,这通常是 ResNet18 模型期望的输入尺寸。

    • 使用 np.transpose 将图像的维度顺序从 HWC(Height-Width-Channel)转换为 CWH(Channel-Height-Width),以符合 PyTorch 的输入要求。

    • 使用 np.expand_dims 在批量维度上扩展图像,使其形状变为 (1, C, H, W)

    • 最后将处理后的图像转换为 PyTorch 张量,并指定数据类型为 torch.float32,然后返回该张量。

def process_img(img_path):
    img=cv2.imread(img_path)

    img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img = img / 255
    img=cv2.resize(img,dsize=(224,224))
    img=np.transpose(img,(2,0,1))#HWC-->CWH
    img=np.expand_dims(img,axis=0)
    img=torch.tensor(img,dtype=torch.float32)
    return img

图像预测部分

  1. 定义一个图像路径 img_path,并将其传入 process_img 函数,得到处理后的图像张量 img
    • 使用 torch.jit.load 加载之前保存的 TorchScript 格式的模型。
    • 将处理后的图像张量传入模型进行前向传播,得到输出张量 output
    • 使用 torch.argmax 在输出张量的维度 1 上找到具有最大值的索引,即预测的类别。
    • 最后打印出预测的类别和对应类别的置信度(输出张量中对应类别的值)。
img_path='dog.jpg'
img=process_img(img_path)
model=torch.jit.load('resnet18_torchscript.pt')#还是torchscript格式的
output=model.forward(img)
cls=torch.argmax(output,axis=1)
print('预测的类别是:',cls.item(),'置信度是',output[0][cls].item())

预测图片 

结果如下:

可以去Imgnet官网找对应的网站来查看类别 


http://www.kler.cn/a/289157.html

相关文章:

  • 速盾:高防 CDN 和 CDN 的缓存机制都一样吗?
  • 比ChatGPT更酷的AI工具
  • 力扣515:在每个树行中找最大值
  • 前端神经网络入门(三):深度学习与机器学习的关系、区别及核心理论支撑 - 以Brain.js示例
  • Springboot 启动端口占用如何解决
  • vxe-table 3.10+ 进阶高级用法(一),根据业务需求自定义实现筛选功能
  • 【运维监控】prometheus+node exporter+grafana 监控linux机器运行情况(2)
  • 【wsl2】从C盘迁移到G盘
  • redroid搭建云手机学习笔记(一)
  • C++ ─── List的模拟实现
  • django orm的Q和~Q的数据相加并不一定等于总数
  • Golang | Leetcode Golang题解之第380题O(1)时间插入、删除和获取随机元素
  • [SDK]-按钮静态文本与编辑框控件
  • Vue-cli的使用
  • MySQL三大日志详解
  • 【区块链 + 房产建筑】透明建造系统 | FISCO BCOS应用案例
  • Windows安装docker,启动ollama运行open-webui使用AIGC大模型写周杰伦歌词
  • Unity实战案例 2D小游戏HappyGlass(模拟水珠)
  • 解剖学上合理的分割:通过先验变形显式保持拓扑结构|文献速递--基于深度学习的医学影像病灶分割
  • 域与活动目录
  • Mudo03 vscode配置相应的文件的搜索路径,库文件的搜索路径以及想要的链接库
  • 【Redis之一:下载安装Redis】
  • Java 入门指南:Java 并发编程 —— 并发容器 ConcurrentSkipListMap
  • P7492 [传智杯 #3 决赛] 序列
  • 【MATLAB源码-第157期】基于matlab的海马优化算法(SHO)机器人栅格路径规划,输出做短路径图和适应度曲线。
  • 【安卓13】解决HDMI OUT和耳机等设备接入时会解除静音问题