当前位置: 首页 > article >正文

【TPV】TPVFormer代码解析

1、前言

前面的博客学习了TPVFormer的基本原理,以及在本地将TPVFormer运行起来,而本篇将进一步学习TPVFormer的代码,通过代码,对于原理的理解会更加深刻,同时TPVFormer中使用了Deformer DETR中的可变形注意力机制,需要提前了解。

TPVFormer原理:https://blog.csdn.net/weixin_42108183/article/details/129629303?spm=1001.2014.3001.5501

Deformer DETR学习:https://blog.csdn.net/weixin_42108183/article/details/128680716?spm=1001.2014.3001.5502

2、原理简介

根据对代码以及论文的学习,论文的整体组件如下图所示:
在这里插入图片描述

环视图片通过特征提取获得多尺度特征Features Maps,然后与TPV特征与Features Maps进行cross-attm,然后hw方向的TPV特征单独进行self-attn,作者应该是认为hw方向是检测过程中最重要的特征,即BEV特征。得到TPV特征后,即可得到空间下每个体素的TPV特征,然后就可以使用预测头对每个体素进行分类了。

3、代码解析

3.1 eval.py
# predict_labels_vox:[1, 18, 100, 100, 8] 为空间中 100*100*8个体素预测类别,8为高度上的体素个数
# predict_labels_pts:[1, 18, 34688, 1, 1] 为当前帧的所有点云预测类别,在实际推理时,不需要预测点云的类别,仅在验证时使用
predict_labels_vox, predict_labels_pts = my_model(img=imgs, img_metas=img_metas, points=val_grid_float) # ->  tpvformer04/tpvformer.py
# 计算loss
loss = lovasz_softmax(...)
predict_labels_pts = torch.argmax(predict_labels_pts, dim=1) # [1, 34752] # 预测点云类别
predict_labels_vox = torch.argmax(predict_labels_vox, dim=1) #[1, 100, 100, 8] # 预测体素类别

... # 计算iou

3.2 tpvformer04/tpvformer.py

def forward(...):
    # 多尺度特征
    img_feats = self.extract_img_feat(img=img, use_grid_mask=use_grid_mask)  
    # 三个方向的BEV特征:俯视、侧视、前视
    outs = self.tpv_head(img_feats, img_metas)  #  [1, 10000, 256]、 [1, 800, 256]、[1, 800, 256]  # -> tpvformer04/tpv_head.py
    # 体素预测类别、点云预测类别
    outs = self.tpv_aggregator(outs, points) # [1, 18, 100, 100, 8]、[1, 18 ,34752, 1, 1] 
3.2 tpvformer04/tpv_head.py
class TPVFormerHead(...):
    def __init__(...):
        ...
        
    def forward(self, mlvl_feats, img_metas):
        tpv_queries_hw = ...  # hw视角的特征图   [10000, 256]
        tpv_queries_zh = ...  # zh视角的特征图   [800, 256]
        tpv_queries_wz = ...  # wz视角的特征图   [800, 256]
        tpv_pos_hw = ... # hw视角位置编码
        for lvl, feat in enumerate(mlvl_feats):
            # [6, 256] 相机参数编码
            feat = feat + self.cams_embeds[:, None, None, :].to(dtype)
            # 特征图尺度 编码 
            feat = feat + self.level_embeds[None, None, lvl:lvl + 1, :].to(dtype)
        
        # encoder中实现了两种注意力机制,且均有deformable 实现
        # 1、hw 方向上的TPV特征进行 self-attn
        # 2、三个方向的TPV特征与图像特征进行cross-attn
        tpv_embed = self.encoder(...)  # -> tpvformer04/modules/encoder.py
        return tpv_embed # 返回三个方向的TPV特征

3.3 tpvformer04/modules/encoder.py

class TPVFormerEncoder(...):
    def __init__(...):
        ...
    
    def forward(...):
        """
        """
        output = tpv_query  # [1, 10000, 256]、[1, 800, 256]、[1, 800, 256]  #三个方向的BEV特征
        ref_3ds = [self.ref_3d_hw, self.ref_3d_zh, self.ref_3d_wz]  # ???? 参考点 [1, 4, 100*100, 3]、 [1, 32, 100*8, 3] [1, 4, 100*100, 3]、 [1, 32, 100*8, 3]
        
        for ref_3d in ref_3ds:
            reference_points_cam, tpv_mask = self.point_sampling(ref_3d, self.pc_range, kwargs['img_metas']) # ???
        
         # [6, 1, 10000, 4, 2] 、[6, 1, 800, 32, 2] 、[6, 1, 800, 32, 2] 
         ref_2d_hw = ... # [1, 10000, 1, 2]
         hybird_ref_2d = ... # [2, 10000, 1, 2]
         
         # self.layers:[[TPVCrossViewHybridAttention,TPVImageCrossAttention],[TPVCrossViewHybridAttention,TPVImageCrossAttention],[TPVCrossViewHybridAttention,TPVImageCrossAttention]]
         # TPVCrossViewHybridAttention -> hw方向山上的特征进行 self-attn
         # TPVImageCrossAttention -> 三个方向的TPV特征与图像特征进行 cross-attn
         for lid, layer in enumerate(self.layers):
            output = layer(...)  #  ->  tpvformer04/modules/tpvformer_layer.py
            
        return output # 返回三个方向的TPV特征
