当前位置: 首页 > article >正文

DINO对比去噪训练代码分析

目标检测算法DINO(DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection)对DN-DETR的去噪训练进行扩展——对GT box进行采用较大的扰动以获得负样本(DN-DETR中对GT box添加了细微扰动只构造了正样本)。接下来结合DINO代码看一下对比去噪训练分支中正、负样本是如何生成的。

代码地址:DINO代码


用于生成对比去噪正负样本的函数在DINO类的forward方法中被调用,名为prepare_for_dn,代码在项目中的路径为:DINO/models/dino/dn_components.py

 一、prepare_for_dn的调用

以下代码是prepare_for_dn()在DINO类的forward方法中的调用,每个参数的含义如下:

if self.dn_number > 0 and targets is not None:
    input_query_label, input_query_bbox, attn_mask, dn_meta = \
        prepare_for_cdn(
        dn_args=(targets, self.dn_number, self.dn_label_noise_ratio, self.dn_box_noise_scale), 
        training=self.training, 
        num_queries=self.num_queries, 
        num_classes=self.num_classes, 
        hidden_dim=self.hidden_dim, 
        label_enc=self.label_enc)
  • dn_args:去噪训练的相关参数,包括标图像中的标注信息targets、去噪组数量self.dn_number(默认值100)、标签翻转概率self.dn_label_noise_ratio(默认值0.5)、box扰动尺度self.dn_label_noise_ratio(默认值0.4);
  • training:是否是训练过程,只有训练过程才有去噪部分;
  • num_queries:query的数量,默认值为900,用于生成attention mask,以阻止自注意力操作中去噪部分和匹配部分的信息泄露;
  • hidden_dim:嵌入向量的维度,默认为256维;
  • label_enc:nn.Embedding(dn_labelbook_size + 1, hidden_dim)初始化的实例,用于将label编码为向量;

二、prepare_for_dn代码解析

代码逻辑主要分为训练推理两部分,由一个if-else语句构成,之前提到过training参数用于判断是训练过程还是推理过程,推理过程的代码很简单,如下所示:

device = dn_args[0][0]['boxes'].device
    # 训练阶段,在此省略
    if training:
        pass
    
    # 推理阶段
    else:

        input_query_label = None
        input_query_bbox = None
        attn_mask = None
        dn_meta = None

    return input_query_label, input_query_bbox, attn_mask, dn_meta

 推理阶段不需要对比去噪,因此相关的返回值都是None。


接下来重点看一下训练部分。由于训练部分代码较长,将按顺序分成若干部分进行讲解。

第一部分

主要逻辑:

  • 对dn_args参数解包,得到相关参数;
  • 确定去噪组数量,不了解去噪组概念的同学可以阅读DN-DETR论文或者我的博文ND-DETR:通过引入Query去噪加速DETR训练;
  • 将label和box扩展到每个去噪组中。

阅读代码时我发现其中定义的很多变量在后续并未使用,我怀疑可能是作者重构代码时没有完全清理干净,没有使用的变量我都在下面的代码中注释了,大家可以忽略相关变量。

# dn_number=100, label_noise_ratio=0.5, box_noise_scale=0.4
targets, dn_number, label_noise_ratio, box_noise_scale = dn_args
# positive and negative dn queries
dn_number = dn_number * 2
# 每个tensor元素包含<目标框个数>个1
known = [(torch.ones_like(t['labels'])).to(device) for t in targets]
batch_size = len(known)
# 列表,每个元素表示每张图像中box的数量
known_num = [sum(k) for k in known]
if int(max(known_num)) == 0:
    dn_number = 1
else:
    if dn_number >= 100:
        # 根据batch中最大label数动态设置,决定去噪组数
        dn_number = dn_number // (int(max(known_num) * 2))  
    elif dn_number < 1:
        dn_number = 1
if dn_number == 0:
    dn_number = 1

# 后续未用到,忽略
unmask_bbox = unmask_label = torch.cat(known)

# gt labels(batch_gt_num,)整合所有图像的label
labels = torch.cat([t['labels'] for t in targets])
# gt boxes(batch_gt_num, 4)整合所有图像的box
boxes = torch.cat([t['boxes'] for t in targets])
# 每个label对应图像的索引
batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])

# known_indice后续并没用到,忽略
known_indice = torch.nonzero(unmask_label + unmask_bbox)
known_indice = known_indice.view(-1)
known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)

# 将label, batch id, bbox扩展到每个去噪组
# 维度为(200,)
known_labels = labels.repeat(2 * dn_number, 1).view(-1)
known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)
# 维度(200, 4)
known_bboxs = boxes.repeat(2 * dn_number, 1)

# 克隆label和box数据,用于加噪
known_labels_expaned = known_labels.clone()
known_bbox_expand = known_bboxs.clone()

细节分析:

