当前位置: 首页 > article >正文

如何用mmclassification训练多标签多分类数据

这里使用的源码版本是 mmclassification-0.25.0
训练数据标签文件格式如下,每行的空格前面是路径(图像文件所在的绝对路径),后面是标签名,因为特殊要求这里我的每张图像都记录了三个标签每个标签用“,”分开(具体看自己的需求),我的训练标签数量是17个。
在这里插入图片描述
训练参数配置文件,用ResNet作为特征提取主干,多标签分类要使用MultiLabelLinearClsHead作为分类头。数据集的格式使用CustomDataset,并修改该结构的定义文件,后面有详细内容。

# checkpoint saving
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
    interval=100,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
optimizer = dict(lr=0.1, momentum=0.9, type='SGD', weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
runner = dict(max_epochs=100, type='EpochBasedRunner')
lr_config = dict(
    policy='step', step=[
        30,
        60,
        90,
    ])

model = dict(
    type='ImageClassifier',
    backbone=dict(type='ResNet',depth=18,num_stages=4,out_indices=(3, ),style='pytorch'), 
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='MultiLabelLinearClsHead',
        num_classes=17,
        in_channels=512,
    ))

dataset_type = 'CustomDataset'          #'MultiLabelDataset'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', size=224),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(256, -1)),
    dict(type='CenterCrop', crop_size=224),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]

data = dict(
    samples_per_gpu=32,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_prefix='rootpath/images',
        ann_file='rootpath/train.txt',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_prefix='rootpath/images',
        ann_file='rootpath/val.txt',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        data_prefix='rootpath/images',
        ann_file='rootpath/test.txt',
        pipeline=test_pipeline))

evaluation = dict(interval=1, metric='accuracy')

其他需要修改的地方:
1、修改加载数据的格式,将./mmclassification-0.25.0/mmcls/datasets/custom.py的CustomDataset里面的load_annotations函数替换成下面的函数:

    ###修改成多标签分类数据加载方式###
    def load_annotations(self):
            """Load image paths and gt_labels."""
            if self.ann_file is None:
                samples = self._find_samples()
            elif isinstance(self.ann_file, str):
                lines = mmcv.list_from_file(
                    self.ann_file, file_client_args=self.file_client_args)
                samples = [x.strip().rsplit(' ', 1) for x in lines]
            else:
                raise TypeError('ann_file must be a str or None')

            data_infos = []
            for filename, gt_label in samples:
                info = {'img_prefix': self.data_prefix}
                info['img_info'] = {'filename': filename.strip()}
                temp_label = np.zeros(len(self.CLASSES))
                # if not self.multi_label:
                #     info['gt_label'] = np.array(gt_label, dtype=np.int64)
                # else:
                ### multi-label classify
                if len(gt_label) == 1:
                    temp_label[np.array(gt_label, dtype=np.int64)] = 1
                    info['gt_label'] = temp_label
                else:
                    for label in gt_label.split(','):
                        i = self.CLASSES.index(label)
                        temp_label[np.array(i, dtype=np.int64)] = 1
                    # for i in range(np.array(gt_label.split(','), dtype=np.int64).shape[0]):
                    #     temp_label[np.array(gt_label.split(','), dtype=np.int64)[i]] = 1
                    info['gt_label'] = temp_label
                # print(info)
                data_infos.append(info)
            return data_infos

记得在初始函数__init__里修改成自己要训练的类别:
在这里插入图片描述

2、修改评估数据的函数,将./mmclassification-0.25.0/mmcls/models/losses/accuracy.py里面的accuracy_torch函数替换成如下函数。我这里只是增加了一些度量函数,方便可视化多标签的指标情况,并没有更新其他地方,训练时还是会验证原来的指标,里面调用的Metric类可以参考这篇文章:https://blog.csdn.net/u013250861/article/details/122727704

def accuracy_torch(pred, target, topk=(1,), thrs=0.):
    if isinstance(thrs, Number):
        thrs = (thrs,)
        res_single = True
    elif isinstance(thrs, tuple):
        res_single = False
    else:
        raise TypeError(f'thrs should be a number or tuple, but got {type(thrs)}.')

    res = []
    maxk = max(topk)
    num = pred.size(0)
    pred = pred.float()
    
    #### ysn修改,增加对多标签分类的度量函数 ###
    pred_ = (pred > 0.5).float()        # 将 pred 中大于0.5的元素替换为1,其余替换为0

    # print("pred shape:", pred.shape, "pred:", pred)
    # # print("pred_ shape:", pred_.shape, "pred_:", pred_)
    # # print("target shape", target.shape, "target:", target)
    from mmcls.utils import get_root_logger
    logger = get_root_logger()
    
    from sklearn.metrics import classification_report
    class_report = classification_report(target.numpy(), pred_.numpy(), target_names=[“这里可以写成你的训练类型列表,也可以不使用这个参数”])     #分类报告汇总了精确率、召回率和 F1 分数等指标
    logger.info("\nClassification Report:\n{}".format(class_report))

    myMetic = Metric(pred_.numpy(), target.numpy())
    ham = myMetic.hamming_distance()
    avgPrecision, _ = myMetic.avgPrecision()
    avgRecall, _, _  = myMetic.avgRecall()
    ranking_loss = myMetic.get_ranking_loss()
    accuracy_multiclass = myMetic.accuracy_multiclass()
    logger.info("\nHam:{}\tAvgPrecision:{}\tAvgRecall:{}\tRanking_loss:{}\tAccuracy_Multilabel:{}".format(ham, avgPrecision, avgRecall, ranking_loss, accuracy_multiclass))

    ####原来的代码###
    pred_score, pred_label = pred.topk(maxk, dim=1)
    pred_label = pred_label.t()

    target = target.argmax(dim=1)     ### ysn修改,这里是多标签分类标签列表的格式,单标签分类去掉这一句 ###

    correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
    for k in topk:
        res_thr = []
        for thr in thrs:
            # Only prediction values larger than thr are counted as correct
            _correct = correct & (pred_score.t() > thr)
            correct_k = _correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res_thr.append((correct_k.mul_(100. / num)))
        if res_single:
            res.append(res_thr[0])
        else:
            res.append(res_thr)
    return res

