Vision Mamba在AMD GPU上使用ROCm
Vision Mamba on AMD GPU with ROCm — ROCm Blogs
2025年1月24日,作者:Sean Song
状态空间模型 (SSMs),如Mamba,已经成为Transformer模型的潜在替代方案。仅使用SSM的视觉骨干网络已显示出令人鼓舞的结果。有关SSMs和Mamba在AMD硬件上的性能的更多信息,请参阅<Mamba在AMD GPU上使用ROCm的相关信息>。本篇博客探讨了Vision Mamba (Vim),一种创新且高效的视觉任务骨干网络,并评估其在使用ROCm的AMD GPU上的性能。我们将从对Vision Mamba的简要介绍开始,接下来是使用Vision Mamba在AMD GPU上进行训练和推理的分步指南。
Vision Mamba
Vision Mamba (Vim)的灵感来自于语言模型中的Mamba,并将其原理扩展到视觉任务中。然而,由于语言和视觉任务之间的固有差异,直接将Mamba应用于视觉任务时效果并不理想。这是因为Mamba适用于顺序数据的单向建模缺乏对视觉任务至关重要的位置感知。为了解决这个问题,Vim引入了一个双向选择性状态空间模型(SSM)用于全局视觉上下文建模,并结合位置嵌入进行位置感知的视觉识别。
Vim将输入图像划分为补丁,并线性投影为tokens。这些补丁作为一个token序列传递给Vim块,它对token序列进行归一化并将其线性投影到x和z。x序列从前向和后向两个方向进行处理。前向和后向传递的输出通过z进行门控并组合,生成最终的输出token序列。位置嵌入提供空间感知,使Vim在进行密集预测任务时更加鲁棒。有关Vim的更多信息,请参阅Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model。
图像来源: Vision Mamba paper
准备与设置
Vision Mamba (Vim) 的源代码可以在 Vision Mamba 仓库 中找到。`causal-conv1d` 和 mamba-1p1p1
文件夹包含了用于双向 Mamba 的硬件自适应优化的 CUDA 源代码。如果要在 ROCm 上运行此代码,则需要将 CUDA 源代码转换为 HIP C++。HIP 是一个 C++ 运行时 API 和内核语言,使开发人员可以使用同一源代码为 AMD 和 NVIDIA GPU 创建可移植应用程序。PyTorch 使用一个名为 Hipify_torch 的工具将源代码从 CUDA 转换为 HIP,从而生成能够在 ROCm 上运行的自定义内核。在构建 CUDA 扩展时,转换将在 PyTorch 内部完成,从而在使用自定义内核时确保无缝体验。有关 HIPIFY 的更多信息,请参见 HIPIFY 文档。
有关设置的全面支持详细信息,请参阅 ROCm 文档。本文使用了以下配置创建。
-
硬件与操作系统:
-
AMD Instinct GPU
-
Ubuntu 22.04.3 LTS
-
-
软件:
-
ROCm 6.1+
-
PyTorch 2.1+ for ROCm
-
本文在一台配备 MI210 GPU 且安装了 AMD GPU 驱动程序版本 6.7.0 的 Linux 机器上使用了 rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.1.2
docker 镜像。
开始使用
使用 rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.1.2 Docker 镜像并在容器中构建 Vision Mamba。
docker pull rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.1.2 docker run -it --name vision_mamba --rm --ipc=host \ --device=/dev/kfd --device=/dev/dri/ \ --group-add=video --shm-size 8G \ rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.1.2
在 AMD GPU 上通过 ROCm 构建并安装 Vision Mamba。
git clone https://github.com/AMD-AI/Vim.git cd Vim git checkout add_rocm_support pip install -r vim/vim_requirements.txt # Install hipified packages required by Vision Mamba pip install -e ./causal-conv1d pip install -e ./mamba-1p1p1
Vim 有三个版本的模型:
模型 | #param. | Top-1 Acc. | Top-5 Acc. |
---|---|---|---|
Vim-tiny | 7M | 76.1 | 93.0 |
Vim-tiny+ | 7M | 78.3 | 94.2 |
Vim-small | 26M | 80.5 | 95.1 |
Vim-small+ | 26M | 81.6 | 95.4 |
Vim-base | 98M | 81.9 | 95.8 |
本文在后续测试中使用 Vim-small+。
注 + 表示模型在较短时间内进行了细化调优。请使用以下命令下载 Vim-small 的权重。
wget https://huggingface.co/hustvl/Vim-small-midclstok/resolve/main/vim_s_midclstok_ft_81p6acc.pth
完成这些步骤后,您将获得 Vim-small 的权重文件(例如,vim_s_midclstok_ft_81p6acc.pth)。
注意: 如果您只需进行推理,则可以跳过以下数据集下载步骤。数据集仅在训练和准确性测试中需要。
使用 ImageNet 数据集 进行训练和测试。ImageNet 是视觉模型的一个流行基准。使用以下命令下载它。根据您的网络速度,此过程可能需要数小时。
mkdir image_dataset cd image_dataset wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar # Extract each .tar file into its own directory find . -name "*.tar" | while read -r NAME; do mkdir -p "${NAME%.tar}" tar -xvf "${NAME}" -C "${NAME%.tar}" && rm -f "${NAME}" done rm train/n04266014/n04266014_10835.JPEG cd .. mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash
一切都已经完成,可以在使用ROCm的AMD GPU上进行Vision Mamba的准确性测试、训练和推理。
准确性测试
在这一部分,我们将评估Vision Mamba(小型)在ImageNet数据集上的表现。准确性测试将帮助验证模型在AMD GPU上与ROCm正常运行。
cd Vim python ./vim/main.py --eval --resume ./vim_s_midclstok_ft_81p6acc.pth \ --model vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 \ --data-path ./image_dataset
输出应该类似于下面的内容。完整的训练日志可以在ROCm博客仓库中找到。
Namespace(batch_size=64, epochs=300, bce_loss=False, unscale_lr=False, model='vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.05, sched='cosine', lr=0.0005, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-06, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=False, src=False, reprob=0.25, remode='pixel', recount=1, resplit=False, mixup=0.8, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, cosub=False, finetune='', attn_only=False, data_path='./image1k2012/tarfile', data_set='IMNET', inat_category='name', output_dir='', device='cuda', seed=0, resume='./vim_s_midclstok_ft_81p6acc.pth', start_epoch=0, eval=True, eval_crop_ratio=0.875, dist_eval=False, num_workers=10, pin_mem=True, distributed=False, world_size=1, dist_url='env://', if_amp=True, if_continue_inf=False, if_nan2num=False, if_random_cls_token_position=False, if_random_token_rank=False, local_rank=0, gpu=None) Creating model: vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 number of params: 26001256 Test: [ 0/27] eta: 0:22:19 loss: 0.5674 (0.5674) acc1: 88.4896 (88.4896) acc5: 98.3333 (98.3333) time: 49.6074 data: 18.4177 max mem: 49186 Test: [10/27] eta: 0:03:52 loss: 0.7132 (0.6941) acc1: 83.8021 (85.2510) acc5: 97.3438 (97.2064) time: 13.6715 data: 1.6747 max mem: 49193 Test: [20/27] eta: 0:01:23 loss: 0.8571 (0.8317) acc1: 81.5625 (82.3289) acc5: 95.4167 (95.6324) time: 10.0834 data: 0.0003 max mem: 49193 Test: [26/27] eta: 0:00:11 loss: 0.9202 (0.8809) acc1: 80.2604 (81.5600) acc5: 94.2188 (95.4420) time: 9.6160 data: 0.0002 max mem: 49193 Test: Total time: 0:05:02 (11.2071 s / it) * Acc@1 81.560 Acc@5 95.442 loss 0.881 Accuracy of the network on the 50000 test images: 81.6%
这个结果与作者报告的结果几乎一致,确认了启用了ROCm的配置和使用Hipify_torch进行的修改工作正常。
在使用ROCm的AMD GPU上进行Vision Mamba分布式数据并行训练
训练可能会花费很长时间(即几天),如果你使用的是小型GPU。为了加速训练过程,可以使用DistributedDataParallel(分布式数据并行)来跨所有8块AMD Instinct MI210 GPU运行多个进程。根据你的GPU类型和内存容量,你需要调整batch_size和num_workers的值——要么增加它们以最大化资源利用率,要么减少它们以防止内存不足(OOM)问题。根据我们的测试,使用8块AMD Instinct MI210 GPU训练大约需要5个小时,使用8块AMD Instinct MI300X GPU训练大约需要2个小时。请注意,这些时间仅供参考,并未进行最快训练速度的优化。实际性能可能会根据你的具体设置而有所不同。
cd Vim torchrun --nnodes 1 --nproc_per_node 8 ./vim/main.py \ --model vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 \ --batch-size 128 --lr 5e-6 --min-lr 1e-5 --warmup-lr 1e-5 --drop-path 0.0 --weight-decay 1e-8 --num_workers 8 \ --data-path ./image_dataset --output_dir ./output/vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2\ --epochs 1 --finetune ./vim_s_midclstok_ft_81p6acc.pth --no_amp
Outputs: ```text [2024-09-09 04:43:53,537] torch.distributed.run: [WARNING] [2024-09-09 04:43:53,537] torch.distributed.run: [WARNING] ***************************************** [2024-09-09 04:43:53,537] torch.distributed.run: [WARNING] 设置每个进程的OMP_NUM_THREADS环境变量为1以避免系统过载,请根据需要进一步调优该变量以获得最佳性能。 [2024-09-09 04:43:53,537] torch.distributed.run: [WARNING] ***************************************** Namespace(batch_size=128, epochs=1, bce_loss=False, unscale_lr=False, model='vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2', input_size=224, drop=0.0, drop_path=0.0, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=1e-08, sched='cosine', lr=5e-06, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-05, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=False, src=False, reprob=0.25, remode='pixel', recount=1, resplit=False, mixup=0.8, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, cosub=False, finetune='./vim_s_midclstok_ft_81p6acc.pth', attn_only=False, data_path='./image_dataset', data_set='IMNET', inat_category='name', output_dir='./output/vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2', device='cuda', seed=0, resume='', start_epoch=0, eval=False, eval_crop_ratio=0.875, dist_eval=False, num_workers=8, pin_mem=True, distributed=True, world_size=8, dist_url='env://', if_amp=False, if_continue_inf=False, if_nan2num=False, if_random_cls_token_position=False, if_random_token_rank=False, local_rank=0, gpu=0, rank=0, dist_backend='nccl') 创建模型: vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 参数数量: 26001256 开始训练,时长为1个epoch Epoch: [0] [ 0/1251] 预计用时: 4:05:42 学习率: 0.000010 损失: 2.8620 (2.8620) 时间: 11.7846 数据: 1.4589 最大内存: 58000 Epoch: [0] [ 10/1251] eta: 1:21:45 lr: 0.000010 loss: 2.8620 (2.7414) time: 3.9532 data: 0.1329 max mem: 58204 Epoch: [0] [ 20/1251] eta: 1:13:30 lr: 0.000010 loss: 2.7010 (2.7007) time: 3.1725 data: 0.0003 max mem: 58204 Epoch: [0] [ 30/1251] eta: 1:10:40 lr: 0.000010 loss: 2.7748 (2.7117) time: 3.2089 data: 0.0003 max mem: 58205 Epoch: [0] [ 40/1251] eta: 1:08:46 lr: 0.000010 loss: 2.7195 (2.6884) time: 3.2240 data: 0.0003 max mem: 58205 Epoch: [0] [ 50/1251] eta: 1:07:18 lr: 0.000010 loss: 2.6703 (2.6825) time: 3.1924 data: 0.0003 max mem: 58206 ... Epoch: [0] [1210/1251] eta: 0:02:13 lr: 0.000010 loss: 2.6030 (2.5552) time: 3.1808 data: 0.0004 max mem: 58206 Epoch: [0] [1220/1251] eta: 0:01:41 lr: 0.000010 loss: 2.5304 (2.5541) time: 3.1812 data: 0.0004 max mem: 58206 Epoch: [0] [1230/1251] eta: 0:01:08 lr: 0.000010 loss: 2.5217 (2.5542) time: 3.1813 data: 0.0004 max mem: 58206 Epoch: [0] [1240/1251] eta: 0:00:35 lr: 0.000010 loss: 2.6710 (2.5543) time: 3.1810 data: 0.0006 max mem: 58206 Epoch: [0] [1250/1251] eta: 0:00:03 lr: 0.000010 loss: 2.5232 (2.5541) time: 3.1798 data: 0.0005 max mem: 58206 Epoch: [0] Total time: 1:08:01 (3.2629 s / it) Averaged stats: lr: 0.000010 loss: 2.5232 (2.5604) Test: [ 0/14] eta: 1:01:08 loss: 0.7193 (0.7193) acc1: 83.8802 (83.8802) acc5: 97.0313 (97.0313) time: 262.0067 data: 77.2542 max mem: 99122 Test: [10/14] eta: 0:02:53 loss: 0.8385 (0.8442) acc1: 82.8906 (81.9058) acc5: 95.1042 (95.4380) time: 43.4364 data: 7.0233 max mem: 99137 Test: [13/14] eta: 0:00:37 loss: 0.8385 (0.9020) acc1: 81.4583 (81.5280) acc5: 95.0000 (95.3680) time: 37.1473 data: 5.5183 max mem: 99137 Test: Total time: 0:08:40 (37.1782 s / it) * Acc@1 81.469 Acc@5 95.298 loss 0.906 Accuracy of the network on the 50000 test images: 81.5% Max accuracy: 81.47% Training time 1:16:43
根据论文,`vim_s_midclstok_ft_81p6acc.pth` 检查点是在 ImageNet 数据集上进行微调的。在本博客中进行训练的目的是为了验证我们使用 “Hipify_torch” 进行修改后的分布式数据并行(DDP)训练在搭载 ROCm 的 AMD GPU 上能否正常工作,而不是通过进一步调整设置来超越现有结果。
训练完成后,检查点将存放在 ./output/vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2
目录中。
在AMD GPU上使用ROCm进行Vision Mamba推理
Vim可以用于推理任务,如图像分类、分割和检测。为了展示Vim在使用AMD GPU和ROCm时的推理能力,我们将使用前一步生成的模型进行图像分类任务。测试过程中使用的图片(`cab.png` 和 cat.jpeg
)和文件(`imagenet_class_index.json`)可以从 ROCm博客库获取。
pip install rope
import torch from PIL import Image from torchvision import transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from vim.models_mamba import VisionMamba from vim.models_mamba import ( vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2) # 创建Vim模型并加载权重 model = vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2() sd = torch.load("./output/vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2/best_checkpoint.pth") model.load_state_dict(sd["model"]) model.eval() model.to("cuda") def inference(model, image): ## 预处理图像 test_image = Image.open(image).convert('RGB') test_image.show() test_image = test_image.resize((224, 224)) image_as_tensor = transforms.ToTensor()(test_image) normalized_tensor = transforms.Normalize( IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD )(image_as_tensor) ## 使用Vision Mamba进行推理 x = normalized_tensor.unsqueeze(0).cuda() pred = model(x) ## 解码输出并打印类别 import json f = open('./src/imagenet_class_index.json') class_idx = json.load(f) idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))] print(f"label: class - {pred.argmax()}:{idx2label[pred.argmax()]}") inference(model,"./image/cab.png") inference(model,"./image/cat.jpeg")
输出:
标签: 类别 - 468:cab
标签: 类别 - 282:tiger_cat
输出结果看起来是正确的!模型正确识别了图像。这展示了Vision Mamba可以在使用ROCm的AMD GPU上用于推理任务。
总结
在这篇博客中,我们探讨了在AMD GPU上使用ROCm进行Vision Mamba,展示了其在视觉任务中的能力和性能。Hipify后的Vision Mamba有效利用AMD硬件进行训练和推理,提供了一种传统模型的强大替代方案。我们鼓励读者在使用ROCm进行计算机视觉应用时,尝试使用Vision Mamba。