当前位置: 首页 > article >正文

【VOS源码解析-2024CVPR-Cutie】1、train_wrapper结构解析

源码解析

  • 论文阅读
  • 1、数据预处理
  • 2、视频帧特征提取
    • 2.1 pixel encoder 特征提取
    • 2.2 tranformer_key
    • 2.3 特征图维度转换

论文阅读

  • 原文阅读笔记
  • github
  • arxiv地址
  • 训练框架
    • 1、train.py 概览
    • 2、trainner.py 概览
  • model主体框架
    • 1、train_wrapper

1、数据预处理

    def forward(self, data: Dict):
        out = {}
        frames = data['rgb']  #输入的RGB帧序列
        first_frame_gt = data['first_frame_gt'].float()  #第一帧的ground truth
        b, seq_length = frames.shape[:2]  #批量大小和序列长度
        # 在数据集预处理过程中,选取的对象数量数不超过train_config中num_objects参数的值
        num_filled_objects = [o.item() for o in data['info']['num_objects']]  #每个样本中实际存在的目标数量,将其转换为python列表
        max_num_objects = max(num_filled_objects)  #最大目标数量
        first_frame_gt = first_frame_gt[:, :, :max_num_objects]  #根据最大目标数量裁剪第一帧的gt标注
        selector = data['selector'][:, :max_num_objects].unsqueeze(2).unsqueeze(2)  #用于选择特定目标的张量

        num_objects = first_frame_gt.shape[2]
        out['num_filled_objects'] = num_filled_objects

        def get_ms_feat_ti(ti):
            return [f[:, ti] for f in ms_feat]  #提取ms_feat中提取第ti时刻的特征

        with torch.cuda.amp.autocast(enabled=self.use_amp):
            frames_flat = frames.view(b * seq_length, *frames.shape[2:])

在这里插入图片描述
数据预处理如代码和图所示,最开始的输入数据data是一个字典类型,它包含以下五个变量,这里只说最重要的三个变量。

  • 输入数据(data:Dict)
    • rgb:输入的原始视频帧,shape为(b, t, c, h, w)
    • first_frame_gt:第一帧视频帧的注释mask,shape为(b, num_ojects, c, h, w)。
      • 这里的num_objects是train_config中设置的参数值(具体位置是cutie/cutie/config/train_config.yaml)
      • num_objects代表model预先设置的能接受的最大对象数+1,这里的1是背景
      • 在生成第一帧gt图时,会根据num_objects设置shape大小。若原始视频的对象数小于num_objects-1,则后面几个维度设为空。若原始视频的对象数大于num_objects-1,则随机挑选num_objects-1个对象进行训练。
    • selector:shape为(1, num_objects)
      代码对输入数据的处理为:
  • 输入数据处理
    • 对rgb图像进行展平处理,并存储为变量frames_flat
    • 对first_frames_gt按照视频中出现过的最大对象数进行裁段
    • 对selector按照视频中出现过的最大对象数进行裁段,并扩充维度

2、视频帧特征提取

在这里插入图片描述

2.1 pixel encoder 特征提取

  • pixel encoder
    • resnet50的前三层,每一层输出一个特征图
    • 输出为ms_feat(resnet50的前三层特征图)以及pix_feat(对第三层特征图的通道维度进行压缩)
  with torch.cuda.amp.autocast(enabled=self.use_amp): #按照预设值决定是否启用混合精度计算
            frames_flat = frames.view(b * seq_length, *frames.shape[2:]) #将frames展平,由原来的b, t, c,h, w变为b*t, c, h, w
            ms_feat, pix_feat = self.encode_image(frames_flat)
    def encode_image(self, image: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor):
        image = (image - self.pixel_mean) / self.pixel_std
        ms_image_feat = self.pixel_encoder(image)
        return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
   'self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1)'

2.2 tranformer_key

对ms_feat的第三个特征图进行关键特征提取:

  • 先统一压缩至256维
  • key特征图:将256维度的特征图再压缩至64维
  • shrinkage:将256维度的特征图压缩维1维
  • selection:将256维的特征图压缩至64维后,再输入进sigmoid函数
with torch.cuda.amp.autocast(enabled=False):  #禁止混合精度计算,确保transformer_key中的所有操作全部以精度运行
     keys, shrinkages, selections = self.transform_key(ms_feat[0].float())   
class KeyProjection(nn.Module):
    def __init__(self, model_cfg: DictConfig):
        super().__init__()
        in_dim = model_cfg.pixel_encoder.ms_dims[0] #1024
        mid_dim = model_cfg.pixel_dim #256
        key_dim = model_cfg.key_dim #64

        self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1)
        self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
        # shrinkage
        self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1)
        # selection
        self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)

        nn.init.orthogonal_(self.key_proj.weight.data)
        nn.init.zeros_(self.key_proj.bias.data)

    def forward(self, x: torch.Tensor, *, need_s: bool,
                need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor):
        x = self.pix_feat_proj(x)
        shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
        selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None

        return self.key_proj(x), shrinkage, selection

2.3 特征图维度转换

将上述得到的所有特征图进行维度转换,将原来展平的时间维度重新提取出来。

h, w = keys.shape[-2:]
keys = self.move_t_from_batch_to_volume(keys)
shrinkages = self.move_t_from_batch_to_volume(shrinkages)
selections = self.move_t_from_batch_to_volume(selections)
ms_feat = [self.move_t_out_of_batch(f) for f in ms_feat]
pix_feat = self.move_t_out_of_batch(pix_feat)
'self.move_t_out_of_batch = Rearrange((b t) c h w -> b t c h w, t=self.seq_length)'

http://www.kler.cn/a/513183.html

相关文章:

  • PyTorch使用教程(10)-torchinfo.summary网络结构可视化详细说明
  • 3D Vision--计算点到平面的距离
  • 第9章:Python TDD解决货币对象相等性比较难题
  • 消息队列篇--原理篇--RocketMQ(NameServer,Broker,单机上每秒处理数百万条消息性能)
  • 基于32QAM的载波同步和定时同步性能仿真,包括Costas环的gardner环
  • C++/QT环境下图像在窗口下等比例渲染绘制
  • sqlmap 自动注入 -01
  • 【Linux】华为服务器使用U盘安装统信操作系统
  • 跨境电商之小程序shinecrys水晶国度小程序数据分析
  • 【HF设计模式】06-命令模式
  • Flink底层架构与运行流程
  • 2.4 kubectl命令行设置7大命令分组
  • 三轴云台之跟随模式篇
  • JAVA:策略模式(Strategy Pattern)的技术指南
  • Java泛型方法所受的限制是什么?
  • JDBC实验测试
  • 软通动力携鸿湖万联与微展世签署战略合作协议,以开源鸿蒙赋能工业创新升级
  • 【深度学习基础】多层感知机 | 多层感知机的实现
  • K8S如何让worker使用kubectl命令(RBAC方法)
  • 机器学习-核函数(Kernel Function)
  • 使用xorriso v1.5.2和grub4dos 0.4.6a -2024-02-26制作可启动ISO文件
  • 《Keras 3 使用 Reptile 进行 Few-Shot 学习》
  • SSL证书的颁发格式和制作过
  • 第四天 安装DevEco Studio,配置HarmonyOS开发环境
  • 【集合】单列集合和双列集合
  • OpenCV简介、OpenCV安装