当前位置: 首页 > article >正文

Faster R-CNN原理详解以及Pytorch实现模型训练与推理

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称项目名称
1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】
3.【手势识别系统开发】4.【人脸面部活体检测系统开发】
5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】
7.【YOLOv8多目标识别与自动标注软件开发】8.【基于深度学习的行人跌倒检测系统】
9.【基于深度学习的PCB板缺陷检测系统】10.【基于深度学习的生活垃圾分类目标检测系统】
11.【基于深度学习的安全帽目标检测系统】12.【基于深度学习的120种犬类检测与识别系统】
13.【基于深度学习的路面坑洞检测系统】14.【基于深度学习的火焰烟雾检测系统】
15.【基于深度学习的钢材表面缺陷检测系统】16.【基于深度学习的舰船目标分类检测系统】
17.【基于深度学习的西红柿成熟度检测系统】18.【基于深度学习的血细胞检测与计数系统】
19.【基于深度学习的吸烟/抽烟行为检测系统】20.【基于深度学习的水稻害虫检测与识别系统】
21.【基于深度学习的高精度车辆行人检测与计数系统】22.【基于深度学习的路面标志线检测与识别系统】
23.【基于深度学习的智能小麦害虫检测识别系统】24.【基于深度学习的智能玉米害虫检测识别系统】
25.【基于深度学习的200种鸟类智能检测与识别系统】26.【基于深度学习的45种交通标志智能检测与识别系统】
27.【基于深度学习的人脸面部表情识别系统】28.【基于深度学习的苹果叶片病害智能诊断系统】
29.【基于深度学习的智能肺炎诊断系统】30.【基于深度学习的葡萄簇目标检测系统】
31.【基于深度学习的100种中草药智能识别系统】32.【基于深度学习的102种花卉智能识别系统】
33.【基于深度学习的100种蝴蝶智能识别系统】34.【基于深度学习的水稻叶片病害智能诊断系统】
35.【基于与ByteTrack的车辆行人多目标检测与追踪系统】36.【基于深度学习的智能草莓病害检测与分割系统】
37.【基于深度学习的复杂场景下船舶目标检测系统】38.【基于深度学习的农作物幼苗与杂草检测系统】
39.【基于深度学习的智能道路裂缝检测与分析系统】40.【基于深度学习的葡萄病害智能诊断与防治系统】
41.【基于深度学习的遥感地理空间物体检测系统】42.【基于深度学习的无人机视角地面物体检测系统】
43.【基于深度学习的木薯病害智能诊断与防治系统】44.【基于深度学习的野外火焰烟雾检测系统】
45.【基于深度学习的脑肿瘤智能检测系统】46.【基于深度学习的玉米叶片病害智能诊断与防治系统】
47.【基于深度学习的橙子病害智能诊断与防治系统】48.【基于深度学习的车辆检测追踪与流量计数系统】
49.【基于深度学习的行人检测追踪与双向流量计数系统】50.【基于深度学习的反光衣检测与预警系统】
51.【基于深度学习的危险区域人员闯入检测与报警系统】52.【基于深度学习的高密度人脸智能检测与统计系统】
53.【基于深度学习的CT扫描图像肾结石智能检测系统】54.【基于深度学习的水果智能检测系统】
55.【基于深度学习的水果质量好坏智能检测系统】56.【基于深度学习的蔬菜目标检测与识别系统】
57.【基于深度学习的非机动车驾驶员头盔检测系统】58.【太基于深度学习的阳能电池板检测与分析系统】
59.【基于深度学习的工业螺栓螺母检测】60.【基于深度学习的金属焊缝缺陷检测系统】
61.【基于深度学习的链条缺陷检测与识别系统】62.【基于深度学习的交通信号灯检测识别】
63.【基于深度学习的草莓成熟度检测与识别系统】64.【基于深度学习的水下海生物检测识别系统】
65.【基于深度学习的道路交通事故检测识别系统】66.【基于深度学习的安检X光危险品检测与识别系统】
67.【基于深度学习的农作物类别检测与识别系统】68.【基于深度学习的危险驾驶行为检测识别系统】
69.【基于深度学习的维修工具检测识别系统】70.【基于深度学习的维修工具检测识别系统】
71.【基于深度学习的建筑墙面损伤检测系统】72.【基于深度学习的煤矿传送带异物检测系统】
73.【基于深度学习的老鼠智能检测系统】74.【基于深度学习的水面垃圾智能检测识别系统】
75.【基于深度学习的遥感视角船只智能检测系统】76.【基于深度学习的胃肠道息肉智能检测分割与诊断系统】
77.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统】78.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统】
79.【基于深度学习的果园苹果检测与计数系统】80.【基于深度学习的半导体芯片缺陷检测系统】
81.【基于深度学习的糖尿病视网膜病变检测与诊断系统】82.【基于深度学习的运动鞋品牌检测与识别系统】
83.【基于深度学习的苹果叶片病害检测识别系统】84.【基于深度学习的医学X光骨折检测与语音提示系统】
85.【基于深度学习的遥感视角农田检测与分割系统】86.【基于深度学习的运动品牌LOGO检测与识别系统】
87.【基于深度学习的电瓶车进电梯检测与语音提示系统】88.【基于深度学习的遥感视角地面房屋建筑检测分割与分析系统】

