当前位置: 首页 > article >正文

YoloV5检测配置多模型

首先自己封装yolov5检测,配置好参数

class YOlOv5:
    def __init__(self, config: dict, logger):
        self.logger = logger
        self.weights = config.get('weights')
        self.imgsz = config.get('imgsz')
        self.conf_thres = config.get('conf_thres')
        self.iou_thres = config.get('iou_thres')
        self.max_det = config.get('max_det')
        self.device = select_device(config.get('device'))
        self.classes = config.get('classes')
        self.agnostic_nms = config.get('agnostic_nms')
        self.augment = config.get('augment')
        self.visualize = config.get('visualize')
        self.line_thickness = config.get('line_thickness')
        self.half = config.get('half')
        self.half &= self.device.type != 'cpu'
        self.dnn = config.get('dnn')
        self.save_result, self.save_detect, self.save_crop = config.get('save_result'), config.get(
            'save_detect'), config.get('save_crop')
        self.detect_save_path = config.get('detect_save_path')
        self.result_save_path = config.get('result_save_path')
        self.labels_waiting_detection = config.get('labels_waiting_detection')
        if self.save_detect:
            self.logger.info(f'已开启保存图片检测结果,路径为:{self.detect_save_path}')
        if self.save_result:
            self.logger.info(f'已开启保存图片处理结果,路径为:{self.result_save_path}')
        if self.save_crop:
            if not os.path.exists('crops'):
                os.makedirs('crops')
        if not os.path.exists(self.detect_save_path):
            os.makedirs(self.detect_save_path)
        if not os.path.exists(self.result_save_path):
            os.makedirs(self.result_save_path)
        self.models_info = {}
        for i, weights in enumerate(self.weights):
            model = attempt_load(weights, map_location=self.device)
            self.models_info[i] = {}
            self.models_info[i]['model'] = model
            self.models_info[i]['names'] = model.module.names if hasattr(model,
                                                                         'module') else model.names  # get class names

        self.logger.info(f'加载模型{self.weights}完成')

        stride = int(self.models_info[0]['model'].stride.max())  # model stride

        if self.half:
            self.models_info[0]['model'].half()  # to FP16
        self.imgsz = check_img_size(self.imgsz, s=stride)  # check image size

        self.logger.info(f'模型标签:{[i["names"] for i in self.models_info.values()]}')
        self.logger.info(f'检测已配置:{list(self.labels_waiting_detection.keys())}')

 主要加载模型部分:

传入的参数weights为列表,其元素是pt文件地址,全部封装到models_info里面,方便取用

格式为:

models_info = {

                        0:{'model':'best.pt','names':['person',...],

                        ...

                        }

 

self.models_info = {}
for i, weights in enumerate(self.weights):
    model = attempt_load(weights, map_location=self.device)
    self.models_info[i] = {}
    self.models_info[i]['model'] = model
    self.models_info[i]['names'] = model.module.names if hasattr(model,
                                                                 'module') else model.names  # get class names

 其次,就是调用模型了,infer需要两个参数,一个是预处理好的图片,一个是模型

预处理图片的方法如下,返回处理后的图片

def process_image(self, image):
    im0s = image
    img = letterbox(im0s, self.imgsz, stride=32, auto=True)[0]
    img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
    img = np.ascontiguousarray(img)
    img = torch.from_numpy(img).to(self.device)
    img = img.half() if self.half else img.float()  # uint8 to fp16/32
    img = img / 255.0  # 0 - 255 to 0.0 - 1.0
    if len(img.shape) == 3:
        img = img[None]  # expand for batch dim
    return img

 然后infer需要将图片和模型接入

def infer(self, image, model):
    pred = model(image, augment=self.augment, visualize=self.visualize)[0]
    pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, self.classes, self.agnostic_nms,
                               max_det=self.max_det)
    return pred, image

 所以在处理时,需要遍历模型文件,具体操作如下,对于标签处理方法就不赘述了

def detect(self, image):
    detect_result_list = []
    img = self.process_image(image)  # image是原图,img是处理后的
    for model_index, model_info in self.models_info.items():
        model = model_info['model']
        names = model_info['names']
        pred, img = self.infer(img, model)
        for i, det in enumerate(pred):

                        。。。


http://www.kler.cn/news/341517.html

相关文章:

  • SpringBoot系列 启动流程
  • fastreport导出PDF后style bold粗体斜体等字体风格不显示的原因
  • LeetCode 54 Spiral Matrix 解题思路和python代码
  • mybatisPlus对于pgSQL中UUID和UUID[]类型的交互
  • 构建高效数据处理桥梁:探索基于数据库驱动的自定义TypeHandler解决方案
  • 基于esp8266的nodemcu实现网页配置wifi功能
  • SpringBoot框架在服装生产管理中的创新应用
  • ANSYS Workbench随机连通孔结构建模
  • 【Cursor教程】探索Cursor颠覆编程体验的创新工具!教程+示例+快捷键
  • Github 2024-10-03Go开源项目日报Top10
  • LeetCode讲解篇之34. 在排序数组中查找元素的第一个和最后一个位置
  • zigbee学习
  • C++-容器适配器- stack、queue、priority_queue和仿函数
  • 重生之我们在ES顶端相遇第 20 章 - Mapping 参数设置大全(进阶)
  • 交叉编译(移植)
  • 深入解析MySQL事务管理:ACID特性与基本操作
  • 文件夹作为普通文件而非子模块管理
  • Unity实现自定义图集(三)
  • 【操作系统】引导(Boot)电脑的奇妙开机过程
  • LeetCode hot100---栈专题(C++语言)