当前位置: 首页 > article >正文

偏标记学习+图像分类(论文复现)

偏标记学习+图像分类(论文复现)

本文所涉及所有资源均在传知代码平台可获取

文章目录

    • 偏标记学习+图像分类(论文复现)
        • 概述
        • 算法原理
        • 核心逻辑
        • 效果演示
        • 使用方式

概述

本文复现论文提出的偏标记学习方法,随着深度神经网络的发展,机器学习任务对标注数据的需求不断增加。然而,大量的标注数据十分依赖人力资源与标注者的专业知识。弱监督学习可以有效缓解这一问题,因其不需要完全且准确的标注数据。该论文关注一个重要的弱监督学习问题——偏标记学习(Partial Label Learning),其中每个训练实例与一组候选标签相关联,但仅有一个标签是真实的

在这里插入图片描述

该论文提出了一种渐进式真实标签识别方法,旨在训练过程中逐渐确定样本的真实标签。该论文所提出的方法获得了接近监督学习的性能,且与具体的网络结构、损失函数、随机优化算法无关

算法原理

传统的监督学习常用交叉熵损失和随机梯度下降来优化深度神经网络。交叉熵损失定义如下

在这里插入图片描述

在这里插入图片描述

核心逻辑

具体的核心逻辑如下所示:

import models
import datasets
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm

def CE_loss(probs, targets):
    """交叉熵损失函数"""
    loss = -torch.sum(targets * torch.log(probs), dim = -1)
    loss_avg = torch.sum(loss)/probs.shape[0]
    return loss_avg

class Proden:
    def __init__(self, configs):
        self.configs = configs
    
    def train(self, save = False):
        configs = self.configs
        # 读取数据集
        dataset_path = configs['dataset path']
        if configs['dataset'] == 'CIFAR-10':
            train_data, train_labels, test_data, test_labels = datasets.cifar10_read(dataset_path)
            train_dataset = datasets.Cifar(train_data, train_labels)
            test_dataset = datasets.Cifar(test_data, test_labels)
            output_dimension = 10
        elif configs['dataset'] == 'CIFAR-100':
            train_data, train_labels, test_data, test_labels = datasets.cifar100_read(dataset_path)
            train_dataset = datasets.Cifar(train_data, train_labels)
            test_dataset = datasets.Cifar(test_data, test_labels)
            output_dimension = 100
        # 生成偏标记
        partial_labels = datasets.generate_partial_labels(train_labels, configs['partial rate'])
        train_dataset.load_partial_labels(partial_labels)
        # 计算数据的均值和方差,用于模型输入的标准化
        mean = [np.mean(train_data[:, i, :, :]) for i in range(3)]
        std = [np.std(train_data[:, i, :, :]) for i in range(3)]
        normalize = transforms.Normalize(mean, std)
        # 设备:GPU或CPU
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # 加载模型
        if configs['model'] == 'ResNet18':
            model = models.ResNet18(output_dimension = output_dimension).to(device)
        elif configs['model'] == 'ConvNet':
            model = models.ConvNet(output_dimension = output_dimension).to(device)
        # 设置学习率等超参数
        lr = configs['learning rate']
        weight_decay = configs['weight decay']
        momentum = configs['momentum']
        optimizer = optim.SGD(model.parameters(), lr = lr, weight_decay = weight_decay, momentum = momentum)
        lr_step = configs['learning rate decay step']
        lr_decay = configs['learning rate decay rate']
        lr_scheduler = StepLR(optimizer, step_size=lr_step, gamma=lr_decay)
        for epoch_id in range(configs['epoch count']):
            # 训练模型
            train_dataloader = DataLoader(train_dataset, batch_size = configs['batch size'], shuffle = True)
            model.train()
            for batch in tqdm(train_dataloader, desc='Training(Epoch %d)' % epoch_id, ascii=' 123456789#'):
                ids = batch['ids']
                # 标准化输入
                data = normalize(batch['data'].to(device))
                partial_labels = batch['partial_labels'].to(device)
                targets = batch['targets'].to(device)
                optimizer.zero_grad()
                # 计算预测概率
                logits = model(data)
                probs = F.softmax(logits, dim=-1)
                # 更新软标签
                with torch.no_grad():
                    new_targets = F.normalize(probs * partial_labels, p=1, dim=-1)
                    train_dataset.targets[ids] = new_targets.cpu().numpy()
                # 计算交叉熵损失
                loss = CE_loss(probs, targets)
                loss.backward()
                # 更新模型参数
                optimizer.step()
            # 调整学习率
            lr_scheduler.step()
效果演示

我提前在 CIFAR-10[2] 数据集和 12 层的 ConvNet[3] 网络上训练了一份模型参数。为了测试其准确率,需要配置环境并运行main.py脚本,得到结果如下

在这里插入图片描述

由图可见,该算法在测试集上获得了 89.8% 的准确率。

进一步地,测试训练出的模型在真实图片上的预测结果。在线部署模型后,将一张轮船的图片输入,可以得到输出的预测类型为 “Ship”:

在这里插入图片描述

在这里插入图片描述

使用方式

解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:

unzip Proden-implemention.zip
cd Proden-implemention

代码的运行环境可通过如下命令进行配置

pip install -r requirements.txt

运行如下命令以下载并解压数据集

bash download.sh

如果希望在本地训练模型,请运行如下命令

python main.py -c [你的配置文件路径] -r [选择下者之一:"train""test""infer"]

如果希望在线部署,请运行如下命令

python main-flask.py

文章代码资源点击附件获取


http://www.kler.cn/a/331610.html

相关文章:

  • C# winform 字符串通过枚举类型转成int类型的数据
  • 【Flutter_Web】Flutter编译Web第二篇(webview篇):flutter_inappwebview如何改造方法,变成web之后数据如何交互
  • Mybatis加密解密查询操作(sql前),where要传入加密后的字段时遇到的问题
  • 顺序表的操作
  • 学习反射(反射的使用,反射的应用场景)
  • fastAPI接口的请求与响应——基础
  • Unity实战案例全解析:RTS游戏的框选和阵型功能 总结
  • 学习docker第四弹----安装redis集群大厂面试
  • 灰度重心法求取图像重心
  • Updates were rejected because the tip of your current branch is behind 的解决方法
  • (功能测试)熟悉web项目及环境 测试流程
  • SQL Server—T-sql聚合函数详解
  • 如何在银河麒麟高级服务器操作系统V10搭建虚拟机管理器?
  • Django学习笔记八:发布RESTful API
  • 【数据结构与算法】LeetCode:堆和快排
  • 深入浅出MongoDB(二)
  • 网络编程-TCP
  • 关于Elastic Search与MySQL之间的数据同步
  • 二、MySQL的数据目录
  • 16.数据结构与算法-串,数组与广义表(串,BF算法,KMP算法)
  • linux第二课:常用命令
  • 828华为云征文|使用Flexus X实例创建FDS+Nginx服务实现图片上传功能
  • 微服务(二)
  • Electron 主进程与渲染进程、预加载preload.js
  • 使用rust实现rtsp码流截图
  • Stable Diffusion绘画 | 来训练属于自己的模型:秋叶训练器使用