当前位置: 首页 > article >正文

hrnet训练的pt模型结合目标检测进行关键点识别的更准确前向推理

本篇在将图像输入hrnet识别之前先进行目标检测来确定识别的位置,让识别更加精准。

本段代码设置了一个区域框BOX,让人走入区域内才开始检测,适用于考核等场景,也可以直接去掉BOX也是一样的效果。若画面背景中有多个行人,还是只取要检测的那个人,同理还是适用考核场景。
为了让检测效果更直观,在一些点位直接使用线连接起来模拟人体骨骼。

import os
import sys
import numpy as np
from mmpose.apis import init_model, inference_topdown
import cv2

import torch
sys.path.append("/home/yons/train/code/yolov5")
from models.experimental import attempt_load
from utils.general import non_max_suppression, scale_boxes
from torchvision import transforms

# 配置文件路径和检查点文件路径
config_file = '/home/.../pose_td-hm_hrnet-w48_8xb32-210e_PullUp-det/pose_td-hm_hrnet-w48_8xb32-210e_PullUp-det.py'
checkpoint_file = '/home/.../pose_td-hm_hrnet-w48_8xb32-210e_PullUp-det/best_coco_AP_epoch_250.pth'

# 初始化姿态估计模型
pose_model = init_model(config_file, checkpoint_file, device='cuda:0')
yolo_model = attempt_load('yolov5x6.pt', device='cuda:0')  # 加载训练好的yolov5模型
pose_model.eval()
yolo_model.eval()

VIDEO_PATH = 'input.mp4' 
BOX = (300, 50, 300, 450)  # 区域框的左上角坐标和宽高
OUTPUT_VIDEO_PATH = 'output.mp4'

def draw_keypoints(frame, keypoints, box, det_box):
    # 在帧上绘制关键点
    # 这里假设关键点是一个 Nx2 的数组,其中 N 是关键点的数量
    # 并且关键点的坐标是相对于裁剪区域的
    x, y, w, h = box
    for kp in keypoints:
        kp_x, kp_y = kp
        x_rec1 = int(det_box[0] + x)
        y_rec1 = int(det_box[1] + y)
        x_rec2 = int(det_box[2] + x)
        y_rec2 = int(det_box[3] + y)
        cv2.rectangle(frame, (x_rec1, y_rec1), (x_rec2, y_rec2), (0, 0, 255), 2)
        x_cir = int(kp_x + det_box[0] + x)
        y_cir = int(kp_y + det_box[1] + y)
        cv2.circle(frame, (x_cir, y_cir), 3, (0, 255, 0), -1)
    lines = [(15, 13), (13, 11), (16, 14), (14, 12), (11, 12), (5, 11), (6, 12), (5, 6),
             (5, 7), (6, 8), (7, 9), (8, 10), (1, 2), (0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6)]
    for line in lines:
        pt1 = (int(keypoints[line[0]][0] + det_box[0] + x), int(keypoints[line[0]][1] + det_box[1] + y))
        pt2 = (int(keypoints[line[1]][0] + det_box[0] + x), int(keypoints[line[1]][1] + det_box[1] + y))
        cv2.line(frame, pt1, pt2, (0, 255, 0), 2)

def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
    """Resizes and pads image to new_shape with stride-multiple constraints, returns resized image, ratio, padding."""
    shape = im.shape[:2]  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    # Scale ratio (new / old)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup:  # only scale down, do not scale up (for better val mAP)
        r = min(r, 1.0)

    # Compute padding
    ratio = r, r  # width, height ratios
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    if auto:  # minimum rectangle
        dw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh padding
    elif scaleFill:  # stretch
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape[1], new_shape[0])
        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios

    dw /= 2  # divide padding into 2 sides
    dh /= 2

    if shape[::-1] != new_unpad:  # resize
        im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
        print(im.shape)

    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))

    im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border

    return im, ratio, (dw, dh)

