当前位置: 首页 > article >正文

OOD分类项目训练

一、项目地址

GitHub - LooKing9218/UIOS

二、label制作

      将训练、验证、测试数据的分类信息转换入.csv文件中,运行如下脚本即可:

import os
import csv
 
#要读取的训练、验证、测试文件的目录,该文件下保存着以各个类别命名的文件夹和对应的分类图片
root_path=r'/media/*********************/train' 
#类别种类
classes=['cls1','cls2']

def get_Write_file_infos(path):
    # 文件信息列表
    file_infos_list=[]
    typeclothes=os.listdir(path)
    for ii in typeclothes:
        everyfile=os.path.join(path , ii)
        for root, dirnames, filenames in os.walk(everyfile):
            for filename in filenames:
                file_infos = {}
                dirname=root
                 
                #根据自己的需求更改路径地址
                filename1 ='train/'+ii+'/'+ filename#.split('.jpg')[0]
                flag = filename1[-1]
                file_infos["ImageId"] = filename1
     
                file_infos["Flag"] = classes.index(ii)
                #将数据追加字典到列表中
                file_infos_list.append(file_infos)
                
    return file_infos_list
 
 
#写入csv文件
def write_csv(file_infos_list):
    with open('train_label.csv','a+',newline='') as csv_file_train:
        csv_writer = csv.DictWriter(csv_file_train,fieldnames=['ImageId','Flag'])
        csv_writer.writeheader()
        for each in file_infos_list:
            print(each)
            csv_writer.writerow(each)
            
def main():
    file_infos_list =get_Write_file_infos(root_path)
    write_csv(file_infos_list)
 
 
if __name__ == '__main__':
    main()
    print('The End!')

生成情况如下:

三、运行程序

     (1)修改参数文件 utils/config.py

# -*- coding: utf-8 -*-
class DefaultConfig(object):
    net_work = 'ResUnNet50'
    num_classes = 2
    num_epochs = 100
    batch_size = 256
    validation_step = 1
    root = "/media/code/"
    train_file = "train_label.csv"
    val_file = "val_label.csv"
    test_file = "test_label.csv"
    lr = 1e-4
    lr_mode = 'poly'
    momentum = 0.9
    weight_decay = 1e-4
    save_model_path = './Model_Saved'.format(net_work,lr)
    log_dirs = './Logs_Adam_0304'
    pretrained =True# False
    pretrained_model_path ='/media/code/UIOS-master/Trained/archive/data/99843712' #None
    cuda = 0
    num_workers = 4
    use_gpu = True
    trained_model_path = ''
    predict_fold = 'predict_mask'

(2)运行

   命令:

python train.py

(3)运行界面

四、踩坑记录

问题原因:ValueError: Only one class present in y_true. ROC AUC score is not defined in that case.

解决方法:

     (1)网上看了很多:

              方法1:添加 try-except

        try:
            epoch_train_auc = metrics.roc_auc_score(labels, outputs)

            writer.add_scalar('Train/train_auc', float(epoch_train_auc),
                          epoch)
            print('loss for train : {},{}'.format(loss_train_mean,round(epoch_train_auc,6)))

        except ValueError:
            pass

        方法2:DataLoader的参数设置shuffle=True

   train_loader = DataLoader(DatasetCFP(
        root=args.root,
        mode='train',
        data_file=args.train_file,
    ),
        batch_size=args.batch_size, shuffle=True, pin_memory=True)
    val_loader = DataLoader(DatasetCFP(
        root=args.root,
        mode='val',
        data_file=args.val_file,
    ),
        batch_size=args.batch_size, shuffle=True, pin_memory=True)
    test_loader = DataLoader(DatasetCFP(
        root=args.root,
        mode='test',
        data_file=args.test_file,
    ),
        batch_size=args.batch_size, shuffle=True, pin_memory=True)

    方法3:增大batch_size

    (2)我的方法:

        其实是我马虎大意

       修改好config.py中的num_classes参数就行了,

       见谅(不好意思~( ̄▽ ̄)~*)


http://www.kler.cn/a/234388.html

相关文章:

  • LeetCode【0035】搜索插入位置
  • MySql结合element-plus pagination的分页查询
  • Jmeter基础篇(22)服务器性能监测工具Nmon的使用
  • 计算机毕业设计Python+Neo4j知识图谱医疗问答系统 大模型 机器学习 深度学习 人工智能 大数据毕业设计 Python爬虫 Python毕业设计
  • IPguard与Ping32全面对比——选择最适合企业的数据安全解决方案
  • 【Vue】Vue3.0(二十一)Vue 3.0中 的$event使用示例
  • kyuubi 接入starrocks | doris
  • Vue3中Setup概述和使用(三)
  • maven插件maven-jar-plugin构建jar文件详细使用
  • 一、西瓜书——绪论
  • 【大厂AI课学习笔记】【1.6 人工智能基础知识】(4)深度学习和机器学习
  • JavaScript 设计模式之原型模式
  • 【美团】酒旅用户增长-后端研发
  • Nginx实战:1-安装搭建
  • C# 字体大小的相关问题
  • 【博云2023】乘龙一跃腾云海,侧目抬手摘星河
  • 双向链表的插入、删除、按位置增删改查、栈和队列区别、什么是内存泄漏
  • 【Larry】英语学习笔记语法篇——从句=连词+简单句
  • Linux——动静态库
  • Python操作MySQL基础
  • Qt知识点总结目录
  • 1523.在区间范围内统计奇数数目(Java)
  • Python爬虫——请求库安装
  • ubuntu20.04 安装mysql(8.x)
  • 13. 串口接收模块的项目应用案例
  • 华为数通方向HCIP-DataCom H12-821题库(单选题:441-460)