dn_number含义的变化:

  • dn_number初始值为100,其含义为:假设一个batch的图像中共包含100个GT box;
  • dn_number = dn_number × 2 表示每个GT box都会生成一个正样本和一个负样本,此时其含义为:一个batch中去噪过程中产生的query(正样本和负样本)的总数;
  • dn_number = dn_number // (int(max(known_num) * 2)),known_num中存储了每张图像中目标的个数,int(max(known_num)表示一个batch中单张图像包含的最大目标数,dn_number此时的含义表示去噪的组数。

在DN-DETR中去噪组数是固定的,根据模型的规模可以设置为5或者10。由于DETR-like模型采用mini-batch训练,每张图像去噪query的数量需要进行padding以便与batch中最大的那个保持一致。

举个例子:在DN-DETR中,假设batch size为2,去噪组数为1,图像A中有1个目标,图像B中有5个目标,由于每个GT box都会产生一个去噪query,所以图像A会产生1个去噪query,图像B会产生5个去噪query。但是由于需要进行并行操作,数据维度要保持一致,所以图像A需要padding成5个去噪query。考虑到COCO数据集中单张图像包含的目标数量从1到80不等,因此DN-DETR去噪组数量固定的这种设计效率低且会导致额外的内存消耗。

为了解决上述问题,DINO提出固定去噪query的数量(例如100*2),根据图像中目标数量动态地调整去噪组的数量。还是上面的例子,DINO中去噪组的数量就是(100*2)/max(1,5) * 2 = 20。

tensor.repeat()用法

第二部分

主要逻辑:

  • 对一部分label进行随机翻转;
  • 确定正样本和负样本的索引。
# noise on the label, 0.5 by config
if label_noise_ratio > 0:
    # 随机生成0~1
    p = torch.rand_like(known_labels_expaned.float())
    # 随机对0.5*label_noise_ratio的label进行翻转,确定需要进行label翻转的索引
    chosen_indice = torch.nonzero(p < (label_noise_ratio * 0.5)).view(-1)
    # 对需要翻转label的索引位置随机翻转
    new_label = torch.randint_like(chosen_indice, 0, num_classes)  
    # 用随机选择的label替换原来的label
    known_labels_expaned.scatter_(0, chosen_indice, new_label)
 
# 这批图像中包含最多目标的图像,其目标的数量 
single_pad = int(max(known_num))  

# 其实就是固定值200,用于后面构建attention mask
pad_size = int(single_pad * 2 * dn_number) 

# 确定正样本索引
# (dn_number, batch_gt_num)
positive_idx = torch.tensor(range(len(boxes))).long().to(device).unsqueeze(0).repeat(dn_number, 1)
positive_idx += (torch.tensor(range(dn_number)) * len(boxes) * 2).long().to(device).unsqueeze(1)
positive_idx = positive_idx.flatten()
# 确定负样本索引
negative_idx = positive_idx + len(boxes)

变量含义说明:

  • label_noise_ratio:标签翻转的比例(实际翻转的比例是0.5*label_noise_ratio)。DN-DETR中对于GT box添加的噪声分为两部分,一部分是对box的坐标添加偏移量,另一部分是对box的label添加噪声,对label添加噪声的方式就是随机将box的label翻转成其他类别。比如label_noise_ratio=1,那就是对大概200*1*0.5=100个query的label改变,如果label_noise_ratio=0.5,那就是对大概200*0.5*0.5=50个query的label改变。

第三部分

目的:对box加噪,即分别对正负样本的box坐标添加偏移量。

代码逻辑整体比较简单:

  • 把box中坐标形式从中心点+高宽的形式转化成左上角+右下角坐标的形式;
  • 确定两个坐标点分别在x方向和y方向的偏移量的大小和方向(1还是-1);
  • 添加偏移量;
  • 把坐标转化为中心点和宽高的形式。
# noise on the box
if box_noise_scale > 0:
    # 转换box坐标形式:中心点+高宽->左上角+右下角
    known_bbox_ = torch.zeros_like(known_bboxs)
    known_bbox_[:, :2] = known_bboxs[:, :2] - known_bboxs[:, 2:] / 2
    known_bbox_[:, 2:] = known_bboxs[:, :2] + known_bboxs[:, 2:] / 2
    # 计算边界框在各维度上的变化范围,为后续边界框添加噪声作为参考,确保添加的噪声不会超出合理范围
    diff = torch.zeros_like(known_bboxs)
    diff[:, :2] = known_bboxs[:, 2:] / 2
    diff[:, 2:] = known_bboxs[:, 2:] / 2  # [w/2, h/2, w/2, h/2]
    # 每个坐标点移动的方向,1或-1
    rand_sign = torch.randint_like(known_bboxs, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
    rand_part = torch.rand_like(known_bboxs)  # 获取随机偏移量0~1
    rand_part[negative_idx] += 1.0  # 负样本位置添加更大的噪声1~2
    rand_part *= rand_sign  # 赋予偏移量方向
    known_bbox_ = known_bbox_ + torch.mul(rand_part, diff).to(device) * box_noise_scale
    known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0)
    # 左上角,右下角->中心点和宽高
    known_bbox_expand[:, :2] = (known_bbox_[:, :2] + known_bbox_[:, 2:]) / 2
    known_bbox_expand[:, 2:] = known_bbox_[:, 2:] - known_bbox_[:, :2]

