当前位置: 首页 > article >正文

Vision Mamba在AMD GPU上使用ROCm

Vision Mamba on AMD GPU with ROCm — ROCm Blogs

2025年1月24日,作者:Sean Song

状态空间模型 (SSMs),如Mamba,已经成为Transformer模型的潜在替代方案。仅使用SSM的视觉骨干网络已显示出令人鼓舞的结果。有关SSMs和Mamba在AMD硬件上的性能的更多信息,请参阅<Mamba在AMD GPU上使用ROCm的相关信息>。本篇博客探讨了Vision Mamba (Vim),一种创新且高效的视觉任务骨干网络,并评估其在使用ROCm的AMD GPU上的性能。我们将从对Vision Mamba的简要介绍开始,接下来是使用Vision Mamba在AMD GPU上进行训练和推理的分步指南。

Vision Mamba

Vision Mamba (Vim)的灵感来自于语言模型中的Mamba,并将其原理扩展到视觉任务中。然而,由于语言和视觉任务之间的固有差异,直接将Mamba应用于视觉任务时效果并不理想。这是因为Mamba适用于顺序数据的单向建模缺乏对视觉任务至关重要的位置感知。为了解决这个问题,Vim引入了一个双向选择性状态空间模型(SSM)用于全局视觉上下文建模,并结合位置嵌入进行位置感知的视觉识别。

Vim将输入图像划分为补丁,并线性投影为tokens。这些补丁作为一个token序列传递给Vim块,它对token序列进行归一化并将其线性投影到x和z。x序列从前向和后向两个方向进行处理。前向和后向传递的输出通过z进行门控并组合,生成最终的输出token序列。位置嵌入提供空间感知,使Vim在进行密集预测任务时更加鲁棒。有关Vim的更多信息,请参阅Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model。

png

图像来源: Vision Mamba paper

准备与设置

Vision Mamba (Vim) 的源代码可以在 Vision Mamba 仓库 中找到。`causal-conv1d` 和 mamba-1p1p1 文件夹包含了用于双向 Mamba 的硬件自适应优化的 CUDA 源代码。如果要在 ROCm 上运行此代码,则需要将 CUDA 源代码转换为 HIP C++。HIP 是一个 C++ 运行时 API 和内核语言,使开发人员可以使用同一源代码为 AMD 和 NVIDIA GPU 创建可移植应用程序。PyTorch 使用一个名为 Hipify_torch 的工具将源代码从 CUDA 转换为 HIP,从而生成能够在 ROCm 上运行的自定义内核。在构建 CUDA 扩展时,转换将在 PyTorch 内部完成,从而在使用自定义内核时确保无缝体验。有关 HIPIFY 的更多信息,请参见 HIPIFY 文档。

有关设置的全面支持详细信息,请参阅 ROCm 文档。本文使用了以下配置创建。

  • 硬件与操作系统:

    • AMD Instinct GPU

    • Ubuntu 22.04.3 LTS

  • 软件:

    • ROCm 6.1+

    • PyTorch 2.1+ for ROCm

本文在一台配备 MI210 GPU 且安装了 AMD GPU 驱动程序版本 6.7.0 的 Linux 机器上使用了 rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.1.2 docker 镜像。

开始使用

使用 rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.1.2 Docker 镜像并在容器中构建 Vision Mamba。

docker pull rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.1.2
docker run -it --name vision_mamba --rm --ipc=host \
            --device=/dev/kfd --device=/dev/dri/ \
            --group-add=video --shm-size 8G \
            rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.1.2

在 AMD GPU 上通过 ROCm 构建并安装 Vision Mamba。

git clone https://github.com/AMD-AI/Vim.git
cd Vim 
git checkout add_rocm_support

pip install -r vim/vim_requirements.txt
# Install hipified packages required by Vision Mamba
pip install -e ./causal-conv1d
pip install -e ./mamba-1p1p1

Vim 有三个版本的模型:

模型

#param.

Top-1 Acc.

Top-5 Acc.

Vim-tiny

7M

76.1

93.0

Vim-tiny+

7M

78.3

94.2

Vim-small

26M

80.5

95.1

Vim-small+

26M

81.6

95.4

Vim-base

98M

81.9

95.8

本文在后续测试中使用 Vim-small+。

 + 表示模型在较短时间内进行了细化调优。请使用以下命令下载 Vim-small 的权重。
wget https://huggingface.co/hustvl/Vim-small-midclstok/resolve/main/vim_s_midclstok_ft_81p6acc.pth

完成这些步骤后,您将获得 Vim-small 的权重文件(例如,vim_s_midclstok_ft_81p6acc.pth)。

