当前位置: 首页 > article >正文

基于detectron2框架的深度学习模型载入自定义数据集

基于detectron2框架的深度学习模型载入自定义数据集

一、前言

最近在做微光目标检测的研究工作,使用了Rank_DETR;这个模型是基于detrex框架,而detrex框架又是基于detectron2的。找了一圈没找到载入数据集的地方,后面查阅了资料得知要用API进行注册。

二、步骤

  1. 注册数据集:
    在脚本中,我们首先要注册数据集。Detectron2 提供了多种注册数据集的方式,常用的是 register_coco_instances,用于 COCO 格式的数据集。您可以在脚本的开头或配置文件中添加如下代码来注册您的数据集:

    from detectron2.data.datasets import register_coco_instances
    
    register_coco_instances("my_dataset_train", {}, "path/to/train_annotations.json", "path/to/train_images/")
    register_coco_instances("my_dataset_val", {}, "path/to/val_annotations.json", "path/to/val_images/")
    
    • "my_dataset_train""my_dataset_val" 是数据集的名称,您可以按需更改。
    • path/to/train_annotations.jsonpath/to/val_annotations.json 分别是训练和验证数据集的 COCO 格式标注文件路径。
    • path/to/train_images/path/to/val_images/ 是训练和验证图像的路径。
  2. 在配置文件中引用数据集:
    在您使用的配置文件中,需要确保数据加载器 (dataloader) 中引用了您刚才注册的数据集。通常,您需要修改以下内容:

    cfg.dataloader.train.dataset.names = "my_dataset_train"
    cfg.dataloader.test.dataset.names = "my_dataset_val"
    

    这确保了训练和验证时使用的是您自定义的数据集。

三、示例代码集成

如果您已经在脚本中集成了以上步骤,代码可能如下所示:

def main(args):
    cfg = LazyConfig.load(args.config_file)
    cfg = LazyConfig.apply_overrides(cfg, args.opts)
    default_setup(cfg, args)
    register_coco_instances("exdark_train", {},
                            "/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/annotations/instances_train2017.json",
                            "/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/train2017")
    register_coco_instances("exdark_test", {},
                            "/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/annotations/instances_val2017.json",
                            "/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/val2017")
    cfg.dataloader.train.dataset.names = "exdark_train"
    cfg.dataloader.test.dataset.names = "exdark_test"

    if args.eval_only:
        model = instantiate(cfg.model)
        model.to(cfg.train.device)
        model = create_ddp_model(model)
        DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
        print(do_test(cfg, model))
    else:
        do_train(args, cfg)


if __name__ == "__main__":
    parser = default_argument_parser()
    parser.add_argument("--use_wandb", action="store_true", help="Whether to use wandb.")
    parser.add_argument("--wandb_key", type=str, help="Wandb API key.")
    args = parser.parse_args()

    if args.use_wandb:
        wandb.login(key=args.wandb_key)
        
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )

http://www.kler.cn/news/283789.html

相关文章:

  • 环境变量--永久 & 暂时
  • 设计模式 16 迭代器模式
  • OCI编程高级篇(十四) 直接路径装载设置字段信息
  • 数据结构与算法 第四天(串、数组、广义表)
  • HTTP分析
  • 高级java每日一道面试题-2024年8月30日-数据库篇-数据库的三范式是什么?
  • Java技术栈 —— Spark入门(三)之实时视频流
  • Dubbo如何传递链路追踪id?
  • 小琳AI课堂:使用ChatGPT API搭建系统(二)
  • innovus:如何让部分sink长到target insertion delay的长度
  • 关于OBI 在unity URP环境下使用的正确步骤
  • 网络编程(学习)2024.8.27
  • jQuery基础——选择器的补充方法——过滤方法、查找方法
  • python使用multiprocessing多进程通讯
  • 各种各样的正则表达式
  • 92. UE5 RPG 使用C++创建GE实现灼烧的负面效果
  • 达梦数据库-DM8 企业版安装指南
  • [java][代码] java中date格式化输出时间字符串
  • 《征服数据结构》LFU缓存
  • Vatee万腾平台:打造企业智能化转型的坚实后盾
  • 【Android】UIMode
  • fpga图像处理实战-双三次插值算法
  • Jmeter提取token并设置为全局变量
  • 聊聊STM32 MCU的BOOT0和BOOT1引脚
  • 浅谈Vue3和React18
  • 六个方面探讨企业为何迫切需要替换FTP
  • PyQt 迁移到 PySide
  • WPF ToolkitMVVM RelayCommand
  • 探究:Elasticsearch 文档的 _id 是 Lucene 的 docid 吗?
  • DNN学习平台(GoogleNet、SSD、FastRCNN、Yolov3)