当前位置: 首页 > article >正文

分类模型onnx推理,并生成混淆矩阵

废话不多说直接上代码

import onnxruntime
import numpy as np
import os
import cv2
import argparse
import time
import shutil
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

labels = ["0", "1", "2", "3", "4", "5", "6", "7"]

def sigmoid(x):
    """Sigmoid function for a scalar or NumPy array."""
    return 1 / (1 + np.exp(-x))

def getFileList(dir, Filelist, ext=None):
    """
    获取文件夹及其子文件夹中文件列表
    输入 dir:文件夹根目录
    输入 ext: 扩展名
    返回: 文件路径列表
    """
    newDir = dir
    if os.path.isfile(dir):
        if ext is None:
            Filelist.append(dir)
        else:
            if ext in dir[-3:]:
                Filelist.append(dir)

    elif os.path.isdir(dir):
        for s in os.listdir(dir):
            newDir = os.path.join(dir, s)
            getFileList(newDir, Filelist, ext)

    return Filelist

def read_image(image_path, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    src = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)
    image = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (224, 224))
    image = image.astype(np.float32)
    image = image / 255.0
    image = image.transpose(2, 0, 1)
    mean = np.array(mean, dtype=np.float32).reshape((3,1,1))
    std = np.array(std, dtype=np.float32).reshape((3,1,1))
    # 对图像进行归一化
    normalized_image = (image - mean) / std
    normalized_image = np.expand_dims(normalized_image, axis=0)
    return normalized_image, src

def load_onnx_model(model_path):
    providers = ['CUDAExecutionProvider']  # 使用 GPU
    # providers = ['CPUExecutionProvider']
    options = onnxruntime.SessionOptions()
    options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
    options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
    session = onnxruntime.InferenceSession(model_path, options, providers=providers)
    print("ONNX模型已成功加载。")
    return session

def main(image_path, session):
    image, src = read_image(image_path)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    pred = session.run([output_name], {input_name: image})[0]
    pred = np.squeeze(pred)
    pred = [sigmoid(x) for x in pred]
    return pred.index(max(pred)), max(pred), labels[pred.index(max(pred))]

def plot_confusion_matrix(y_true, y_pred, labels):
    """
    绘制混淆矩阵
    输入 y_true: 真实标签
    输入 y_pred: 预测标签
    输入 labels: 标签名称
    """
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    disp.plot(cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.show()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--images_path', type=str, default="/home/workspace/temp/test_result/kaideng", help='images_path')
    parser.add_argument('--model_path', type=str, default="/home/workspace/temp/47-val-Loss-0.0203-Acc-0.9942.onnx", help='model_path')
    args = parser.parse_args()
    img_list = []
    img_list = getFileList(args.images_path, img_list)
    count = 0
    session = load_onnx_model(args.model_path)
    start = time.time()
    y_true = []
    y_pred = []
    count_time = 0
    for img in img_list:
        #true_label = int(img.split('/')[-2].split('-')[0])
        true_label = img.split('/')[-3] #这一句代码是获取图像类别文件夹的名称,具体索引需要修改
        start_1 = time.time()
        predicted_index, score, label = main(img, session)
        count_time += time.time() - start_1
        y_true.append(true_label)
        #y_pred.append(predicted_index)
        y_pred.append(label)
        if label == true_label:
            count += 1
        # else:
        #     dst_path = img.replace('test', 'test_out')
        #     dst_dir = os.path.dirname(dst_path)
        #     if not os.path.exists(dst_dir):
        #         os.makedirs(dst_dir)
        #     shutil.copy(img, dst_path.replace('.jpg', "-" + label + '.jpg'))

    accuracy = count / len(img_list) * 100
    print(f"Accuracy: {accuracy:.2f}%")
    print(f"Correct predictions: {count}, Total images: {len(img_list)}")
    print(f"Time taken: {time.time() - start:.2f} seconds")
    print("推理", len(img_list), "张图像用时", count_time)
    # 绘制混淆矩阵
    plot_confusion_matrix(y_true, y_pred, labels)

1.要确保图像类别文件夹的名称和labels列表相对应,不然无法生成混淆矩阵

2.read_image函数中的预处理方式要和训练时的一致,请根据个人需要修改代码


http://www.kler.cn/a/380355.html

相关文章:

  • PostgreSQL (八) 创建分区
  • 解决方案 | 部署更快,自动化程度高!TOSUN同星线控底盘解决方案
  • Navicat 17 功能简介 | 转储SQL文件
  • 独孤思维:工作被骂,副业停滞,算个屁
  • Spring1(初始Spring 解耦实现 SpringIOC SpringDI Spring常见面试题)
  • 操作系统(10) (并发(2)------基于软件/硬件/操作系统层面解决两个进程之间的临界区问题/抢占式/非抢占式内核)
  • 如何在本地Linux服务器搭建WordPress网站结合内网穿透随时随地可访问
  • 使用 Python 中的 pydub实现 M4A 转 MP3 转换器
  • element-plus按需引入报错IconsResolver is not a function
  • 经纬恒润车载TSN网络测试仪TestBase-ATT全新上线!
  • C#、C和C++的主要区别
  • Python | Leetcode Python题解之第530题二叉搜索树的最小绝对差
  • 将Notepad++添加到右键菜单【一招实现】
  • Rust 力扣 - 1297. 子串的最大出现次数
  • 使用python爬取某新闻网并进行数据分析
  • 【论文阅读笔记】Wavelet Convolutions for Large Receptive Fields
  • 论文阅读(一种基于球面投影和特征提取的岩石点云快速配准算法)
  • [ DOS 命令基础 4 ] DOS 命令命令详解-端口进程相关命令
  • 【ROS2】hbm_img_msgs/msg/HbmMsg1080P 转 opencv cv::Mat
  • 江协科技STM32学习- P32 MPU6050
  • PHP不良事件上报系统源码,医院安全不良事件管理系统,基于 vue2+element+ laravel框架开发
  • 前端页面整屏滚动fullpage.js简单使用
  • 儿童安全座椅行业全面深入分析
  • 【Linux】将 bin 目录添加到环境变量 LD_LIBRARY_PATH
  • 【【简单systyem verilog 语言学习使用二--- 新adder加法器 】】
  • 【Rust中的错误处理】