注意: 如果您只需进行推理,则可以跳过以下数据集下载步骤。数据集仅在训练和准确性测试中需要。

使用 ImageNet 数据集 进行训练和测试。ImageNet 是视觉模型的一个流行基准。使用以下命令下载它。根据您的网络速度,此过程可能需要数小时。

mkdir image_dataset
cd image_dataset
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar
mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
# Extract each .tar file into its own directory
find . -name "*.tar" | while read -r NAME; do
    mkdir -p "${NAME%.tar}"
    tar -xvf "${NAME}" -C "${NAME%.tar}" && rm -f "${NAME}"
done

rm train/n04266014/n04266014_10835.JPEG
cd ..
mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash

一切都已经完成,可以在使用ROCm的AMD GPU上进行Vision Mamba的准确性测试、训练和推理。

准确性测试

在这一部分,我们将评估Vision Mamba(小型)在ImageNet数据集上的表现。准确性测试将帮助验证模型在AMD GPU上与ROCm正常运行。

cd Vim
python ./vim/main.py --eval --resume ./vim_s_midclstok_ft_81p6acc.pth \
    --model vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2  \
    --data-path ./image_dataset

输出应该类似于下面的内容。完整的训练日志可以在ROCm博客仓库中找到。

Namespace(batch_size=64, epochs=300, bce_loss=False, unscale_lr=False, model='vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.05, sched='cosine', lr=0.0005, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-06, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=False, src=False, reprob=0.25, remode='pixel', recount=1, resplit=False, mixup=0.8, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, cosub=False, finetune='', attn_only=False, data_path='./image1k2012/tarfile', data_set='IMNET', inat_category='name', output_dir='', device='cuda', seed=0, resume='./vim_s_midclstok_ft_81p6acc.pth', start_epoch=0, eval=True, eval_crop_ratio=0.875, dist_eval=False, num_workers=10, pin_mem=True, distributed=False, world_size=1, dist_url='env://', if_amp=True, if_continue_inf=False, if_nan2num=False, if_random_cls_token_position=False, if_random_token_rank=False, local_rank=0, gpu=None)
Creating model: vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2
number of params: 26001256
Test:  [ 0/27]  eta: 0:22:19  loss: 0.5674 (0.5674)  acc1: 88.4896 (88.4896)  acc5: 98.3333 (98.3333)  time: 49.6074  data: 18.4177  max mem: 49186
Test:  [10/27]  eta: 0:03:52  loss: 0.7132 (0.6941)  acc1: 83.8021 (85.2510)  acc5: 97.3438 (97.2064)  time: 13.6715  data: 1.6747  max mem: 49193
Test:  [20/27]  eta: 0:01:23  loss: 0.8571 (0.8317)  acc1: 81.5625 (82.3289)  acc5: 95.4167 (95.6324)  time: 10.0834  data: 0.0003  max mem: 49193
Test:  [26/27]  eta: 0:00:11  loss: 0.9202 (0.8809)  acc1: 80.2604 (81.5600)  acc5: 94.2188 (95.4420)  time: 9.6160  data: 0.0002  max mem: 49193
Test: Total time: 0:05:02 (11.2071 s / it)
* Acc@1 81.560 Acc@5 95.442 loss 0.881
Accuracy of the network on the 50000 test images: 81.6%

这个结果与作者报告的结果几乎一致,确认了启用了ROCm的配置和使用Hipify_torch进行的修改工作正常。

在使用ROCm的AMD GPU上进行Vision Mamba分布式数据并行训练

训练可能会花费很长时间(即几天),如果你使用的是小型GPU。为了加速训练过程,可以使用DistributedDataParallel(分布式数据并行)来跨所有8块AMD Instinct MI210 GPU运行多个进程。根据你的GPU类型和内存容量,你需要调整batch_size和num_workers的值——要么增加它们以最大化资源利用率,要么减少它们以防止内存不足(OOM)问题。根据我们的测试,使用8块AMD Instinct MI210 GPU训练大约需要5个小时,使用8块AMD Instinct MI300X GPU训练大约需要2个小时。请注意,这些时间仅供参考,并未进行最快训练速度的优化。实际性能可能会根据你的具体设置而有所不同。

cd Vim
torchrun --nnodes 1 --nproc_per_node 8 ./vim/main.py \
    --model vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 \
    --batch-size 128 --lr 5e-6 --min-lr 1e-5 --warmup-lr 1e-5 --drop-path 0.0 --weight-decay 1e-8 --num_workers 8 \
    --data-path  ./image_dataset --output_dir ./output/vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2\
    --epochs 1 --finetune  ./vim_s_midclstok_ft_81p6acc.pth --no_amp
Outputs:
```text
[2024-09-09 04:43:53,537] torch.distributed.run: [WARNING] 
[2024-09-09 04:43:53,537] torch.distributed.run: [WARNING] *****************************************
[2024-09-09 04:43:53,537] torch.distributed.run: [WARNING] 设置每个进程的OMP_NUM_THREADS环境变量为1以避免系统过载,请根据需要进一步调优该变量以获得最佳性能。
[2024-09-09 04:43:53,537] torch.distributed.run: [WARNING] *****************************************
Namespace(batch_size=128, epochs=1, bce_loss=False, unscale_lr=False, model='vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2', input_size=224, drop=0.0, drop_path=0.0, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=1e-08, sched='cosine', lr=5e-06, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-05, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=False, src=False, reprob=0.25, remode='pixel', recount=1, resplit=False, mixup=0.8, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, cosub=False, finetune='./vim_s_midclstok_ft_81p6acc.pth', attn_only=False, data_path='./image_dataset', data_set='IMNET', inat_category='name', output_dir='./output/vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2', device='cuda', seed=0, resume='', start_epoch=0, eval=False, eval_crop_ratio=0.875, dist_eval=False, num_workers=8, pin_mem=True, distributed=True, world_size=8, dist_url='env://', if_amp=False, if_continue_inf=False, if_nan2num=False, if_random_cls_token_position=False, if_random_token_rank=False, local_rank=0, gpu=0, rank=0, dist_backend='nccl')
创建模型: vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2
参数数量: 26001256
开始训练,时长为1个epoch
Epoch: [0]  [   0/1251]  预计用时: 4:05:42  学习率: 0.000010  损失: 2.8620 (2.8620)  时间: 11.7846  数据: 1.4589   最大内存: 58000
Epoch: [0]  [  10/1251]  eta: 1:21:45  lr: 0.000010  loss: 2.8620 (2.7414)  time: 3.9532  data: 0.1329  max mem: 58204
Epoch: [0]  [  20/1251]  eta: 1:13:30  lr: 0.000010  loss: 2.7010 (2.7007)  time: 3.1725  data: 0.0003  max mem: 58204
Epoch: [0]  [  30/1251]  eta: 1:10:40  lr: 0.000010  loss: 2.7748 (2.7117)  time: 3.2089  data: 0.0003  max mem: 58205
Epoch: [0]  [  40/1251]  eta: 1:08:46  lr: 0.000010  loss: 2.7195 (2.6884)  time: 3.2240  data: 0.0003  max mem: 58205
Epoch: [0]  [  50/1251]  eta: 1:07:18  lr: 0.000010  loss: 2.6703 (2.6825)  time: 3.1924  data: 0.0003  max mem: 58206
...
Epoch: [0]  [1210/1251]  eta: 0:02:13  lr: 0.000010  loss: 2.6030 (2.5552)  time: 3.1808  data: 0.0004  max mem: 58206
Epoch: [0]  [1220/1251]  eta: 0:01:41  lr: 0.000010  loss: 2.5304 (2.5541)  time: 3.1812  data: 0.0004  max mem: 58206
Epoch: [0]  [1230/1251]  eta: 0:01:08  lr: 0.000010  loss: 2.5217 (2.5542)  time: 3.1813  data: 0.0004  max mem: 58206
Epoch: [0]  [1240/1251]  eta: 0:00:35  lr: 0.000010  loss: 2.6710 (2.5543)  time: 3.1810  data: 0.0006  max mem: 58206
Epoch: [0]  [1250/1251]  eta: 0:00:03  lr: 0.000010  loss: 2.5232 (2.5541)  time: 3.1798  data: 0.0005  max mem: 58206
Epoch: [0] Total time: 1:08:01 (3.2629 s / it)
Averaged stats: lr: 0.000010  loss: 2.5232 (2.5604)
Test:  [ 0/14]  eta: 1:01:08  loss: 0.7193 (0.7193)  acc1: 83.8802 (83.8802)  acc5: 97.0313 (97.0313)  time: 262.0067  data: 77.2542  max mem: 99122
Test:  [10/14]  eta: 0:02:53  loss: 0.8385 (0.8442)  acc1: 82.8906 (81.9058)  acc5: 95.1042 (95.4380)  time: 43.4364  data: 7.0233  max mem: 99137
Test:  [13/14]  eta: 0:00:37  loss: 0.8385 (0.9020)  acc1: 81.4583 (81.5280)  acc5: 95.0000 (95.3680)  time: 37.1473  data: 5.5183  max mem: 99137
Test: Total time: 0:08:40 (37.1782 s / it)
* Acc@1 81.469 Acc@5 95.298 loss 0.906
Accuracy of the network on the 50000 test images: 81.5%
Max accuracy: 81.47%
Training time 1:16:43

根据论文,`vim_s_midclstok_ft_81p6acc.pth` 检查点是在 ImageNet 数据集上进行微调的。在本博客中进行训练的目的是为了验证我们使用 “Hipify_torch” 进行修改后的分布式数据并行(DDP)训练在搭载 ROCm 的 AMD GPU 上能否正常工作,而不是通过进一步调整设置来超越现有结果。

训练完成后,检查点将存放在 ./output/vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 目录中。 

在AMD GPU上使用ROCm进行Vision Mamba推理

Vim可以用于推理任务,如图像分类、分割和检测。为了展示Vim在使用AMD GPU和ROCm时的推理能力,我们将使用前一步生成的模型进行图像分类任务。测试过程中使用的图片(`cab.png` 和 cat.jpeg)和文件(`imagenet_class_index.json`)可以从 ROCm博客库获取。

pip install rope
import torch
from PIL import Image
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from vim.models_mamba import VisionMamba
from vim.models_mamba import (
    vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2)

# 创建Vim模型并加载权重
model = vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2()
sd = torch.load("./output/vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2/best_checkpoint.pth")
model.load_state_dict(sd["model"])
model.eval()
model.to("cuda")


def inference(model, image):
    ## 预处理图像
    test_image = Image.open(image).convert('RGB')
    test_image.show()
    test_image = test_image.resize((224, 224))
    image_as_tensor = transforms.ToTensor()(test_image)
    normalized_tensor = transforms.Normalize(
        IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
    )(image_as_tensor)

    ## 使用Vision Mamba进行推理
    x = normalized_tensor.unsqueeze(0).cuda()
    pred = model(x)

    ## 解码输出并打印类别
    import json
    f = open('./src/imagenet_class_index.json')
    class_idx = json.load(f)
    idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
    print(f"label: class - {pred.argmax()}:{idx2label[pred.argmax()]}")

inference(model,"./image/cab.png")
inference(model,"./image/cat.jpeg")

输出:

png

标签: 类别 - 468:cab

png

标签: 类别 - 282:tiger_cat

输出结果看起来是正确的!模型正确识别了图像。这展示了Vision Mamba可以在使用ROCm的AMD GPU上用于推理任务。

总结

在这篇博客中,我们探讨了在AMD GPU上使用ROCm进行Vision Mamba,展示了其在视觉任务中的能力和性能。Hipify后的Vision Mamba有效利用AMD硬件进行训练和推理,提供了一种传统模型的强大替代方案。我们鼓励读者在使用ROCm进行计算机视觉应用时,尝试使用Vision Mamba。


http://www.kler.cn/a/522198.html

相关文章:

  • 【Rust自学】14.6. 安装二进制crate
  • PySide(PyQT)进行SQLite数据库编辑和前端展示的基本操作
  • MotionLCM 部署笔记
  • 【电工基础】2.低压带电作业定义,范围,工作要求,电工基本工具
  • 1.27补题 回训练营
  • 【每日一A】2015NOIP真题 (二分+贪心) python
  • c语言版贪吃蛇(Pro Max版)附源代码
  • 题解 信息学奥赛一本通/AcWing 1118 分成互质组 DFS C++
  • 010 mybatis-PageHelper分页插件
  • 精通PCIe技术:协议解析与UVM验证实战
  • 大数据学习之SCALA分布式语言三
  • POWER SCHEDULER:一种与批次大小和token数量无关的学习率调度器
  • Mac Electron 应用签名(signature)和公证(notarization)
  • Mybatis初步了解
  • RU 19.26安装(手工安装各个补丁)
  • wxPython中wx.ListCtrl用法(四)
  • 66-《虞美人》
  • 从ai产品推荐到利用cursor快速掌握一个开源项目再到langchain手搓一个Text2Sql agent
  • 4.scala默认参数值
  • YOLO目标检测4
  • C#面试常考随笔6:ArrayList和 List的主要区别?
  • deepseek R1的确不错,特别是深度思考模式
  • excel如何查找一个表的数据在另外一个表是否存在
  • clean code阅读笔记——如何命名?
  • Nacos深度解析:构建高效微服务架构的利器
  • Python3 【高阶函数】项目实战:5 个学习案例