ms-swift3 序列分类训练
目录
引言
一、数据集准备
二、训练/推理代码
2.1 训练
2.2 推理
三、性能验证
引言
swift 3.x支持了序列分类Command Line Parameters — swift 3.2.0.dev0 documentation
想尝试一下用多模态(图像)的序列分类与普通的图像分类任务有啥区别
一、数据集准备
根据官方给出的自定义序列分类数据集格式Custom Dataset — swift 3.2.0.dev0 documentation 和多模态数据集的格式Custom Dataset — swift 3.2.0.dev0 documentation
可以结合二者得到一个官方支持的多模态数据序列分类格式,简单来说就是把多模态SFT数据集中的assistant的字段改成label字段
{"messages": [{"role": "user", "content": "<image><image>What is the difference between the two images?"}], "images": ["/xxx/x.jpg", "/xxx/x.png"], "label": 0}
当然在这里也可以自行注册一个新数据集格式,不过太麻烦,而且可迁移性就降低了。
具体地,比如一个分类任务,可以构造这样的一个分类数据集cat_cls.jsonl文件
{"messages": [{"role": "user", "content": "<image>这是什么品种的猫?"}], "images": ["/xxx/1.jpg"], "label": 0}
{"messages": [{"role": "user", "content": "<image>这是什么品种的猫?"}], "images": ["/xxx/2.jpg"], "label": 1}
注意label是从0开始计数
二、训练/推理代码
2.1 训练
根据官方给的example:https://github.com/modelscope/ms-swift/tree/main/examples/train/seq_cls
以qwen2-vl-base模型为例,训练的代码为:
# If `num_labels` is provided, it will be considered a classification task.
# You can also specify `--model Qwen/Qwen2-VL-2B-Instruct --use_chat_template true`.
CUDA_VISIBLE_DEVICES=0 \
MAX_PIXELS=1003520 \
swift sft \
--model Qwen/Qwen2-VL-2B \
--train_type lora \
--dataset 'tany0699/garbage265#20000' \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-4 \
--lora_rank 8 \
--lora_alpha 32 \
--target_modules all-linear \
--gradient_accumulation_steps 16 \
--eval_steps 50 \
--save_steps 50 \
--save_total_limit 5 \
--logging_steps 5 \
--max_length 2048 \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--num_labels 265 \
--task_type seq_cls \
--use_chat_template false
其中主要修改--model,--dataset,--num_labels这三个参数。
注意开源instruct模型以及swift 3.x训练得到的模型可以直接用--model 单参数传模型文件路径。
但是swift 2.x训练保存的模型还需要额外加--model_type参数 比如--model_type qwen2_vl
原理是3.x保存的模型里已有这个参数 之前的版本保存的模型没有 会有一个不兼容的问题。
此外base模型不需要chat_template,但是instruct模型是需要的,不过实测在序列分类任务上这个参数对最后训练的模型性能影响不大。
具体地,--model 设置为本地保存的模型路径,--dataset为上面构造的数据集jsonl文件路径,--num_labels为类别数量,注意从0开始计数,有几类填几(int)
2.2 推理
训练保存的模型形式和推理的代码基本和普通SFT没区别:https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/infer.sh
## full 全参数训练
CUDA_VISIBLE_DEVICES=0 \
MAX_PIXELS=1003520 \
swift infer \
--model output/vx-xxx/checkpoint-xxx \
--load_data_args true
## lora训练
CUDA_VISIBLE_DEVICES=0 \
MAX_PIXELS=1003520 \
swift infer \
--adapter output/vx-xxx/checkpoint-xxx \
--load_data_args true
这里有一个坑,官方给出的example由于是lora训练,因此infer的脚本也只适配lora模型,具体来说就是用--adapter 指定保存下来的模型文件夹路径。但是对于全参数训练(full sft),需要用--model代替--adapter,不然实测输出全是某一个label。
此外注意确认训练的时候指定了val_dataset,否则--load_data_args要去掉 用--val_dataset代替。
三、性能验证
最终在某个自定义的图片分类数据集上试了一下,整体acc下降,但是推理速度能有提升。考虑到准确率,还是继续用causal_lm来进行图像分类训练了。
训练方式 | 精度-Acc | 推理速度 |
causal_lm | 82.5% | 3.79it/s |
seq_cls | 72.76% | 4.97it/s |