DINO对比去噪训练代码分析
目标检测算法DINO(DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection)对DN-DETR的去噪训练进行扩展——对GT box进行采用较大的扰动以获得负样本(DN-DETR中对GT box添加了细微扰动只构造了正样本)。接下来结合DINO代码看一下对比去噪训练分支中正、负样本是如何生成的。
代码地址:DINO代码
用于生成对比去噪正负样本的函数在DINO类的forward方法中被调用,名为prepare_for_dn,代码在项目中的路径为:DINO/models/dino/dn_components.py
一、prepare_for_dn的调用
以下代码是prepare_for_dn()在DINO类的forward方法中的调用,每个参数的含义如下:
if self.dn_number > 0 and targets is not None:
input_query_label, input_query_bbox, attn_mask, dn_meta = \
prepare_for_cdn(
dn_args=(targets, self.dn_number, self.dn_label_noise_ratio, self.dn_box_noise_scale),
training=self.training,
num_queries=self.num_queries,
num_classes=self.num_classes,
hidden_dim=self.hidden_dim,
label_enc=self.label_enc)
- dn_args:去噪训练的相关参数,包括标图像中的标注信息targets、去噪组数量self.dn_number(默认值100)、标签翻转概率self.dn_label_noise_ratio(默认值0.5)、box扰动尺度self.dn_label_noise_ratio(默认值0.4);
- training:是否是训练过程,只有训练过程才有去噪部分;
- num_queries:query的数量,默认值为900,用于生成attention mask,以阻止自注意力操作中去噪部分和匹配部分的信息泄露;
- hidden_dim:嵌入向量的维度,默认为256维;
- label_enc:nn.Embedding(dn_labelbook_size + 1, hidden_dim)初始化的实例,用于将label编码为向量;
二、prepare_for_dn代码解析
代码逻辑主要分为训练和推理两部分,由一个if-else语句构成,之前提到过training参数用于判断是训练过程还是推理过程,推理过程的代码很简单,如下所示:
device = dn_args[0][0]['boxes'].device
# 训练阶段,在此省略
if training:
pass
# 推理阶段
else:
input_query_label = None
input_query_bbox = None
attn_mask = None
dn_meta = None
return input_query_label, input_query_bbox, attn_mask, dn_meta
推理阶段不需要对比去噪,因此相关的返回值都是None。
接下来重点看一下训练部分。由于训练部分代码较长,将按顺序分成若干部分进行讲解。
第一部分
主要逻辑:
- 对dn_args参数解包,得到相关参数;
- 确定去噪组数量,不了解去噪组概念的同学可以阅读DN-DETR论文或者我的博文ND-DETR:通过引入Query去噪加速DETR训练;
- 将label和box扩展到每个去噪组中。
阅读代码时我发现其中定义的很多变量在后续并未使用,我怀疑可能是作者重构代码时没有完全清理干净,没有使用的变量我都在下面的代码中注释了,大家可以忽略相关变量。
# dn_number=100, label_noise_ratio=0.5, box_noise_scale=0.4
targets, dn_number, label_noise_ratio, box_noise_scale = dn_args
# positive and negative dn queries
dn_number = dn_number * 2
# 每个tensor元素包含<目标框个数>个1
known = [(torch.ones_like(t['labels'])).to(device) for t in targets]
batch_size = len(known)
# 列表,每个元素表示每张图像中box的数量
known_num = [sum(k) for k in known]
if int(max(known_num)) == 0:
dn_number = 1
else:
if dn_number >= 100:
# 根据batch中最大label数动态设置,决定去噪组数
dn_number = dn_number // (int(max(known_num) * 2))
elif dn_number < 1:
dn_number = 1
if dn_number == 0:
dn_number = 1
# 后续未用到,忽略
unmask_bbox = unmask_label = torch.cat(known)
# gt labels(batch_gt_num,)整合所有图像的label
labels = torch.cat([t['labels'] for t in targets])
# gt boxes(batch_gt_num, 4)整合所有图像的box
boxes = torch.cat([t['boxes'] for t in targets])
# 每个label对应图像的索引
batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])
# known_indice后续并没用到,忽略
known_indice = torch.nonzero(unmask_label + unmask_bbox)
known_indice = known_indice.view(-1)
known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)
# 将label, batch id, bbox扩展到每个去噪组
# 维度为(200,)
known_labels = labels.repeat(2 * dn_number, 1).view(-1)
known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)
# 维度(200, 4)
known_bboxs = boxes.repeat(2 * dn_number, 1)
# 克隆label和box数据,用于加噪
known_labels_expaned = known_labels.clone()
known_bbox_expand = known_bboxs.clone()
细节分析:
dn_number含义的变化:
- dn_number初始值为100,其含义为:假设一个batch的图像中共包含100个GT box;
- dn_number = dn_number × 2 表示每个GT box都会生成一个正样本和一个负样本,此时其含义为:一个batch中去噪过程中产生的query(正样本和负样本)的总数;
- dn_number = dn_number // (int(max(known_num) * 2)),known_num中存储了每张图像中目标的个数,int(max(known_num)表示一个batch中单张图像包含的最大目标数,dn_number此时的含义表示去噪的组数。
在DN-DETR中去噪组数是固定的,根据模型的规模可以设置为5或者10。由于DETR-like模型采用mini-batch训练,每张图像去噪query的数量需要进行padding以便与batch中最大的那个保持一致。
举个例子:在DN-DETR中,假设batch size为2,去噪组数为1,图像A中有1个目标,图像B中有5个目标,由于每个GT box都会产生一个去噪query,所以图像A会产生1个去噪query,图像B会产生5个去噪query。但是由于需要进行并行操作,数据维度要保持一致,所以图像A需要padding成5个去噪query。考虑到COCO数据集中单张图像包含的目标数量从1到80不等,因此DN-DETR去噪组数量固定的这种设计效率低且会导致额外的内存消耗。
为了解决上述问题,DINO提出固定去噪query的数量(例如100*2),根据图像中目标数量动态地调整去噪组的数量。还是上面的例子,DINO中去噪组的数量就是(100*2)/max(1,5) * 2 = 20。
tensor.repeat()用法
第二部分
主要逻辑:
- 对一部分label进行随机翻转;
- 确定正样本和负样本的索引。
# noise on the label, 0.5 by config
if label_noise_ratio > 0:
# 随机生成0~1
p = torch.rand_like(known_labels_expaned.float())
# 随机对0.5*label_noise_ratio的label进行翻转,确定需要进行label翻转的索引
chosen_indice = torch.nonzero(p < (label_noise_ratio * 0.5)).view(-1)
# 对需要翻转label的索引位置随机翻转
new_label = torch.randint_like(chosen_indice, 0, num_classes)
# 用随机选择的label替换原来的label
known_labels_expaned.scatter_(0, chosen_indice, new_label)
# 这批图像中包含最多目标的图像,其目标的数量
single_pad = int(max(known_num))
# 其实就是固定值200,用于后面构建attention mask
pad_size = int(single_pad * 2 * dn_number)
# 确定正样本索引
# (dn_number, batch_gt_num)
positive_idx = torch.tensor(range(len(boxes))).long().to(device).unsqueeze(0).repeat(dn_number, 1)
positive_idx += (torch.tensor(range(dn_number)) * len(boxes) * 2).long().to(device).unsqueeze(1)
positive_idx = positive_idx.flatten()
# 确定负样本索引
negative_idx = positive_idx + len(boxes)
变量含义说明:
- label_noise_ratio:标签翻转的比例(实际翻转的比例是0.5*label_noise_ratio)。DN-DETR中对于GT box添加的噪声分为两部分,一部分是对box的坐标添加偏移量,另一部分是对box的label添加噪声,对label添加噪声的方式就是随机将box的label翻转成其他类别。比如label_noise_ratio=1,那就是对大概200*1*0.5=100个query的label改变,如果label_noise_ratio=0.5,那就是对大概200*0.5*0.5=50个query的label改变。
第三部分
目的:对box加噪,即分别对正负样本的box坐标添加偏移量。
代码逻辑整体比较简单:
- 把box中坐标形式从中心点+高宽的形式转化成左上角+右下角坐标的形式;
- 确定两个坐标点分别在x方向和y方向的偏移量的大小和方向(1还是-1);
- 添加偏移量;
- 把坐标转化为中心点和宽高的形式。
# noise on the box
if box_noise_scale > 0:
# 转换box坐标形式:中心点+高宽->左上角+右下角
known_bbox_ = torch.zeros_like(known_bboxs)
known_bbox_[:, :2] = known_bboxs[:, :2] - known_bboxs[:, 2:] / 2
known_bbox_[:, 2:] = known_bboxs[:, :2] + known_bboxs[:, 2:] / 2
# 计算边界框在各维度上的变化范围,为后续边界框添加噪声作为参考,确保添加的噪声不会超出合理范围
diff = torch.zeros_like(known_bboxs)
diff[:, :2] = known_bboxs[:, 2:] / 2
diff[:, 2:] = known_bboxs[:, 2:] / 2 # [w/2, h/2, w/2, h/2]
# 每个坐标点移动的方向,1或-1
rand_sign = torch.randint_like(known_bboxs, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
rand_part = torch.rand_like(known_bboxs) # 获取随机偏移量0~1
rand_part[negative_idx] += 1.0 # 负样本位置添加更大的噪声1~2
rand_part *= rand_sign # 赋予偏移量方向
known_bbox_ = known_bbox_ + torch.mul(rand_part, diff).to(device) * box_noise_scale
known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0)
# 左上角,右下角->中心点和宽高
known_bbox_expand[:, :2] = (known_bbox_[:, :2] + known_bbox_[:, 2:]) / 2
known_bbox_expand[:, 2:] = known_bbox_[:, 2:] - known_bbox_[:, :2]
其中偏移量的大小受以下因素的影响:
- 基础偏移量:x方向是0.5w,y方向是0.5h,w和h分别是box的宽和高;
- box_noise_scale:控制box噪声的尺度系数,配置文件中的默认值为0.4;
- 随机偏移量:在正样本位置产生的随机偏移量为0~1,负样本位置产生的随机偏移量为1~2;
所以基于上述信息,正样本左上角和右下角坐标在x方向的偏移量大小范围为(0,0.2w),在y方向上偏移量大小范围为(0,0.2h);负样本在x方向偏移量大小范围为(0.2w,0.4w),在y方向偏移量大小范围为(0.2h,0.4h)。
第四部分
构建去噪所需的query(一共200个),其中包括类别嵌入(label embedding)和位置嵌入(box embedding)。
m = known_labels_expaned.long().to(device)
# 对添加噪声的label进行编码(2 * dn_number * batch_gt_num, 256)
input_label_embed = label_enc(m)
input_bbox_embed = inverse_sigmoid(known_bbox_expand)
# (200, 256)
padding_label = torch.zeros(pad_size, hidden_dim).to(device)
padding_bbox = torch.zeros(pad_size, 4).to(device)
# (bs, 200,256)
input_query_label = padding_label.repeat(batch_size, 1, 1)
input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)
map_known_indice = torch.tensor([]).to(device)
if len(known_num): # 每个图像中的目标数量
map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num])
map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(2 * dn_number)]).long()
if len(known_bid):
# 在对应batch、对应query索引位置赋值
input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed
input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed
第五部分
构建attention mask,目的是用于在decoder中计算自注意力时防止信息泄露;构建dn_meta字典存储相关信息,这部分省略,感兴趣的同学可自行学习。
tgt_size = pad_size + num_queries
attn_mask = torch.ones(tgt_size, tgt_size).to(device) < 0
# match query cannot see the reconstruct
attn_mask[pad_size:, :pad_size] = True
# reconstruct cannot see each other
for i in range(dn_number):
if i == 0:
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
if i == dn_number - 1:
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * i * 2] = True
else:
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * 2 * i] = True
dn_meta = {
'pad_size': pad_size,
'num_dn_group': dn_number,
}
总结
上述内容就是对DINO中如何构建对比去噪分支中的正负样本相关代码的解析,由于本人文笔有限,可能在具体变量含义的表述上不够清晰,大家结合博客内容自行调试代码,了解代码的含义和作用,加深自己的理解。文中如有错误或歧义之处欢迎在评论区指出,大家也可以在评论区分享自己的理解。
另外像说一下阅读DINO源码的一些个人感悟。通常我们会先阅读论文,了解模型或者某个算法的大致思路,然后再结合代码查看具体的实现过程。在阅读代码时一定要保持宏观的思路,即这段代码的整体思路和作用是什么,不然就进入“不知庐山真面目,只缘身在此山中”的困境。切忌过于深入的钻研某个具体的操作导致思路陷入局部的“思维困境”;其次就是尽量搞懂代码中每个变量的含义和作用,用清楚的语言进行注释,方便理解和复盘;最后就是调试代码,查看变量的数值和维度,各个维度的含义,加深理解。