二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~

《------正文------》

目录

  • 引言
  • Faster R-CNN架构
  • 第一阶段:区域提案网络(RPN)
    • 骨干网
    • 锚框
    • 目标分类
    • 边界框优化
  • 阶段2:目标分类和框细化
    • Region Proposals区域建议
    • ROI池化
    • 对象分类
    • 边界框优化(再次)
    • 多任务学习
  • 推理(测试/预测时间):
  • 训练
  • 在PyTorch中实现和微调Faster R-CNN
    • 步骤1:安装所需的库
    • 步骤2:导入所需库
    • 步骤3:加载预训练的Faster R-CNN模型
    • 步骤4:准备数据集
    • 步骤5:设置数据加载器
    • 步骤6:设置训练循环
    • 步骤7:评估模型
  • 8.推理

引言

目前大多数SOTA模型都是建立在Faster-RCNN模型的基础之上。Faster R-CNN是一种对象检测模型,它可以识别图像中的对象,并在它们周围绘制边界框,同时还可以对这些对象进行分类。这是一个两级探测器:

  1. 阶段1:提出图像中可能包含对象的潜在区域。这是由**Region Proposal Network(RPN)**处理的。
  2. 阶段2:使用这些建议的区域来预测对象的类别,并细化边界框以更好地匹配对象。

Faster R-CNN架构

img

第一阶段:区域提案网络(RPN)

骨干网

  • 图像通过卷积网络(如ResNet或VGG16)。
  • 这将从图像中提取重要特征并创建特征图

锚框

  • 锚框是放置在特征图上的点上的不同大小和形状的框。
  • 每个锚框表示可能的对象位置。
  • 在特征图上的每个点处,生成具有不同大小和纵横比的锚框。

目标分类

  • RPN预测每个锚框是背景(无对象)还是前景(包含对象)。
  • 正(前景)锚点:与实际对象高度重叠的框。
  • 负(背景)锚点:与对象几乎没有重叠的框。

边界框优化

  • RPN还通过预测偏移(调整)来细化锚框,以更好地将其与实际对象对齐。

损失函数

I)分类损失:帮助模型决定锚是背景还是前景。

II)回归损失:帮助调整锚框以更精确地适应对象。

阶段2:目标分类和框细化

Region Proposals区域建议

  • 在RPN之后,我们得到区域建议(可能包含对象的细化框)。

ROI池化

  • 区域建议有不同的大小,但神经网络需要固定大小的输入。
  • ROI池化通过将所有区域方案划分为较小的区域并应用池化来将它们调整为固定大小,从而使它们均匀一致。

对象分类

  • 每个区域建议通过一个小网络来预测类别(例如,狗、汽车等)里面的物体。
  • 交叉熵损失用于将对象分类到类别中。

边界框优化(再次)

  • 使用偏移量,再次细化区域建议,以更好地匹配实际对象。
  • 这使用回归损失来调整建议。

多任务学习

  • 第二阶段的网络同时学习预测对象类别和优化边界框。

推理(测试/预测时间):

  • Top Region Proposals:在测试过程中,模型会生成大量的区域建议,但只有顶级建议(具有最高分类分数)会被传递到第二阶段。
  • 最终预测:第二阶段预测最终的类别和边界框。
  • 非最大抑制:一种称为非最大抑制的技术被应用于删除重复或重叠的框,只保留最好的框。

训练

两种训练方式

  1. 分阶段训练:首先训练区域建议网络(RPN),然后训练分类器和回归器。
  2. 同时训练:同时训练两个阶段(更快,更有效)。

在PyTorch中实现和微调Faster R-CNN

步骤1:安装所需的库

pip install torch torchvision

步骤2:导入所需库

import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torchvision.transforms as T
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

步骤3:加载预训练的Faster R-CNN模型

PyTorch的torchvision提供了一个在COCO上预训练的更快的R-CNN模型。您可以通过更改最后一层中的类数量来修改您自己的数据集。

# Load the pre-trained Faster R-CNN model with a ResNet-50 backbone
model = fasterrcnn_resnet50_fpn(pretrained=True)

# Number of classes (your dataset classes + 1 for background)
num_classes = 3  # For example, 2 classes + background