3.4 tpvformer04/modules/tpvformer_layer.py
class TPVFormerLayer(...):
    def __init__(...):
        ...
        
    def forward(...):
        # self.operation_order -> ('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm')
        # self.attentions -> [TPVCrossViewHybridAttention, TPVImageCrossAttention]
        for layer in self.operation_order:
            # 首先进行 'self_attn'
            if layer == 'self_attn':
                # 只对 query[0] 也就是 hw方向的 BEV特征进行 self-attn
                query_0 = self.attentions[attn_index](query[0],...) 
                
                
                # 叠加三个方向的 TPV特征
                query = torch.cat([query_0, query[1], query[2]], dim=1)
            
            # 进行 'cross_attn',即 TPV特征与 图片特征进行 cross_attn
            elif layer == 'cross_attn':
                query = self.attentions[attn_index](query,...)  # -> tpvformer04/modules/image_cross_attention.py
                ... 
        
    query = torch.split(query, [tpv_h*tpv_w, tpv_z*tpv_h, tpv_w*tpv_z], dim=1) # [1,11600,256] 
    return query 

3.5 tpvformer04/modules/cross_view_hybrid_attention.py

class TPVCrossViewHybridAttention(...):
    def __init__(...):
        ...
        
    def forward(...):
        query = torch.cat([value[:bs], query], -1)
        value = ... # [2, 10000, 8, 32] # 8头注意力机制
        sampling_offsets = ... # 预测偏移量 [1, 10000, 8, 2, 1, 4, 2] 
        if reference_points.shape[-1] == 2:
            offset_normalizer = ...
            sampling_locations = ... # 参考点与偏移点相加并归一化
        
        if torch.cuda.is_available() and value.is_cuda::
            if value.dtype == torch.float16::
                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
            else:
                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
            output = MultiScaleDeformableAttnFunction.apply(...)
        
        ...
        return self.dropout(output) + identity  # 残差链接  ->  tpvformer04/modules/tpvformer_layer.py

3.6 tpvformer04/modules/image_cross_attention.py

class TPVImageCrossAttention(...):
    def __init__(...):
        ...
    
    def forward(...):
        ...
        queries = self.deformable_attention(...)
        value = self.value_proj(value)
        sampling_offsets, attention_weights = self.get_sampling_offsets_and_attention(query)
        reference_points = self.reshape_reference_points(reference_points)
        if reference_points.shape[-1] == 2:
            ... # 计算偏移点
        else:
            ...
        
        # 实现TPV特征与图片特征进行 cross_attn 操作    
        if torch.cuda.is_available() and value.is_cuda:
            if value.dtype == torch.float16:
                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
            else:
                MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
            output = MultiScaleDeformableAttnFunction.apply(
                value, spatial_shapes, level_start_index, sampling_locations,
                attention_weights, self.im2col_step)
        else:
            output = multi_scale_deformable_attn_pytorch(
                value, spatial_shapes, sampling_locations, attention_weights)
        return output

4、总结

本次学习了TPVFormer的代码流程,深入学习了其中的详细步骤,一步一步的将TPVFormer的过程进行了详细的分析,从最终的效果可以看出,TPVFormer的效果十分惊艳!


http://www.kler.cn/a/4805.html

相关文章:

  • 【Go】:深入解析 Go 1.24:新特性、改进与最佳实践
  • Vue-Cli
  • 牛客网刷题 ——C语言初阶(6指针)——BC106 上三角矩阵判定
  • 最近在盘gitlab.0.先review了一下docker
  • python实战应用讲解-【numpy专题篇】常见问题解惑(十五)(附python示例代码)
  • 6 Nginx常用核心模块指令
  • 华为OD机试题【剩余可用字符集】用 Java 解 | 含解题说明
  • 【Python学习笔记(八)】threading多线程模块的使用
  • python实战应用讲解-【numpy专题篇】实用小技巧(四)(附python示例代码)
  • mycat2 安装 jDK
  • Python 反射
  • 【TDengine】详解 taosAdapter 适配器
  • Html5代码实现动态三角形
  • Elasticsearch 搜索测试与集成Springboot3
  • 18005 它不是丑数
  • 算法第十九期——图论初入门
  • Java多线程
  • CSS Grid 网格布局详解
  • 【故障检测】基于 KPCA 的故障检测【T2 和 Q 统计指数的可视化】(Matlab代码实现)
  • 【华为OD机试 2023最新 】新学校选址(C++ 100%)
  • 解析springboot源码中this::selfInitialize怪异用法的含义
  • C++11右值引用
  • 华为OD机试用java实现 -【吃火锅】
  • ChatGPT辅助编程实践——常用提示词整理