YoloV5检测配置多模型
首先自己封装yolov5检测,配置好参数
class YOlOv5: def __init__(self, config: dict, logger): self.logger = logger self.weights = config.get('weights') self.imgsz = config.get('imgsz') self.conf_thres = config.get('conf_thres') self.iou_thres = config.get('iou_thres') self.max_det = config.get('max_det') self.device = select_device(config.get('device')) self.classes = config.get('classes') self.agnostic_nms = config.get('agnostic_nms') self.augment = config.get('augment') self.visualize = config.get('visualize') self.line_thickness = config.get('line_thickness') self.half = config.get('half') self.half &= self.device.type != 'cpu' self.dnn = config.get('dnn') self.save_result, self.save_detect, self.save_crop = config.get('save_result'), config.get( 'save_detect'), config.get('save_crop') self.detect_save_path = config.get('detect_save_path') self.result_save_path = config.get('result_save_path') self.labels_waiting_detection = config.get('labels_waiting_detection') if self.save_detect: self.logger.info(f'已开启保存图片检测结果,路径为:{self.detect_save_path}') if self.save_result: self.logger.info(f'已开启保存图片处理结果,路径为:{self.result_save_path}') if self.save_crop: if not os.path.exists('crops'): os.makedirs('crops') if not os.path.exists(self.detect_save_path): os.makedirs(self.detect_save_path) if not os.path.exists(self.result_save_path): os.makedirs(self.result_save_path) self.models_info = {} for i, weights in enumerate(self.weights): model = attempt_load(weights, map_location=self.device) self.models_info[i] = {} self.models_info[i]['model'] = model self.models_info[i]['names'] = model.module.names if hasattr(model, 'module') else model.names # get class names self.logger.info(f'加载模型{self.weights}完成') stride = int(self.models_info[0]['model'].stride.max()) # model stride if self.half: self.models_info[0]['model'].half() # to FP16 self.imgsz = check_img_size(self.imgsz, s=stride) # check image size self.logger.info(f'模型标签:{[i["names"] for i in self.models_info.values()]}') self.logger.info(f'检测已配置:{list(self.labels_waiting_detection.keys())}')
主要加载模型部分:
传入的参数weights为列表,其元素是pt文件地址,全部封装到models_info里面,方便取用
格式为:
models_info = {
0:{'model':'best.pt','names':['person',...],
...
}
self.models_info = {} for i, weights in enumerate(self.weights): model = attempt_load(weights, map_location=self.device) self.models_info[i] = {} self.models_info[i]['model'] = model self.models_info[i]['names'] = model.module.names if hasattr(model, 'module') else model.names # get class names
其次,就是调用模型了,infer需要两个参数,一个是预处理好的图片,一个是模型
预处理图片的方法如下,返回处理后的图片
def process_image(self, image): im0s = image img = letterbox(im0s, self.imgsz, stride=32, auto=True)[0] img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB img = np.ascontiguousarray(img) img = torch.from_numpy(img).to(self.device) img = img.half() if self.half else img.float() # uint8 to fp16/32 img = img / 255.0 # 0 - 255 to 0.0 - 1.0 if len(img.shape) == 3: img = img[None] # expand for batch dim return img
然后infer需要将图片和模型接入
def infer(self, image, model): pred = model(image, augment=self.augment, visualize=self.visualize)[0] pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, self.classes, self.agnostic_nms, max_det=self.max_det) return pred, image
所以在处理时,需要遍历模型文件,具体操作如下,对于标签处理方法就不赘述了
def detect(self, image): detect_result_list = [] img = self.process_image(image) # image是原图,img是处理后的 for model_index, model_info in self.models_info.items(): model = model_info['model'] names = model_info['names'] pred, img = self.infer(img, model) for i, det in enumerate(pred):。。。