# 处理每一张图像
k = 0
if __name__ == '__main__':
    # 打开视频
    cap = cv2.VideoCapture(VIDEO_PATH)

    # 获取视频的一些属性
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # 创建 VideoWriter 对象
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # 或者使用 'XVID'
    out = cv2.VideoWriter(OUTPUT_VIDEO_PATH, fourcc, fps, (width, height))

    while True:
        # 读取一帧
        ret, frame = cap.read()
        if not ret:
            break

        # 加载帧
        x, y, w, h = BOX
        img0 = frame[y:y+h, x:x+w, :]

        img_size = (1280, 1280)
        stride = max(int(yolo_model.stride.max()), 32)
        img = letterbox(img0, img_size, stride=stride, auto=True)[0]   # padded resize
        img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        img = np.ascontiguousarray(img)  # contiguous

        # yolo_model.warmup(imgsz=(1, 3, *img_size))  # warmup
        img = torch.from_numpy(img).to('cuda:0')
        # img = img.half() if yolo_model.fp16 else img.float()  # uint8 to fp16/32
        img = img.float()
        img /= 255  # 0 - 255 to 0.0 - 1.0
        if len(img.shape) == 3:
            img = img[None]  # expand for batch dim

        with torch.no_grad():
            pred = yolo_model(img)  # Inference
        pred = non_max_suppression(pred, 0.25, 0.45, 0)  # NMS    0 for person
        # input()

        det_box = None
        # Process predictions
        # print(pred)
        for i, det in enumerate(pred):  # per image
            # print(det)
            if len(det):
                # Rescale boxes from img_size to img0 size
                det[:, :4] = scale_boxes(img.shape[2:], det[:, :4], img0.shape).round()

                max_bb = None
                for x1, y1, x2, y2, conf, cls in reversed(det):
                    if max_bb is None:
                        max_bb = [x1, y1, x2, y2]
                    else:
                        if ((x2 - x1) * (y2 - y1)) > (
                                (max_bb[2] - max_bb[0]) * (max_bb[3] - max_bb[3])):
                            max_bb = [x1, y1, x2, y2]
                det_box = max_bb
                for idx in range(len(det_box)):
                    det_box[idx] = int(det_box[idx])

        x1, y1, x2, y2 = det_box
        # print(det_box)
        img_seg = img0[y1:y2, x1:x2, :]

        person_results = np.array([[0, 0, x2-x1, y2-y1]])
        # 推理得到关键点坐标
        pose_results = inference_topdown(pose_model, img_seg, person_results, bbox_format='xyxy')

        # 提取关键点坐标并检查是否检测出17个关键点
        keypoints = []
        if len(pose_results) > 0 and pose_results[0].pred_instances.keypoints.shape[1] == 17:
            keypoints = pose_results[0].pred_instances.keypoints[0]

        draw_keypoints(frame, keypoints, BOX, det_box)

        # 写入帧
        out.write(frame)

        # 显示帧
        cv2.imshow('frame', frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    out.release()
    cv2.destroyAllWindows()

input.mp4视频如下:

引体向上原始视频


output.mp4视频如下:

引体向上推理结果视频

大概原理是区域框内进行一系列处理后输入进yolo进行目标检测,在多个目标框内选出我们要检测的人物的目标框输入进hrnet得到关键点,关键点从目标框映射回区域框再映射回原图,得到最终结果。


http://www.kler.cn/a/323066.html

相关文章:

  • Java项目实战II基于微信小程序的电子商城购物平台(开发文档+数据库+源码)
  • 【征稿倒计时!华南理工大学主办 | IEEE出版 | EI检索稳定】2024智能机器人与自动控制国际学术会议 (IRAC 2024)
  • 深入探索 Kubernetes 安全容器:Kata Containers 与 gVisor
  • 【已解决】git push一直提示输入用户名及密码、fatal: Could not read from remote repository的问题
  • GitHub新手入门 - 从创建仓库到协作管理
  • Fish Agent V0.13B:Fish Audio的语音处理新突破,AI语音助手的未来已来!
  • PHP视频活体检测API接口示例-视频活体检测引领身份验证新潮流
  • mysql索引 -- 全文索引介绍(如何创建,使用),explain关键字
  • C#中NModbus4中常用的方法
  • 解决Mac 默认设置 wps不能双面打印的问题
  • DevExpress WPF中文教程:如何解决编辑单元格值的常见问题?
  • 1.6 物理层
  • 每天学习一个技术栈 ——【Django Channels】篇(1)
  • 《深度学习》—— 神经网络中的数据增强
  • PHP中如何使用三元条件运算符
  • 智能PPT行业赋能用户画像
  • Kafka系列之:安装部署CMAK,CMAK管理大型Kafka集群参数调优
  • 实现org.springframework.beans.factory.InitializingBean 接口--初始化bean
  • 渲染太慢?Maya云渲染教程
  • 转行大模型的必要性与未来前景:迎接智能时代的浪潮
  • 阅读CVPR论文——mPLUG-Owl2:革命性的多模态大语言模型与模态协作
  • 复杂网络(Complex Network)社团数据可视化分析(gephi)实验
  • 初识爬虫8
  • SwiftUI疑难杂症(1):sheet content多次执行
  • 在Java中,关于final、static关键字与方法的重写和继承【易错点】
  • io流(学习笔记01)--File知识点