3、修改推理部分,将./mmclassification-0.25.0/mmcls/apis/inference.py里面的inference_model函数修改如下,推理多标签时候可以指定输出所有得分阈值大于0.5的所有标签类型。

def inference_model(model, img):
    """Inference image(s) with the classifier.

    Args:
        model (nn.Module): The loaded classifier.
        img (str/ndarray): The image filename or loaded image.

    Returns:
        result (dict): The classification results that contains
            `class_name`, `pred_label` and `pred_score`.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    if isinstance(img, str):
        if cfg.data.test.pipeline[0]['type'] != 'LoadImageFromFile':
            cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile'))
        data = dict(img_info=dict(filename=img), img_prefix=None)
    else:
        if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
            cfg.data.test.pipeline.pop(0)
        data = dict(img=img)
    test_pipeline = Compose(cfg.data.test.pipeline)
    data = test_pipeline(data)
    data = collate([data], samples_per_gpu=1)
    if next(model.parameters()).is_cuda:
        # scatter to specified GPU
        data = scatter(data, [device])[0]

    # forward the model
    # with torch.no_grad():
    #     scores = model(return_loss=False, **data)
    #     pred_score = np.max(scores, axis=1)[0]
    #     pred_label = np.argmax(scores, axis=1)[0]
    #     result = {'pred_label': pred_label, 'pred_score': float(pred_score)}
    # result['pred_class'] = model.CLASSES[result['pred_label']]
    # return result

    ## ysn修改 ##
    with torch.no_grad():
        scores = model(return_loss=False, **data)
        # print(scores, type(scores), len(scores), len(model.CLASSES))
    result = {'pred_label':[], 'pred_score': [], 'pred_class':[]}
    for i in range(len(scores[0])):
        if scores[0][i]>0.5:
            result['pred_label'].append(int(i))
            result['pred_score'].append(float(scores[0][i]))
            result['pred_class'].append(model.CLASSES[int(i)])
        else:
            continue
    return result

或者直接使用以下推理脚本:

# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
import warnings
import os
import mmcv
import torch
import numpy as np
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmcls.datasets.pipelines import Compose
from mmcls.models import build_classifier


def init_model(config, checkpoint=None, device='cuda:0', options=None):
    """Initialize a classifier from config file.

    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
        options (dict): Options to override some settings in the used config.

    Returns:
        nn.Module: The constructed classifier.
    """
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
    if options is not None:
        config.merge_from_dict(options)
    config.model.pretrained = None
    model = build_classifier(config.model)
    if checkpoint is not None:
        # Mapping the weights to GPU may cause unexpected video memory leak
        # which refers to https://github.com/open-mmlab/mmdetection/pull/6405
        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
        if 'CLASSES' in checkpoint.get('meta', {}):
            model.CLASSES = checkpoint['meta']['CLASSES']
        else:
            from mmcls.datasets import ImageNet
            warnings.simplefilter('once')
            warnings.warn('Class names are not saved in the checkpoint\'s '
                          'meta data, use imagenet by default.')
            model.CLASSES = ImageNet.CLASSES
    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model


def inference_model(model, img, threshold=0.5):
    """Inference image(s) with the classifier.

    Args:
        model (nn.Module): The loaded classifier.
        img (str/ndarray): The image filename or loaded image.

    Returns:
        result (dict): The classification results that contains
            `class_name`, `pred_label` and `pred_score`.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    if isinstance(img, str):
        if cfg.data.test.pipeline[0]['type'] != 'LoadImageFromFile':
            cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile'))
        data = dict(img_info=dict(filename=img), img_prefix=None)
    else:
        if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
            cfg.data.test.pipeline.pop(0)
        data = dict(img=img)
    test_pipeline = Compose(cfg.data.test.pipeline)
    data = test_pipeline(data)
    data = collate([data], samples_per_gpu=1)
    if next(model.parameters()).is_cuda:
        # scatter to specified GPU
        data = scatter(data, [device])[0]

    ### 原始代码 ###
    # forward the model
    # with torch.no_grad():
    #     scores = model(return_loss=False, **data)
    #     pred_score = np.max(scores, axis=1)[0]
    #     pred_label = np.argmax(scores, axis=1)[0]
    #     result = {'pred_label': pred_label, 'pred_score': float(pred_score)}
    # result['pred_class'] = model.CLASSES[result['pred_label']]
    # return result

    ### ysn修改 ###
    with torch.no_grad():
        scores = model(return_loss=False, **data)
        # print(scores, type(scores), len(scores), len(model.CLASSES))
    result = {'pred_label':[], 'pred_score': [], 'pred_class':[]}
    for i in range(len(scores[0])):
        if scores[0][i] > threshold:
            result['pred_label'].append(int(i))
            result['pred_score'].append(round(float(scores[0][i]), 4))
            result['pred_class'].append(model.CLASSES[int(i)])
        else:
            continue
    return result