其中偏移量的大小受以下因素的影响:

  • 基础偏移量:x方向是0.5w,y方向是0.5h,w和h分别是box的宽和高;
  • box_noise_scale:控制box噪声的尺度系数,配置文件中的默认值为0.4;
  • 随机偏移量:在正样本位置产生的随机偏移量为0~1,负样本位置产生的随机偏移量为1~2;

所以基于上述信息,正样本左上角和右下角坐标在x方向的偏移量大小范围为(0,0.2w),在y方向上偏移量大小范围为(0,0.2h);负样本在x方向偏移量大小范围为(0.2w,0.4w),在y方向偏移量大小范围为(0.2h,0.4h)。

第四部分

构建去噪所需的query(一共200个),其中包括类别嵌入(label embedding)和位置嵌入(box embedding)。

m = known_labels_expaned.long().to(device)
# 对添加噪声的label进行编码(2 * dn_number * batch_gt_num, 256)
input_label_embed = label_enc(m)
input_bbox_embed = inverse_sigmoid(known_bbox_expand)
# (200, 256)
padding_label = torch.zeros(pad_size, hidden_dim).to(device)
padding_bbox = torch.zeros(pad_size, 4).to(device)
# (bs, 200,256)
input_query_label = padding_label.repeat(batch_size, 1, 1)
input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)

map_known_indice = torch.tensor([]).to(device)
if len(known_num):  # 每个图像中的目标数量
    map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num])
    map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(2 * dn_number)]).long()
if len(known_bid):
    # 在对应batch、对应query索引位置赋值
    input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed
    input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed

第五部分

构建attention mask,目的是用于在decoder中计算自注意力时防止信息泄露;构建dn_meta字典存储相关信息,这部分省略,感兴趣的同学可自行学习。

tgt_size = pad_size + num_queries
attn_mask = torch.ones(tgt_size, tgt_size).to(device) < 0
# match query cannot see the reconstruct
attn_mask[pad_size:, :pad_size] = True
# reconstruct cannot see each other
for i in range(dn_number):
    if i == 0:
        attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
    if i == dn_number - 1:
        attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * i * 2] = True
    else:
        attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
        attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * 2 * i] = True

dn_meta = {
    'pad_size': pad_size,
    'num_dn_group': dn_number,
}

总结

上述内容就是对DINO中如何构建对比去噪分支中的正负样本相关代码的解析,由于本人文笔有限,可能在具体变量含义的表述上不够清晰,大家结合博客内容自行调试代码,了解代码的含义和作用,加深自己的理解。文中如有错误或歧义之处欢迎在评论区指出,大家也可以在评论区分享自己的理解。

另外像说一下阅读DINO源码的一些个人感悟。通常我们会先阅读论文,了解模型或者某个算法的大致思路,然后再结合代码查看具体的实现过程。在阅读代码时一定要保持宏观的思路,即这段代码的整体思路和作用是什么,不然就进入“不知庐山真面目,只缘身在此山中”的困境。切忌过于深入的钻研某个具体的操作导致思路陷入局部的“思维困境”;其次就是尽量搞懂代码中每个变量的含义和作用,用清楚的语言进行注释,方便理解和复盘;最后就是调试代码,查看变量的数值和维度,各个维度的含义,加深理解。


http://www.kler.cn/a/449168.html

相关文章:

  • ARM异常处理 M33
  • Docker Compose 安装 Harbor
  • Java重要面试名词整理(一):性能调优
  • 微服务——技术选型与框架
  • linux springboot项目启动端口被占用 Port 8901 was already in use.
  • ASP.NET|日常开发中数据集合详解
  • 范德蒙矩阵(Vandermonde 矩阵)简介:意义、用途及编程应用
  • 图学习新突破:一个统一框架连接空域和频域
  • 《开启微服务之旅:Spring Boot 从入门到实践》(一)
  • 短视频矩阵源码开发部署全解析
  • CentOS修改hostname,导致无法连接(网络不工作)
  • 动手学深度学习-深度学习计算-1层和块
  • 如何实现圆形头像功能
  • 【IC】TSMC先进工艺发展历程--从N5到A16
  • 统信UOS(1070)系统如何进入root用户模式下操作
  • Java 实现日志文件大小限制及管理——以 Python Logging 为启示
  • redis编译安装(版本6.2.6)
  • 练14:DFS基础
  • [python SQLAlchemy数据库操作入门]-03.为行情设计数据库模型
  • 华为云语音交互SIS的使用案例(文字转语音-详细教程)
  • 【多线程进阶】重要!!!
  • 音视频学习(二十四):hls协议
  • 如何理解TCP/IP协议?如何理解TCP/IP协议是什么?
  • Unable to create data directory /var/lib/zookeeper/log/version-2
  • java 对mongodb操作封装工具类
  • Tomcat负载均衡全解析