【TPV】TPVFormer代码解析
1、前言
前面的博客学习了TPVFormer的基本原理,以及在本地将TPVFormer运行起来,而本篇将进一步学习TPVFormer的代码,通过代码,对于原理的理解会更加深刻,同时TPVFormer中使用了Deformer DETR中的可变形注意力机制,需要提前了解。
TPVFormer原理:https://blog.csdn.net/weixin_42108183/article/details/129629303?spm=1001.2014.3001.5501
Deformer DETR学习:https://blog.csdn.net/weixin_42108183/article/details/128680716?spm=1001.2014.3001.5502
2、原理简介
根据对代码以及论文的学习,论文的整体组件如下图所示:
环视图片通过特征提取获得多尺度特征Features Maps,然后与TPV特征与Features Maps进行cross-attm,然后hw方向的TPV特征单独进行self-attn,作者应该是认为hw方向是检测过程中最重要的特征,即BEV特征。得到TPV特征后,即可得到空间下每个体素的TPV特征,然后就可以使用预测头对每个体素进行分类了。
3、代码解析
3.1 eval.py
# predict_labels_vox:[1, 18, 100, 100, 8] 为空间中 100*100*8个体素预测类别,8为高度上的体素个数
# predict_labels_pts:[1, 18, 34688, 1, 1] 为当前帧的所有点云预测类别,在实际推理时,不需要预测点云的类别,仅在验证时使用
predict_labels_vox, predict_labels_pts = my_model(img=imgs, img_metas=img_metas, points=val_grid_float) # -> tpvformer04/tpvformer.py
# 计算loss
loss = lovasz_softmax(...)
predict_labels_pts = torch.argmax(predict_labels_pts, dim=1) # [1, 34752] # 预测点云类别
predict_labels_vox = torch.argmax(predict_labels_vox, dim=1) #[1, 100, 100, 8] # 预测体素类别
... # 计算iou
3.2 tpvformer04/tpvformer.py
def forward(...):
# 多尺度特征
img_feats = self.extract_img_feat(img=img, use_grid_mask=use_grid_mask)
# 三个方向的BEV特征:俯视、侧视、前视
outs = self.tpv_head(img_feats, img_metas) # [1, 10000, 256]、 [1, 800, 256]、[1, 800, 256] # -> tpvformer04/tpv_head.py
# 体素预测类别、点云预测类别
outs = self.tpv_aggregator(outs, points) # [1, 18, 100, 100, 8]、[1, 18 ,34752, 1, 1]
3.2 tpvformer04/tpv_head.py
class TPVFormerHead(...):
def __init__(...):
...
def forward(self, mlvl_feats, img_metas):
tpv_queries_hw = ... # hw视角的特征图 [10000, 256]
tpv_queries_zh = ... # zh视角的特征图 [800, 256]
tpv_queries_wz = ... # wz视角的特征图 [800, 256]
tpv_pos_hw = ... # hw视角位置编码
for lvl, feat in enumerate(mlvl_feats):
# [6, 256] 相机参数编码
feat = feat + self.cams_embeds[:, None, None, :].to(dtype)
# 特征图尺度 编码
feat = feat + self.level_embeds[None, None, lvl:lvl + 1, :].to(dtype)
# encoder中实现了两种注意力机制,且均有deformable 实现
# 1、hw 方向上的TPV特征进行 self-attn
# 2、三个方向的TPV特征与图像特征进行cross-attn
tpv_embed = self.encoder(...) # -> tpvformer04/modules/encoder.py
return tpv_embed # 返回三个方向的TPV特征
3.3 tpvformer04/modules/encoder.py
class TPVFormerEncoder(...):
def __init__(...):
...
def forward(...):
"""
"""
output = tpv_query # [1, 10000, 256]、[1, 800, 256]、[1, 800, 256] #三个方向的BEV特征
ref_3ds = [self.ref_3d_hw, self.ref_3d_zh, self.ref_3d_wz] # ???? 参考点 [1, 4, 100*100, 3]、 [1, 32, 100*8, 3] [1, 4, 100*100, 3]、 [1, 32, 100*8, 3]
for ref_3d in ref_3ds:
reference_points_cam, tpv_mask = self.point_sampling(ref_3d, self.pc_range, kwargs['img_metas']) # ???
# [6, 1, 10000, 4, 2] 、[6, 1, 800, 32, 2] 、[6, 1, 800, 32, 2]
ref_2d_hw = ... # [1, 10000, 1, 2]
hybird_ref_2d = ... # [2, 10000, 1, 2]
# self.layers:[[TPVCrossViewHybridAttention,TPVImageCrossAttention],[TPVCrossViewHybridAttention,TPVImageCrossAttention],[TPVCrossViewHybridAttention,TPVImageCrossAttention]]
# TPVCrossViewHybridAttention -> hw方向山上的特征进行 self-attn
# TPVImageCrossAttention -> 三个方向的TPV特征与图像特征进行 cross-attn
for lid, layer in enumerate(self.layers):
output = layer(...) # -> tpvformer04/modules/tpvformer_layer.py
return output # 返回三个方向的TPV特征
3.4 tpvformer04/modules/tpvformer_layer.py
class TPVFormerLayer(...):
def __init__(...):
...
def forward(...):
# self.operation_order -> ('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm')
# self.attentions -> [TPVCrossViewHybridAttention, TPVImageCrossAttention]
for layer in self.operation_order:
# 首先进行 'self_attn'
if layer == 'self_attn':
# 只对 query[0] 也就是 hw方向的 BEV特征进行 self-attn
query_0 = self.attentions[attn_index](query[0],...)
# 叠加三个方向的 TPV特征
query = torch.cat([query_0, query[1], query[2]], dim=1)
# 进行 'cross_attn',即 TPV特征与 图片特征进行 cross_attn
elif layer == 'cross_attn':
query = self.attentions[attn_index](query,...) # -> tpvformer04/modules/image_cross_attention.py
...
query = torch.split(query, [tpv_h*tpv_w, tpv_z*tpv_h, tpv_w*tpv_z], dim=1) # [1,11600,256]
return query
3.5 tpvformer04/modules/cross_view_hybrid_attention.py
class TPVCrossViewHybridAttention(...):
def __init__(...):
...
def forward(...):
query = torch.cat([value[:bs], query], -1)
value = ... # [2, 10000, 8, 32] # 8头注意力机制
sampling_offsets = ... # 预测偏移量 [1, 10000, 8, 2, 1, 4, 2]
if reference_points.shape[-1] == 2:
offset_normalizer = ...
sampling_locations = ... # 参考点与偏移点相加并归一化
if torch.cuda.is_available() and value.is_cuda::
if value.dtype == torch.float16::
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
else:
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
output = MultiScaleDeformableAttnFunction.apply(...)
...
return self.dropout(output) + identity # 残差链接 -> tpvformer04/modules/tpvformer_layer.py
3.6 tpvformer04/modules/image_cross_attention.py
class TPVImageCrossAttention(...):
def __init__(...):
...
def forward(...):
...
queries = self.deformable_attention(...)
value = self.value_proj(value)
sampling_offsets, attention_weights = self.get_sampling_offsets_and_attention(query)
reference_points = self.reshape_reference_points(reference_points)
if reference_points.shape[-1] == 2:
... # 计算偏移点
else:
...
# 实现TPV特征与图片特征进行 cross_attn 操作
if torch.cuda.is_available() and value.is_cuda:
if value.dtype == torch.float16:
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
else:
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
return output
4、总结
本次学习了TPVFormer的代码流程,深入学习了其中的详细步骤,一步一步的将TPVFormer的过程进行了详细的分析,从最终的效果可以看出,TPVFormer的效果十分惊艳!