当前位置: 首页 > article >正文

SCANet代码解读

论文链接:[2406.07189] RGB-Sonar Tracking Benchmark and Spatial Cross-Attention Transformer Tracker

代码链接:

GitHub - LiYunfengLYF/SCANet


 通过代码需要了解到的事情:

1. 两种模态的数据集是怎么传入模型的?
2. 模型结构代码是如何搭建


train.py

路径:tracking/train.py

该文件主要用于是参数设置。

  • --script:训练脚本的名称,用于指定要运行的训练逻辑,比如模型结构或任务。
  • --config:训练配置文件名,默认是 baseline(通常是 YAML 文件),用于指定训练的超参数、模型配置等。
  • --save_dir:指定保存训练结果(如模型权重、日志、TensorBoard 日志)的目录路径。
  • --mode:训练模式,支持以下几种:
    • single:单卡训练(单 GPU)。
    • multiple:多卡训练(多 GPU,单节点)。
    • multi_node:多节点训练(分布式训练,跨多台机器)。
  • --nproc_per_node:每个节点上使用的 GPU 数量(仅在 multiple 或 multi_node 模式下需要)。
  • --use_lmdb:是否使用 LMDB 格式的数据集(0 或 1)。
  • --use_wandb:是否使用 Weights & Biases(WandB) 工具进行训练监控(0 或 1)。
  • --env_num:指定环境编号,支持多个环境的开发(0, 1, 2 等)。

通过选择训练模式(单卡)后,os.system() 会创建一个新的子进程来执行 train_cmd。在子进程中,系统调用相应的 Python 解释器来执行指定的脚本( lib/train/run_training.py)。


run_training.py

 路径:lib/train/run_training.py

这个文件包含训练的入口程序。

功能概述:

  1. 接收用户输入的命令行参数(例如训练脚本名、配置文件名、随机种子等)。
  2. 初始化训练环境,包括设置随机种子、CUDNN 加速、分布式通信等。
  3. 调用具体的训练逻辑,基于用户指定的脚本和配置文件启动训练任务。
  4. 支持扩展功能:
    • 知识蒸馏(使用教师模型指导学生模型训练)。
    • 多环境开发(env_num 参数)。
    • 数据存储格式(如 LMDB)。
    • 集成工具(如 WandB,用于训练过程的可视化)。

 初始化随机种子 (init_seeds)

作用:设置随机种子以确保训练的可复现性。

细节

  • 使用 Python 的 random 和 NumPy 的 np.random.seed 设置全局随机数种子。
  • 使用 PyTorch 的 torch.manual_seed 和 torch.cuda.manual_seed 设置 GPU 和 CPU 的随机数种子。
  • 如果使用 CUDNN 后端(GPU 加速),设置 torch.backends.cudnn.benchmark 为 True,以加速特定的卷积操作。

 初始化训练环境

ws_settings.Settings(env_num):加载训练环境的配置,例如设备编号、超参数等。

路径设置

settings.project_path:项目路径,通常表示当前训练任务的文件夹。

  • settings.cfg_file:训练配置文件的完整路径(experiments/<script_name>/<config_name>.yaml)。
  • settings.save_dir:将用户指定的保存路径转为绝对路径。

 知识蒸馏支持

如果启用了知识蒸馏(distill=1),会加载教师模型的脚本和配置文件。

细节

  • 教师模型和学生模型的训练逻辑分别存放在不同的模块中:
    • train_script_distill:蒸馏模式。
    • train_script普通训练模式(使用)
  • 动态加载模块:使用 importlib.import_module 动态导入模块,通过反射调用相应的训练逻辑。

 执行训练

训练选择普通模式(不使用知识蒸馏)。

  • getattr(expr_module, 'run'):从动态加载的模块中获取 run 函数。
  • expr_func(settings):将训练的配置(settings)传入,并启动训练。

进入lib.train.train_script.run()执行训练。

train_script.py 

路径:lib/train/train_script.py

 加载配置文件

初始化随机种子

日志目录初始化

数据加载器构建

build_dataloaders:加载训练和验证数据集,返回两个数据加载器loader_train 和 loader_val 。

  • 训练加载器(loader_train):用于训练数据的迭代读取。
  • 验证加载器(loader_val):用于在训练期间验证模型性能。

 模型构建

根据配置文件指定的模型名称,动态从 TRACKER_REGISTRY 注册表中找到模型类

TRACKER_REGISTRY 是一个注册表,类似于 Python 的字典,存储了模型名称和模型类之间的映射关系。

