基于detectron2框架的深度学习模型载入自定义数据集
基于detectron2框架的深度学习模型载入自定义数据集
一、前言
最近在做微光目标检测的研究工作,使用了Rank_DETR;这个模型是基于detrex框架,而detrex框架又是基于detectron2的。找了一圈没找到载入数据集的地方,后面查阅了资料得知要用API进行注册。
二、步骤
-
注册数据集:
在脚本中,我们首先要注册数据集。Detectron2 提供了多种注册数据集的方式,常用的是register_coco_instances
,用于 COCO 格式的数据集。您可以在脚本的开头或配置文件中添加如下代码来注册您的数据集:from detectron2.data.datasets import register_coco_instances register_coco_instances("my_dataset_train", {}, "path/to/train_annotations.json", "path/to/train_images/") register_coco_instances("my_dataset_val", {}, "path/to/val_annotations.json", "path/to/val_images/")
"my_dataset_train"
和"my_dataset_val"
是数据集的名称,您可以按需更改。path/to/train_annotations.json
和path/to/val_annotations.json
分别是训练和验证数据集的 COCO 格式标注文件路径。path/to/train_images/
和path/to/val_images/
是训练和验证图像的路径。
-
在配置文件中引用数据集:
在您使用的配置文件中,需要确保数据加载器 (dataloader
) 中引用了您刚才注册的数据集。通常,您需要修改以下内容:cfg.dataloader.train.dataset.names = "my_dataset_train" cfg.dataloader.test.dataset.names = "my_dataset_val"
这确保了训练和验证时使用的是您自定义的数据集。
三、示例代码集成
如果您已经在脚本中集成了以上步骤,代码可能如下所示:
def main(args):
cfg = LazyConfig.load(args.config_file)
cfg = LazyConfig.apply_overrides(cfg, args.opts)
default_setup(cfg, args)
register_coco_instances("exdark_train", {},
"/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/annotations/instances_train2017.json",
"/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/train2017")
register_coco_instances("exdark_test", {},
"/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/annotations/instances_val2017.json",
"/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/val2017")
cfg.dataloader.train.dataset.names = "exdark_train"
cfg.dataloader.test.dataset.names = "exdark_test"
if args.eval_only:
model = instantiate(cfg.model)
model.to(cfg.train.device)
model = create_ddp_model(model)
DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
print(do_test(cfg, model))
else:
do_train(args, cfg)
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--use_wandb", action="store_true", help="Whether to use wandb.")
parser.add_argument("--wandb_key", type=str, help="Wandb API key.")
args = parser.parse_args()
if args.use_wandb:
wandb.login(key=args.wandb_key)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)