MMFewShot
一、配置环境
先看看版本要求:
当然在我运行的时候,mmcv的版本还不能超过1.6.0,mmcv的安装跟torch相关,如果你的mmcv-full一直安装不上,那很大可能是你的torch版本不匹配,下面给出与torch版本匹配的mmcv-full安装方法。
安装环境指令:
conda create -n mmfewshot python=3.8
conda activate mmfewshot
#关于安装torch,建议直接使用pytorch官方安装链接
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
#安装mmcv-full
pip install mmcv-full==1.6.0 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html
# install mmclassification mmdetection
pip install mmcls==0.23.2 -i https://pypi.tuna.tsinghua.edu.cn/simple/
pip install mmdet==2.25.0 -i https://pypi.tuna.tsinghua.edu.cn/simple/
pip install mmfewshot==0.1.0
# install mmfewshot 如果出现Connection timed out,打开梯子即可
git clone https://github.com/open-mmlab/mmfewshot.git
cd mmfewshot
pip install -r requirements/build.txt
pip install -v -e . # or "python setup.py develop"
#后续:出现TypeError: FormatCode() got an unexpected keyword argument 'verify'
pip install yapf==0.40.0
创建好虚拟环境后就安装pytorch,pytorch官方链接,注意pytorch的版本不能太高,比如我之前安装的是1.13.0,最后导致mmcv-full版本太高(>1.6.0)所以后来改为了1.9.0,安装mmcv-full的指令根据:pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html来确定。然后后面的指令一次来就行。然后这里是官方环境安装文档:链接
二、下载mmfewshot
链接:GitHub - open-mmlab/mmfewshot: OpenMMLab FewShot Learning Toolbox and Benchmark
三、检验环境是否安装无误
运行tools->detection->train.py,如果出现
则表示环境安装成功,如果出现要求版本过高比如mmdet,mmcls\mmcv-full那就根据提示安装所能满足条件的最高版本,亲测有用!
再者,还有一种方法:
进__init__.py当中修改也可以:
四、数据集准备
如果使用coco格式,使用labelme2coco.py转换为coco格式的json文件
五、训练base_model
1.修改mmfewshot\detection\datasets\coco.py中的ALL_CLASSES、NOVEL_CLASSES和BASE_CLASSES,否则回报错:IndexError: FewShotCocoDataset: list index out of range
2.修改configs\detection\fsce\coco\fsce_r101_fpn_coco_base-training.py中的num_classes和max_iters
3.修改configs/detection/_base_/schedules/schedule.py学习率
4.查看配置文件所有参数请运行python tools/misc/print_config.py /PATH/TO/CONFIG
六、转换权重文件
1.修改tools\detection\misc\initialize_bbox_head.py中的COCO_NOVEL_CLASSES和COCO_BASE_CLASSES,与json文件里面的category_id有关,不是从0开始的,本例子hu是1,shu是2
2.修改tools\detection\misc\initialize_bbox_head.py中的COCO_TAR_SIZE
3.python -m tools.detection.misc.initialize_bbox_head --src1 work_dir/latest.pth --method random_init --save-dir ./ --coco。这里我是在initialize_bbox_head文件当中直接修改默认值不习惯使用终端,运行后会生成一个新的权重问价文件。
七、finetune
1.准备数据集,有需要的可以用prepare_coco_few_shot.py提取,并复制路径:data/few_shot_ann/coco/benchmark_10shot/下,benchmark_10shot文件夹是mmfewshot/detection/datasets/coco.py文件创建的,如果你跟我一样是不使用终端那么下边还会对路径进行修改,也就是修改相对路径
2.修改configs/detection/_base_/datasets/fine_tune_based/few_shot_coco.py当中的val的ann_file路径
3.修改mmfewshot/detection/datasets/coco.py中coco_benchmark的路径(终端好像不用修改)
4.修改configs\detection\fsce\coco\fsce_r101_fpn_coco_10shot-fine-tuning.py中的num_novel_shots和num_base_shots,max_iters
5.修改configs/detection/fsce/coco/fsce_r101_fpn_coco_10shot-fine-tuning.py的load_from权重路径
5.修改configs/detection/_base_/datasets/fine_tune_based/few_shot_coco.py中data_root的路径,该路径是存放图片的路径
Meta R-CNN:
1.数据集准备(按照VOC或者COCO的格式后面给出代码将自定义数据集修改为标准的VOC和COCO数据格式)
2.修改./mmfewshot/detection/datasets/voc.py下的VOC_SPLIT
3.修改configs/detection/_base_/datasets/nway_kshot/base_voc.py下的data_root路径、num_support_ways等参数
注意:
dataset=dict(
type='FewShotVOCDataset',
ann_cfg=[
dict(
type='ann_file',
ann_file=data_root +
'VOC2007/ImageSets/Main/trainval.txt'),
#规范化后的自定义数据集应该是没有VOC2012的,这里注释
# dict(
# type='ann_file',
# ann_file=data_root + 'VOC2012/ImageSets/Main/trainval.txt')
],
img_prefix=data_root,
multi_pipelines=train_multi_pipelines,
classes=None,
use_difficult=True,
instance_wise=False,
dataset_name='query_dataset'),
4.修改./configs/detection/meta_rcnn/voc/split1/meta-rcnn_r101_c4_8xb4_voc-split1_base-training.py 当中的一些超参数比如max_iters、lr、num_classes、num_meta_classes。
5.训练结束后修改./tools/detection/misc/initialize_bbox_head.py下的VOC_TAR_SIZE(种类数)、--src1(该参数是第一阶段的权重路径)、save-dir(保存重构检测头后的权重路径)和method。
6.修改./meta_rcnn_mmfewshot/configs/detection/_base_/datasets/nway_kshot/few_shot_voc.py下的num_support_ways(总共类别数)、data_root
7.修改./meta_rcnn_mmfewshot/configs/detection/meta_rcnn/voc/split1/meta-rcnn_r101_c4_8xb4_voc-split1_10shot-fine-tuning.py下的load_from、optimizer
8.修改train.py下的config、work-dir
9.修改./meta_rcnn_mmfewshot/configs/detection/meta_rcnn/meta-rcnn_r101_c4.py下的num_classes、num_meta_classes
八、模型创建
#源代码
model = build_detector(cfg.model, logger=logger)
build_detector()是mmfewshot或者说是mmdetection的一种模型创建方法,
以上是使用过程中的一些细节,仅供参考哦