在代码中,TRACKER_REGISTRY.get(cfg.MODEL.NETWORK) 用于通过模型名称(cfg.MODEL.NETWORK)从注册表中查找对应的模型类。

在代码的其他部分,模型类会通过装饰器注册到 TRACKER_REGISTRY,如下:

分布式训练包裹

损失函数和 Actor

损失函数

  • 根据配置文件中指定的 IoU 类型(例如 giou 或 wiou),选择相应的 IoU 损失函数。
  • 其他损失函数包括:
    • l1_loss:用于回归目标框。
    • focal_loss:用于处理类别不平衡。
    • BCEWithLogitsLoss:用于分类精度。
  • loss_weight:定义每个损失函数的权重,用于加权损失求和。

Actor

  • 从 ACTOR_Registry 中获取 Actor 实例(类似于控制训练逻辑的管理者)。
  • Actor 将模型(net)、损失函数(objective)和训练设置(settingscfg)结合在一起。

 优化器和学习率调度器

Trainer 实例化

LTRTrainer:训练流程管理器,负责实际的训练逻辑。

参数

  • actor:负责模型前向传播、损失计算和后向传播的组件。
  • 数据加载器 [loader_train, loader_val]:训练和验证数据源。
  • optimizer 和 lr_scheduler:优化器和学习率调度器。
  • use_amp:是否使用混合精度训练(AMP)。
  • rgb_mode:是否启用 RGB 模式(配置文件中定义)

启动训练过程

trainer.train:启动训练流程。

  • 参数
    • cfg.TRAIN.LEARN.EPOCH:训练的总 epoch 数。
    • load_latest=True:是否加载最近保存的检查点(如果存在)。
    • fail_safe=True:是否启用故障恢复机制(例如断点续训)。

 trackerModel.py——重点

路径:lib/models/scanet/trackerModel.py

在train_script.py中,通过TRACKER_REGISTRY.get() 动态查找模型类。lib/models/scanet/trackerModel.py中的SCANet_network类被装饰器注册到TRACKER_REGISTRY中。

 模型基于 OSTrack 构建,目标是实现RGB-T(可见光 & 声呐)目标跟踪任务。


 SCANet_network类

