当前位置: 首页 > article >正文

一文详解YOLOv8多模态目标检测(可见光+红外图像,基于Ultralytics官方代码实现),轻松入门多模态检测领域!

目录

  • 1. 文章主要内容
  • 2. 相关说明
  • 3. 基于YOLOv8的多模态目标检测
    • 3.1 启动运行YOLOv8多模态代码
    • 3.2 详解代码流程(重点)
      • 3.2.1 train.py文件(入口)
      • 3.2.2 engine\model.py文件
      • 3.2.3 engine\trainer.py文件
      • 3.2.4 models\yolo\detect\train.py文件
      • 3.2.5 nn\tasks.py文件
      • 3.2.6 再次回到engine\trainer.py文件
    • 3.3 总结

1. 文章主要内容

       本文主要是详解YOLOv8实现多模态,包括如何启动,以及详细代码部分如何改进,从而让单模态的检测支持为多模态的检测!。基于YOLO的单模态检测赛道已经非常卷,很难出好的论文,这个时候入门多模态检测是非常有必要的!所以,本篇代码分析论文则是入门基于YOLOv8的多模态目标检测的基础之一。

2. 相关说明

       本篇博文代码来源于原博客:YOLOV8多模态(可见光+红外光目标检测任务,基于Ultralytics官方代码实现)。
       原博客中使用的是基于Ultralytics的YOLOv8模型,使用的数据集是DroneVehicle,这里的DroneVehicle数据集是红外-可见光两种模态的无人机目标检测数据集。需要注意的是原数据集中的DroneVehicle数据图像有白边,需要对原数据集进行去白边进行处理。另外,我自己在这里把目标框从旋转框改为水平框,也就是仅仅进行水平框的检测,后续再考虑出一篇旋转框的检测内容。
       注意到:以下代码的相关内容分析需要前提知识:已经了解单模态的YOLOv8代码相关的知识点。

3. 基于YOLOv8的多模态目标检测

       这一块分为两个部分,第一块是启动运行部分,第二块是多模态代码的分析(中期融合,也叫做特征融合)第二块是重点,因为后续想改进代码必须搞懂如何进行模型改进和前向传播的改进等。

3.1 启动运行YOLOv8多模态代码

       从上面原博客中找到代码的地址,或者直接点击这个链接TwoStream_Yolov8源代码,进去看相关README部分,配置好相关环境,然后准备好相关的数据集DroneVehicle以及对应的格式摆放。另外在安装环境的时候一定要运行这个代码pip install -e .,这个代码的作用是将这个项目TwoStream_Yolov8的本地的Ultralytics文件夹进行编译,而不是用环境nn下对应的包。不然就会出现No module Ultralytics相关错误。另外上述编译代码是在项目的根目录进行运行的,大家别搞错。

3.2 详解代码流程(重点)

       这一块的内容主要是从train函数部分,一步步去分析如何构造多模态目标检测的(这里是中期融合,后续的前期融合我准备再出一篇分析),这里先给出一张函数的流程图,下面的内容就是根据这张图来说明,注意我不会讲所有的代码,只会讲牵扯多模态相关需要修改的代码部分。
在这里插入图片描述

3.2.1 train.py文件(入口)

       说明:作者也提供了一份train.py文件,我是用自己的train.py,大差不差。源代码加载yaml文件使用的是绝对路径,我这里通过sys.path.append(“/home/project/TwoStream_Yolov8-main/”),将根目录设定为项目的根目录路径,所以下面加载yaml文件使用相对路径即可!然后,device部分我这里是多卡训练(使用Linux环境,建议使用Linux服务器),如果你是单卡的话改成0即可!
       代码分析:首先加载YOLO模型,这个YOLO模型只是ultralytics/model/yolo/model.py文件中的一个类,这个类继承了Model基类,也就是engine/model.py(这个类是重点)。所以你可以理解第一行代码model只是将yaml文件加载到了基类engine/model.py中的cfg变量当中
       后面第二行代码,调用model.train函数,其中带了data属性,也就是数据集的yaml文件!如下面第二幅图所示:train、train_ir分别为可见光和红外图像的训练集路径。
       注意这里和源代码不同,我这里改了相关代码,主要在路径中的imgRGB和IR部分,如果你也想改为自定义的路径,需要修改ultralytics/data/base.py中的load_image代码,如第三幅图所示。其主要的作用是根据可见光的路径获取红外文件的路径,然后再加载数据。

