当前位置: 首页 > article >正文

Resnet50进行迁移学习实现图片二分类

内容简介

本文使用预训练的Resnet50网络对皮肤病图片进行二分类,基于portch框架。

数据集说明

数据集存放目录为: used_dataset , 共200张图片,标签为:benign(良性)、malignant(患病)。

数据集划分如下:

代码目录介绍

  • args.py 存放训练和测试所用的各种参数。 --mode字段表示运行模式:train or test. --model_path字段是训练模型的保存路径。 其余字段都有默认值。
  • create_dataset.py 该脚本是用来读json中的数据的,可以忽略。
  • data_gen.py 该脚本实现划分数据集以及数据增强和数据加载。
  • main.py 包含训练、评估和测试。
  • transform.py 实现图片增强。
  • utils.py 存放一些工具函数。
  • models/Res.py 是重写的ResNet各种类型的网络。
  • checkpoints 保存模型

运行命令

# 训练模型
python main.py --mode=train
# 测试模型
python main.py --mode=test --model_path='训练好的模型文件路径'   

main.py 脚本介绍

main()函数 实现模型的训练和评估

step1: 加载数据

# data
    transformations = get_transforms(input_size=args.image_size,test_size=args.image_size)
    train_set = data_gen.Dataset(root=args.train_txt_path,transform=transformations['val_train'])
    train_loader = data.DataLoader(train_set,batch_size=args.batch_size,shuffle=True)

    val_set = data_gen.ValDataset(root=args.val_txt_path,transform=transformations['val_test'])
    val_loader = data.DataLoader(val_set,batch_size=args.batch_size,shuffle=False)

step2: 构建模型

model = make_model(args)
    if use_cuda:
        model.cuda()
    
    # define loss function and optimizer
    if use_cuda:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = nn.CrossEntropyLoss()
	
    optimizer = get_optimizer(model,args)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, verbose=False)

step3: 模型的训练和评估

# train
    for epoch in range(start_epoch,args.epochs):
        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, optimizer.param_groups[0]['lr']))

        train_loss,train_acc = train(train_loader,model,criterion,optimizer,epoch,use_cuda)
        test_loss,val_acc = val(val_loader,model,criterion,epoch,use_cuda)

        scheduler.step(test_loss)

        print(f'train_loss:{train_loss}\t val_loss:{test_loss}\t train_acc:{train_acc} \t val_acc:{val_acc}')

        # save_model
        is_best = val_acc >= best_acc
        best_acc = max(val_acc, best_acc)

        save_checkpoint({
                    'fold': 0,
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'train_acc':train_acc,
                    'acc': val_acc,
                    'best_acc': best_acc,
                    'optimizer' : optimizer.state_dict(),
                }, is_best, single=True, checkpoint=args.checkpoint)

    print("best acc = ",best_acc)

train()函数 每个epoch下的模型训练过程

主要实现每个批次下梯度的反向传播,计算accuarcy 和 loss, 并更新,最后返回其均值。

def train(train_loader,model,criterion,optimizer,epoch,use_cuda):
    model.train()
    losses = AverageMeter()
    train_acc = AverageMeter()

    for (inputs,targets) in tqdm(train_loader):
        if use_cuda:
            inputs,targets = inputs.cuda(),targets.cuda(async=True)
        inputs,targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # 梯度参数设为0
        optimizer.zero_grad()

        # forward
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        acc = accuracy(outputs.data,targets.data)
        # inputs.size(0)=32
        losses.update(loss.item(), inputs.size(0))
        train_acc.update(acc.item(),inputs.size(0))

    return losses.avg,train_acc.avg

val()函数 每个epoch下的模型评估过程

主要代码与train()函数一致,但没有梯度的计算,还有将model.train()改成model.eval()。

def val(val_loader,model,criterion,epoch,use_cuda):
    global best_acc
    losses = AverageMeter()
    val_acc = AverageMeter()

    model.eval() # 将模型设置为验证模式
    # 混淆矩阵
    confusion_matrix = meter.ConfusionMeter(args.num_classes)
    for _,(inputs,targets) in enumerate(val_loader):
        if use_cuda:
            inputs,targets = inputs.cuda(),targets.cuda()
        inputs,targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        confusion_matrix.add(outputs.data.squeeze(),targets.long())
        acc1 = accuracy(outputs.data,targets.data)

        # measure accuracy and record loss
        losses.update(loss.item(), inputs.size(0))
        val_acc.update(acc1.item(),inputs.size(0))
    return losses.avg,val_acc.avg