构造函数

  • self.backbone:构建骨干网络,从输入的输入数据中提取特征,参数cfg.MODEL.BACKBONE.LOAD_MODE指定了加载模式

  • 加载预训练权重
  • 调用骨干网络的微调函数
  • 构建RGB和Sonar的预测头模块,头部模块的类型和参数由配置文件指定(cfg.MODEL.HEAD 和 cfg.MODEL.RGBS_HEAD

forward前向传播

  • 输入:模板帧(template frame)搜索帧(search frame)。模板帧用于定义要跟踪的目标,而搜索帧用于定位目标。
  • 骨干网络特征提取
  • 调用 forward_head 方法,将骨干网络的输出传递给头部模块,生成 RGB 和声呐的预测结果。

 forward_head预测头模块函数

  • 输入:骨干网络的输出 cat_feature,以及可选参数 gt_score_map
  • 特征切片
  • 特征变换
  • RGB 和声呐头部的前向传播

    • 将 RGB 和声呐特征分别传递给对应的头部模块 box_head 和 sonar_head
    • 生成的输出包含预测的边界框(pred_boxes)和其他特征图(如 score_mapsize_map)。
    • 上述的具体参数是来自配置文件cfg的,具体是通过找到 settings.cfg_file 中指定的 YAML 文件路径,查看 MODEL.HEADMODEL.RGBS_HEAD 的具体配置

baseline.yaml

路径:experiments/scanet/baseline.yaml

这份 YAML 配置文件包含了模型架构、数据集、训练、验证和测试的全流程参数:

  1. 数据配置(DATA):定义了输入数据的模态、归一化参数,以及训练/验证数据集。
  2. 模型配置(MODEL):包括骨干网络、头部模块等配置,支持多模态(RGB + 声呐)。
  3. 训练配置(TRAIN):详细定义了优化器、损失函数、混合精度等训练参数。
  4. 测试配置(TEST):与训练类似,定义了搜索和模板区域的参数。

数据配置DATA

模态设置

  • RGB_MODE 和 RGBS_MODE:指示模型是否同时处理 RGB 和 RGBS(声呐)数据模态。
  • RGBS_ROTATE:是否对 RGBS 数据增加旋转增强。
  • MAX_SAMPLE_INTERVAL:最大采样间隔,用于决定从视频中采样帧的时间跨度。

数据归一化参数

  • MEAN 和 STD:输入图像的归一化参数,通常是 RGB 图像的标准均值和标准差,用于将像素值归一化到标准分布。

搜索区域 (SEARCH) 和模板区域 (TEMPLATE) 配置

  • 搜索区域(SEARCH)
    • 定义了模型在当前帧中搜索目标的参数。
    • 例如,FACTOR=4.0 表示搜索区域的宽度/高度是目标框的 4 倍。
    • SIZE=256 表示将搜索区域缩放到 256x256 的固定尺寸。
  • 模板区域(TEMPLATE)
    • 定义了模板帧(通常是目标的初始位置)的参数。
    • SIZE=128 表示将模板区域缩放到 128x128。
    • 通常模板帧会被固定且不抖动(CENTER_JITTER=0 和 SCALE_JITTER=0)。

数据集


模型配置 MODEL

网络整体信息 

  • NETWORK:模型的主网络名称,这里是 SCANet_network
  • RETURN_STAGES:骨干网络中需要返回的特征层索引(例如第 2、5、8 和 11 层)。

 骨干网络 (BACKBONE)


  • TYPE:骨干网络类型,这里采用 vit_base_patch16_224_midlayer(Vision Transformer)。
  • PARAMS
    • ffm: SCAM:特征融合方式为 SCAM。
    • rgbs_loc: [3, 6, 9]:指定在第 3、6、9 层使用 RGB-S 融合。

 头部模块

  • HEAD(RGB 头) 和 RGBS_HEAD(声呐头)
    • 类型为 center_head,参数为:
      • inplanes:输入通道数。
      • channel:中间特征通道数。
      • feat_sz:特征图大小。
      • stride:卷积步幅。


 训练配置 TRAIN

  • 优化目标
    • 使用 giou(广义 IoU)作为 IoU 损失。
    • GIOU_WEIGHT 和 L1_WEIGHT:分别为 GIoU 和 L1 损失的权重。
  • 优化器参数
    • LR:初始学习率为 0.00001
    • WEIGHT_DECAY:权重衰减值为 0.0001

 优化器和学习率调度器

  • 优化器:使用 ADAMW 优化器。
  • 学习率调度器
    • 使用 step 调度器,每隔 30 个 epoch 降低学习率。

AMP混合精度训练 

  • 是否启用 AMP(混合精度训练):USED: False
  • 梯度裁剪的最大范数:GRAD_CLIP_NORM: 0.1

测试配置 TEST

 

  • 测试的搜索区域和模板区域与训练保持一致,分别为 256x256 和 128x128

 vit_rgbs.py

路径:lib/models/scanet/vit_rgbs.py

在trackerModel.py中

 在YAML文件中,backbone的配置如下:

 backbone的type为vit_base_patch16_224_midlayer,在vit_rgbs.py文件中定义


vit_rgbs.py文件代码结构概览

  1. Attention 和 Block 模块

    • 实现了 Transformer 中的注意力机制(Attention)及基本的 Transformer 块(Block)。
  2. VisionTransformer_midlayer

    • 核心 Transformer 模型,支持 ViT 结构,包含可见光(RGB)和 T 模态的融合能力。
    • 支持冻结部分 Transformer 层(freeze_layer)和插入中间融合层(rgbs_layers)。
  3. 辅助函数

    • 包括权重初始化、加载预训练权重、调整位置嵌入(pos_embed)等。
  4. 模型注册

    • 定义了 vit_basevit_small, 和 vit_tiny 三种 ViT 模型,并通过注册表(MODEL_REGISTRY)动态加载。

 vit_base_patch16_224_midlayer --------> _create_vision_transformer

_create_vision_transformer ---------> VisionTransformer_midlayer


VisionTransformer_midlayer——重点

这部分是这个py文件的核心,它实现了标准的 Vision Transformer,同时支持红外(T)与 RGB 的融合。


模型结构

整体模型架构可以拆解为以下几个主要部分:

  1. 输入处理模块

    • 图像被切分为 Patch,并通过 PatchEmbed 层进行线性嵌入,生成低维特征。
    • 位置编码(pos_embed)被添加到特征中,使得 Transformer 感知输入的空间位置。
  2. Transformer 主干网络

    • 若干个堆叠的 Transformer Block,每个 Block 包含:
      • 多头注意力机制(Attention)。
      • 前馈网络(MLP)。
    • 这些 Block 用于提取输入的高层次特征。
  3. 融合模块rgbs_layers):

    • 在指定的 Transformer Block 层(通过 rgbs_loc 参数定义)插入特征融合模块(ffm)。
    • 融合模块用于交互和融合 RGB 和 T 模态的特征。
  4. 输出层

    • Transformer 的输出经过归一化(norm)后,分别恢复为 RGB 和 T 模态的各自特征。
    • 最终将两种模态的特征拼接,作为模型的输出。