import warnings
import sys
sys.path.append("/home/project/TwoStream_Yolov8-main/")
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':

    # 加载模型
    model = YOLO('yaml/ADDyolov8n.yaml') # .load('yolov8n.pt')  # 从YAML构建并转移权重
    # 训练模型
    results = model.train(data='data/drone2.yaml', epochs=200, batch=32, device=[0,1])

在这里插入图片描述
在这里插入图片描述

3.2.2 engine\model.py文件

       进入到train这个类中,就跳转到了engine\model.py,然后我们找到在这里self.trainer.train()代码,继续进入,这里注意到我们进入的是下面这个类。
在这里插入图片描述

3.2.3 engine\trainer.py文件

       进入到engine\trainer.py类,找到这一行代码:self._do_train(world_size),进入这个方法,需要注意到world_size是判断训练是几张卡,如果有两张,那么假设你batch_size设置为16,每张卡就是batch_size为8.
       然后再找到这一行代码:self._setup_train(world_size),再次进入找到 ckpt = self.setup_model()方法,然后再次进入可以看到self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) 代码,这里的cfg就是我们的yaml/ADDyolov8n.yaml

3.2.4 models\yolo\detect\train.py文件

       然后再次进入跳转到models\yolo\detect\train.py文件,可以看到model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)这一行代码,说明我们的model使用的是DetectionModel类。

3.2.5 nn\tasks.py文件

       通过DetectionModel进入到这个 nn\tasks.py文件,然后找到这行代码self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose),再进入到parse_model这个函数中,这个parse_model就是通过yaml文件去构造model的结构。
       第一:我们可以看到这两行代码:tx的list表里存放的是输入的通道数,接近着会讲,ty为索引,初始值为0.

  tx=[3,256,256,512,512,max_channels,max_channels] # TODO....
  ty=0

       第二:然后可以看到如下几行的代码:结合上面tx和ty的定义可以知道,这是对输入通道数c1的改变。当yaml文中from也就是f值为-4的时候,输入的通道数要从tx当中去取出,并且当ty不等于0的时候需要乘于width因子。