# Get the number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# Replace the head of the model with a new one (for the number of classes in your dataset)
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

步骤4:准备数据集

  • Faster R-CNN需要图像和相应的注释(边界框和标签)。
  • **您的数据集应返回:包括边界框(*框*)和标签(*标签*)的图像和目标字典。

如有必要,请创建自定义数据集类。您可以使用 torchvision.datasets.ImageFolder 并在注释文件中提供边界框或创建自定义Dataset类。

# Define transformations (e.g., resizing, normalization)
transform = T.Compose([
    T.ToTensor(),
])
# Custom Dataset class or using an existing one
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, transforms=None):
        # Initialize dataset paths and annotations here
        self.transforms = transforms
        # Your dataset logic (image paths, annotations, etc.)

    def __getitem__(self, idx):
        # Load image
        img = ...  # Load your image here
        # Load corresponding bounding boxes and labels

        boxes = ...  # Load or define bounding boxes

        labels = ...  # Load or define labels

        # Create a target dictionary
        target = {}
        target["boxes"] = torch.tensor(boxes, dtype=torch.float32)
        target["labels"] = torch.tensor(labels, dtype=torch.int64)
        # Apply transforms
        if self.transforms is not None:
            img = self.transforms(img)
        return img, target
    def __len__(self):
        # Return the length of your dataset
        return len(self.data)

步骤5:设置数据加载器

# Load dataset
dataset = CustomDataset(transforms=transform)
# Split into train and validation sets
indices = torch.randperm(len(dataset)).tolist()
train_dataset = torch.utils.data.Subset(dataset, indices[:-50])
valid_dataset = torch.utils.data.Subset(dataset, indices[-50:])
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, 
                                   collate_fn=lambda x: tuple(zip(*x)))
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False, 
                                    collate_fn=lambda x: tuple(zip(*x)))

步骤6:设置训练循环

现在设置优化器和训练循环。对于Faster R-CNN,通常使用SGDAdam作为优化器。

# Move model to GPU if available
device = torch.device('cuda') if torch.cuda.is_available() 
                               else torch.device('cpu')
model.to(device)

# Set up the optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, 
                                                   weight_decay=0.0005)
# Learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, 
                                                               gamma=0.1)
# Train the model
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

   # Training loop
    for images, targets in train_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        # Backward pass
        losses.backward()
        optimizer.step()
        train_loss += losses.item()

    # Update the learning rate
    lr_scheduler.step()
    print(f'Epoch: {epoch + 1}, Loss: {train_loss / len(train_loader)}')
print("Training complete!")

步骤7:评估模型

训练后,您可以在验证集上评估模型,或将其用于新图像的推理。

# Set the model to evaluation mode
model.eval()
# Test on a new image
with torch.no_grad():
    for images, targets in valid_loader:
        images = list(img.to(device) for img in images)
        predictions = model(images)
        # Example: print the bounding boxes and labels for the first image
        print(predictions[0]['boxes'])
        print(predictions[0]['labels'])

8.推理

要在新图像上运行推理,请执行以下操作:

import cv2
from PIL import Image
# Load image
img = Image.open("path/to/your/image.jpg")
# Apply the same transformation as for training
img = transform(img)
img = img.unsqueeze(0).to(device)
# Model prediction
model.eval()
with torch.no_grad():
    prediction = model([img])
# Print the predicted bounding boxes and labels
print(prediction[0]['boxes'])
print(prediction[0]['labels'])

在这里插入图片描述

好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!


http://www.kler.cn/a/582015.html

相关文章:

  • Spring Boot 整合 RabbitMQ(在Spring项目中使用RabbitMQ)
  • Chrome 浏览器 134 版本新特性
  • 一周学会Flask3 Python Web开发-SQLAlchemy连接Mysql数据库
  • Html5-照片滤镜应用学习经验总结
  • Windows10下docker desktop命令行操作指南(大部分也适用于Linux)
  • Jmeter请求发送加密参数
  • K8S学习之基础二十三:k8s的持久化存储之nfs
  • 机器学习 Day02,matplotlib库绘图
  • TikTok多店铺网络安全搭建指南
  • 在资源有限中逆势突围:从抗战智谋到寒门高考的破局智慧
  • 数据库统计信息开启和关闭
  • 深入浅出Bearer Token:解析工作原理及其在Vue、Uni-app与Java中的实现Demo
  • Redis 设置密码(配置文件、docker容器、命令行3种场景)
  • HippoRAG 2 原理精读
  • Vue开发中计算属性与方法调用之间的区别与联系
  • LeNet-5卷积神经网络详解
  • 【沐渥科技】氮气柜日常如何维护?
  • QT与网页显示数据公式的方法
  • 企业未来 AI 应用中的工作重心
  • 在VMware上部署【Rocky Linux】保姆级