a. Patch 嵌入

  • 将输入图像切分成固定大小的 Patch,并使用线性投影将每个 Patch 映射到 embed_dim 维度。
 b. 位置嵌入

  • 提供位置编码,使得 Transformer 能够感知输入的空间位置信息。
c. Transformer 块

  • 核心组件是多个堆叠的 Block 模块,数量由 depth 参数指定。
 d. 融合层——SCAM的插入

  • 插入融合层( ffm——SCAM)以实现 RGB 和 T 特征的融合。
  • 通过 rgbs_loc 指定融合层插入的位置。

前向传播

a. 输入处理

b. 特征融合

  • 逐层通过 Transformer 块。
  • 在指定的层上插入 RGB 和 T 模态特征的融合。
 c. 模态恢复

  • 从融合后的特征中恢复 RGB 和 T 的独立特征。
  • 最后将两种模态的特征拼接在一起。

scam.py

路径:lib/models/scanet/mid_rgbs_layer/scam.py

在vit_rgbs.py文件中,搭建了backbone结构,其中融合层ffm 在YAML文件中被定义为SCAM,其在scam.py文件中被定义。

SCAM模块的作用是接收两个输入特征(x1x2),通过交叉注意力机制进行交互,然后在特征中加入残差连接和前馈网络(FFN)。

原论文对SCAM的解释:

SCAM旨在实现空间未对齐的RGB特征与声呐特征的有效跨模态交互。

SCAM由一个空间交叉注意力层和两个独立的全局整合模块(GIM)组成。我们的SCAM模块的操作如下所述:

  • SCAM的输入是Hr和Hs。
  • 首先,计算两个模态的QKV矩阵。RGB模态的查询矩阵表示为Qr = Hr ∈ RN ×C,其中N = Nz + Nx。
  • 键矩阵和值矩阵通过Kr, Vr = Split(Linear(Hr))获得,其中Linear表示没有偏置的线性层,它将特征从RC映射到R2C。
  • Split将特征从R2C映射到两个RC。声呐的Qs, Ks, Vs以相同的方式获得。

 


 center_head.py

路径:lib/models/scanet/head/center_head.py

在YAML配置文件中,Head类型为center_head,它的定义在center_head.py文件中。

  • conv 函数: 构建一个卷积层、BatchNorm 和 ReLU 激活函数的组合。
  • center_head 类
    • 继承自 torch.nn.Module,这是 PyTorch 中用于构建神经网络的基类。
    • 包含:
      1. 构造函数 (__init__):定义网络结构。
      2. forward 函数:定义前向传播逻辑。
      3. 分支模块
        • 中心点分支(预测目标中心点概率图)。
        • 偏移分支(预测目标中心点的偏移量)。
        • 尺寸分支(预测目标的宽度和高度)。
      4. 辅助函数
        • cal_bbox:根据预测结果计算目标边界框。
        • get_score_map:计算三个分支的输出(中心点、尺寸和偏移量)。

原文解释:

 


http://www.kler.cn/a/552880.html

相关文章:

  • 爬取网站内容转为markdown 和 html(通常模式)
  • kotlin Java 使用ArrayList.add() ,set()前面所有值被 覆盖 的问题
  • 上证50股指期货持仓量查询的方式在哪里?
  • STL之string类的模拟实现
  • Pilz安全继电器介绍(PNOZ X2.8P,Pilz MB0)
  • DeepSeek:情智机器人的“情感引擎”与未来变革者
  • Zookeeper 和 Redis 哪种更好?
  • 一键部署开源DeepSeek并集成到钉钉
  • Ubuntu 下 nginx-1.24.0 源码分析 - ngx_get_full_name 函数
  • C++核心指导原则: 函数部分
  • C++字符串处理指南:从基础操作到性能优化——基于std::string的全面解析
  • 【QT常用技术讲解】国产Linux桌面系统+window系统通过窗口句柄对窗口进行操作
  • Jtti.cc:CentOS下PyTorch运行出错怎么办
  • Java集合之ArrayList(含源码解析 超详细)
  • 测试。。。
  • 在高流量下保持WordPress网站的稳定和高效运行
  • C++中为什么有了tuple还需要pair?
  • DeepSeek和ChatGPT的全面对比
  • No.38 蓝队 | 网络安全学习笔记:等级保护与法律法规
  • 华为昇腾服务器部署DeepSeek模型实战