c1, c2 = ch[f], args[0]
            if f==-4:
                c1=tx[ty]
                if ty!=0:
                    c1=c1*width
            
                c1=int(c1)
                ty+=1

       第三:来结合ADDyolov8n.yaml文件中的代码结合来看,如下所示。可以看出,第一次出现f=-4的时候,是在IR也就是红外分支Conv的时候,如果这个时候我们不用特殊分支进行判断,按照YOLOv8的原逻辑就会从上一层也就是f=-1处理,此时的c1=256,也就是RGB第四层Conv的输出,明显不对。因为IR分支的第一层输入应该也是3,所以我们就搞懂了上述逻辑的代码,这里是一个重点!,这样我们就构造了通过yaml新建model的逻辑。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 5 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  # RGB
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 3, C2f, [128, True]] #2
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8

  # IR
  - [-4, 1, Conv, [64, 3, 2]] # 4 3
  - [-1, 1, Conv, [128, 3, 2]] # 5
  - [-1, 3, C2f, [128, True]] # 6
  - [-1, 1, Conv, [256, 3, 2]] # 7
 
  # Fusion1 使用时记得修改文件block文件将RIFusion置为空
  - [-3,1,RIFusion,[64]] #8

  # RGB
  - [-4, 6, C2f, [256, True]] #9 256
  - [-1, 1, Conv, [512, 3, 2]] # 10
 
  # IR
  - [-4, 6, C2f, [256, True]] # 11
  - [-1, 1, Conv, [512, 3, 2]] # 12
 
  # Fusion2
  - [-3,1,RIFusion,[128]]  #13

  # RGB
  - [-4, 6, C2f, [512, True]] #14
  - [-1, 1, Conv, [1024, 3, 2]] # 15

  # IR
  - [-4, 6, C2f, [512, True]] #16
  - [-1, 1, Conv, [1024, 3, 2]] # 17

  # Fusion3
  - [-3,1,RIFusion,[256]] #18

  # RGB
  - [-4, 3, C2f, [1024, True]] #19
  - [-1, 1, SPPF, [1024, 5]] # 20

  # IR
  - [-4, 3, C2f, [1024, True]] #21
  - [-1, 1, SPPF, [1024, 5]] # 22

  
  - [[9,11], 1, ADD, [1]] # 23
  - [[14,16], 1, ADD, [1]] # 24
  - [[20,22],1, ADD, [1]] # 25

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]] #26
  - [[-1, 24], 1, Concat, [1]] # 27
  - [-1, 3, C2f, [512]] # 28

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]] #29
  - [[-1, 23], 1, Concat, [1]] # 30
  - [-1, 3, C2f, [256]] # 31

  - [-1, 1, Conv, [256, 3, 2]] #32
  - [[-1, 28], 1, Concat, [1]] # 33
  - [-1, 3, C2f, [512]] # 34

  - [-1, 1, Conv, [512, 3, 2]] # 35
  - [[-1, 25], 1, Concat, [1]] # 36
  - [-1, 3, C2f, [1024]] # 37

  - [[31, 34, 37], 1, Detect, [nc]] # 38

3.2.6 再次回到engine\trainer.py文件

       新建完模型之后,回到trainer.py的代码中,我们看到这行代码:这是读取数据集loader的代码,我们先去看看self.trainset, self.trainirset怎么获取的,可以看到在同一个文件中有这样一行代码:self.trainset, self.testset,self.trainirset,self.testirset = self.get_dataset(),再去看看get_dataset()这个代码,可以看到这两行代码:其中的data就是我们最开始train.py中传入的data='data/drone2.yaml',然后获取其中的相关数据集路径,并返回。

self.data = data
        return data["train"], data.get("val"),data["train_ir"],data.get("val_ir") or data.get("test")
   # 读取数据集
   self.train_loader = self.get_dataloader(self.trainset, self.trainirset,batch_size=batch_size, rank=RANK, mode="train")

       再回到self.get_dataloader代码部分,我们进去这个方法,注意是进去第一个get_dataloader。然后可以看到这行代码:dataset = self.build_dataset(dataset_path, datasetir_path,mode, batch_size),再进去这个函数可以看到这行代码, return build_yolo_dataset(self.args, img_path, imgir_path,batch, self.data, mode=mode, rect=mode == "val", stride=gs)。然后再进去build_yolo_dataset这个函数,可以看到我们使用的是YOLODataset这个函数,它需要支持 img_path=img_path和 imgir_path=imgir_path两个路径的输入。然后YOLODataset又是继承了BaseDataset,我们可以看到这行self.cache_images()这行代码,继续进入到这个函数。可以看到一行代码:fcn, storage = (self.cache_images_to_disk, "Disk") if self.cache == "disk" else (self.load_image, "RAM")。**我们再次进入到load_image这个函数中,这个函数通过imir=cv2.imread(f)来从路径中读取数据集。**另外要注意到load_image中的这行代码:im = np.dstack((im, imir)) ,说明输入的数据是以六通道的数据存在的,后续要进行分开处理!当然这里面还有数据增强部分,是牵扯self.transforms = self.build_transforms(hyp=hyp)这一块的代码,大家可以自己看看!