def show_result(img, result, out_file):
    import matplotlib.pyplot as plt
    plt.imshow(img)
    plt.title(f'{result["pred_class"]}: {result["pred_score"]}')
    plt.axis('off')
    if out_file is not None:
        plt.savefig(out_file)
    plt.show()


def save_result(imgpath, result, outfile="result.txt"):
    # print(result['pred_label'], result['pred_class'], result['pred_score'])
    with open(outfile, "a+") as f:
        f.write(imgpath + "\t" + ",".join(result["pred_class"]) + "\n")
    f.close()

def main():
    parser = ArgumentParser()
    parser.add_argument('--imgpath', default="./images", help='Image file')
    parser.add_argument('--img', default=None, help='Image file')
    parser.add_argument('--outpath', default="./res", help='Image file')
    parser.add_argument('--config', default="config.py",  help='Config file')
    parser.add_argument('--checkpoint', default="./epoch_100.pth",  help='Checkpoint file')
    parser.add_argument('--device', default='cuda:0', help='Device used for inference')
    args = parser.parse_args()

    if not os.path.exists(args.outpath):
        os.mkdir(args.outpath)

    model = init_model(args.config, args.checkpoint, device=args.device)

    if args.img is None and os.path.exists(args.imgpath):
        for imgname in os.listdir(args.imgpath):
            img_path = os.path.join(args.imgpath, imgname)
            img = mmcv.imread(img_path)
            if img is None:
                continue
            result = inference_model(model, img, threshold=0.5)
            print("img_path: ", img_path, result)
            save_result(img_path, result, outfile=os.path.join(args.outpath, "result.txt"))
            show_result(img, result, out_file=os.path.join(args.outpath, imgname.replace('.jpg', '_res.jpg')))

    elif args.img is not None and os.path.exists(args.img):
        result = inference_model(model, args.img, threshold=0.5)
        # print(result['pred_label'], result['pred_class'], result['pred_score'])
    else:
        raise Exception('No such file or directory: {}'.format(args.img))



if __name__ == '__main__':
    main()

通过以上修改,可以成功训练、评估、推理多标签分类训练了。
由于我没有找到mmcls官方的训练多标签的训练教程,因此做了上述修改。如果有其他更方便有效的多标签多分类方法或者项目,欢迎在该文章下面留言,非常感谢。

参考文章
https://blog.csdn.net/litt1e/article/details/125316552
https://blog.csdn.net/u013250861/article/details/122727704


http://www.kler.cn/news/367537.html

相关文章:

  • nginx 路径匹配,关于“/“对规则的影响
  • Spring 设计模式之工厂模式
  • AutoSar AP CM服务接口级别的数据类型总结
  • 数据结构------手撕链表(一)【不带头单向非循环】
  • 免费PDF页面提取小工具
  • 从汇编角度看C/C++函数指针与函数的调用差异
  • 如何理解前端与后端开发
  • entwine 和 conda环境下 使用和踩坑 详细步骤! 已解决
  • uptime kuma拨测系统
  • 身份证归属地查询接口-在线身份证归属地查询-身份证归属地查询API
  • 论文略读:Less is More: on the Over-Globalizing Problem in Graph Transformers
  • 2FA-双因素认证
  • 基于Python的智能求职分析系统
  • python 使用 企微机器人发送消息
  • 安全日志记录的重要性
  • 今天不分享技术,分享秋天的故事
  • Spring Boot框架下的厨艺社区开发
  • ALLO数据集:首个为月球轨道机器人近距离操作设计的异常检测基准开源数据集。
  • 安全知识见闻-脚本语言对与安全的重要性
  • Spring Boot驱动的厨艺分享社区开发
  • 5G工业路由器智能电网部署实录:一天内解决供电、网络
  • 手机在网状态查询接口-在线手机在网状态查询-手机在网状态查询API
  • vue2 关于组件
  • react mackDowan 渲染文本,图片,视频
  • Vue3实现获取验证码按钮倒计时效果
  • 深入解析机器学习算法