test()函数 模型的测试

def test(use_cuda):
    # data
    transformations = get_transforms(input_size=args.image_size,test_size=args.image_size)
    test_set = data_gen.TestDataset(root=args.test_txt_path,transform= transformations['test'])
    test_loader = data.DataLoader(test_set,batch_size=args.batch_size,shuffle=False)
    # load model
    model = make_model(args)
    
    if args.model_path:
        # 加载模型
        model.load_state_dict(torch.load(args.model_path))

    if use_cuda:
        model.cuda()

    # evaluate
    y_pred = []
    y_true = []
    img_paths = []
    with torch.no_grad():
        model.eval() # 设置成eval模式
        for (inputs,targets,paths) in tqdm(test_loader):
            y_true.extend(targets.detach().tolist())
            img_paths.extend(list(paths))
            if use_cuda:
                inputs,targets = inputs.cuda(),targets.cuda()
            inputs,targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
            # compute output
            outputs = model(inputs)  # (16,2)
            # dim=1 表示按行计算 即对每一行进行softmax
            # probability = torch.nn.functional.softmax(outputs,dim=1)[:,1].tolist()
            # probability = [1 if prob >= 0.5 else 0 for prob in probability]
            # 返回最大值的索引
            probability = torch.max(outputs, dim=1)[1].data.cpu().numpy().squeeze()
            y_pred.extend(probability)
        print("y_pred=",y_pred)

        accuracy = metrics.accuracy_score(y_true,y_pred)
        print("accuracy=",accuracy)
        confusion_matrix = metrics.confusion_matrix(y_true,y_pred)
        print("confusion_matrix=",confusion_matrix)
        print(metrics.classification_report(y_true,y_pred))
        # fpr,tpr,thresholds = metrics.roc_curve(y_true,y_pred)
        print("roc-auc score=",metrics.roc_auc_score(y_true,y_pred))
    
        res_dict = {
            'img_path':img_paths,
            'label':y_true,
            'predict':y_pred,

        }
        df = pd.DataFrame(res_dict)
        df.to_csv(args.result_csv,index=False)
        print(f"write to {args.result_csv} succeed ")

 实验结果


http://www.kler.cn/a/374669.html

相关文章:

  • 闪存学习_1:Flash-Aware Computing from Jihong Kim
  • mybatis 参数判断报错的问题
  • 【第一个qt项目的实现和介绍以及程序分析】【正点原子】嵌入式Qt5 C++开发视频
  • 【Python】动态与静态的较量:深入探讨Python的动态类型机制与类型提示的应用
  • c# WinForm弹出窗体时不获取焦点方法
  • Rust:文档注释 //! 和 ///
  • vue vxeui 上传组件 vxe-upload 全局配置上传方法,显示上传进度,最完美的配置方案
  • 音视频听译:助力多维度沟通与发展的大门
  • 预告帖|在MATLAB/Simulink中调用C语言的几种方法
  • 【neo4j】 neo4j cypher单一语句 optional 可选操作的技巧
  • 【CSS in Depth 2 精译_055】8.3 伪类 :is() 和 :where() 的正确打开方式
  • JS 字符串拼接并去重
  • Java 判断回文数
  • 乐鑫ESP32-S3无线AI语音方案,教育机器人交互应用,启明云端乐鑫代理商
  • Linux补基础之:网络配置
  • 笔试题 求空格分割的英文句子中,最大单词长度。
  • 大语言模型推理代码构建(基于llama3模型)
  • 2001-2023年A股上市公司数字化转型数据(MDA报告词频统计)(三种方法)
  • (51)MATLAB迫零均衡器系统建模与性能仿真
  • python使用pymysql
  • 关于我、重生到500年前凭借C语言改变世界科技vlog.13——深入理解指针(3)
  • Glide 简易教程
  • 【Rust标准库中的convert(AsRef,From,Into,TryFrom,TryInto)】
  • PyQt5信号与槽一
  • 【抽代复习笔记】34-群(二十八):不变子群的几道例题
  • .net core中间件Polly