在这里插入图片描述
       好,我们回到get_dataloader的build_dataloader方法,刚刚我们得到了dataset,然后传到build_dataloader中即可得到加载器。

    def get_dataloader(self, dataset_path, datasetir_path,batch_size=16, rank=0, mode="train"):
        """Construct and return dataloader."""
        assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
        with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
            dataset = self.build_dataset(dataset_path, datasetir_path,mode, batch_size)
        shuffle = mode == "train"
        if getattr(dataset, "rect", False) and shuffle:
            LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
            shuffle = False
        workers = self.args.workers if mode == "train" else self.args.workers * 2
        return build_dataloader(dataset, batch_size, workers, shuffle, rank)  # return dataloader

       在trainer.py文件中可以看到 self.loss, self.loss_items = self.model(batch),意思就是将数据以batch批次传进去,我们这里还有一个问题,model的结构构造好了,数据怎么传进去的?我们现在的数据从load_image中是以六通道的形式存在,所以我们得去看BaseModel类的前向传播算法。这个前向传播算法大家可以理解为数据进入到BaseModel中必须执行的一个函数,也就是在task.py中,可以看到def forward中 return self.predict(x, *args, **kwargs)这行代码,然后进去可以看到这行代码: return self._predict_once(x, profile, visualize, embed)。这个_predict_once函数才是处理六通道输入的函数,一步一步来:

       第一:首先将六通道的输入切分开,分别获得rgb和ir的输入,随后将rgb先赋值给x,这是因为在yaml文件中RGB的网络结构在IR的前面。

 y, dt, embeddings = [], [], []  # outputs
        rgb,ir=torch.chunk(x,chunks=2,dim=1) # 红外
        # rgb=x[:, :3, :, :] # 可见光
        x=rgb

       第二:当f==-4的时候,也就是要切换输入了,第一次这个时候将rgb切换为ir。

 if m.f==-4:
                # 跳转另外一个分支
                if isR:
                    x= m(ir)
                    ir=x
                    isR=False
                else :
                    x = m(rgb)  # run
                    rgb=x
                    isR=True

       还有一段代码,这段代码就是正常的执行相关yaml文件,也就是在backbone的ADD融合之前:

            elif m.i<23:
                if isR:
                    x= m(rgb)
                    rgb=x
                else :
                    x = m(ir)  # run
                    ir=x

3.3 总结

       大概就是这些代码,可能有一些细节没讲解,这是属于中期融合,也就是特征级的融合,也是最常见的融合,希望大家能有收获,如果有任何疑问,可以评论区交流!如果可以的话,希望大家多多点赞,收藏,后续会更新相关代码和论文的解读!


http://www.kler.cn/a/500575.html

相关文章:

  • 根据浏览器的不同类型动态加载不同的 CSS 文件
  • 金融项目实战 03|JMeter脚本实现手工接口测试
  • 【微信小程序】回到顶部图标-页面滚动事件 | 漫画-综合实训
  • 深度学习中的EMA技术:原理、实现与实验分析
  • 计算机网络之---传输层的功能
  • conntrack iptables 安全组
  • Oracle 使用dbms_stats.gather_table_stats来进行表analyse,收集表统计信息
  • 《零基础Go语言算法实战》【题目 2-7】defer 关键字特性
  • spring boot 支持jsonp请求
  • 阿里云发现后门webshell,怎么处理,怎么解决?
  • React - router的使用 结合react-redux的路由守卫
  • 依赖网络系统混合级联故障下系统可靠性提高与弹性的组合优化
  • 网络安全 | Web安全常见漏洞和防护经验策略
  • 苍穹外卖及软件开发介绍
  • 基于 B2C 的网上拍卖系统:秒杀与竞价功能的实现
  • ip归属地和手机号是一个地址吗
  • 【微服务】面试 2、负载均衡
  • matlab专栏-模拟滤波器设计
  • Spring——几个常用注解
  • linux服务器安装mysql数据库和nginx
  • 多线程面试相关