当前位置: 首页 > article >正文

mmpose框架进行人体姿态识别模型HRNet训练

进行训练之前要先进行标注及数据增强,标注工具写在另一篇,首先讲数据增强。

数据增强

进行简单的色彩变换和位置变换,代码如下:

from PIL import Image, ImageEnhance
import numpy as np
import os
import glob
import json
import torch
import torchvision.transforms as transforms

from TrainingTools.ImageResizeTool import input_path, output_path


def color(flag, input_directory, output_directory):
    # 创建输出目录
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    # 遍历输入目录中的图片文件
    for filename in os.listdir(input_directory):
        if filename.endswith(".jpg") or filename.endswith(".png"):
            # 读取图片
            image_path = os.path.join(input_directory, filename)
            original_image = Image.open(image_path)

            # 色彩增强参数,可以根据需求进行调整
            brightness_factor = np.random.uniform(0.6, 1.8)  # 亮度
            contrast_factor = np.random.uniform(0.7, 1.7)    # 对比度
            saturation_factor = np.random.uniform(0.7, 1.7)  # 饱和度

            # 对图像进行色彩增强
            enhancer = ImageEnhance.Brightness(original_image)
            enhanced_image = enhancer.enhance(brightness_factor)

            enhancer = ImageEnhance.Contrast(enhanced_image)
            enhanced_image = enhancer.enhance(contrast_factor)

            enhancer = ImageEnhance.Color(enhanced_image)
            enhanced_image = enhancer.enhance(saturation_factor)

            new_name = os.path.splitext(filename)[0] + "-" + flag + ".jpg"

            json_file = os.path.join(input_directory, os.path.splitext(filename)[0] +".json")
            with open(json_file, 'r') as fj:
                data = json.load(fj)
                data["imagepath"] = new_name
            saved_json = os.path.join(output_directory, os.path.splitext(new_name)[0] + ".json")
            json.dump(data, open(saved_json, 'w'), ensure_ascii=False, indent=2)

            # 保存增强后的图片
            output_path = os.path.join(output_directory, new_name)
            enhanced_image.save(output_path)
            print(output_path)

    print("数据增强完成。")

def position(flag, data_folder, output_folder):
    os.makedirs(output_folder, exist_ok=True)

    # 获取文件夹内所有JSON文件路径
    json_files = glob.glob(os.path.join(data_folder, '*.json'))

    # 定义图像变换
    transform = transforms.RandomAffine(degrees=30, translate=(0.2, 0.2), scale=(0.8, 1.2), shear=10)

    for json_file_path in json_files:
        flage = 0
        while True:
            # 读取JSON文件
            with open(json_file_path, 'r') as json_file:
                data = json.load(json_file)

            # 图像路径
            image_path = os.path.splitext(json_file_path)[0] + ".jpg"

            # 读取图像并获取图像宽高
            image = Image.open(image_path)
            image_width, image_height = image.size

            aspect_ratio = image_width / image_height
            crop_height = int(torch.randint(int(image_height * 0.5), image_height, (1,)).item())
            crop_width = int(crop_height * aspect_ratio)

            transform = transforms.CenterCrop((crop_height, crop_width))
            transformed_image = transform(image)

            # 获取变换后图像的宽高
            transformed_image_width, transformed_image_height = transformed_image.size

            crop_zero_point = ((image_width - transformed_image_width) / 2, (image_height - transformed_image_height) / 2)

            keypoints = data["keypoints"][0]
            for i in range(0, len(keypoints), 3):
                x, y, v = keypoints[i], keypoints[i + 1], keypoints[i + 2]
                keypoints[i], keypoints[i + 1], keypoints[i + 2] = [
                    x - crop_zero_point[0], y - crop_zero_point[1], v]

            # 判断是否有关键点超出图像边缘
            if all(0 <= keypoints[i] < transformed_image_width and 0 <= keypoints[i+1] < transformed_image_height for i in range(0, len(keypoints), 3)):
                break  # 如果所有关键点在图像内,则退出循环

            flage += 1
            print(flage)
            if flage == 20:
                break

        if flage == 20:
            continue

        resized_image = transformed_image.resize((image_width, image_height))
        resize_ration = image_width / transformed_image_width
        # 将关键点坐标转换回绝对坐标
        for i in range(0, len(keypoints), 3):
            x, y, v = [keypoints[i], keypoints[i + 1], keypoints[i + 2]]
            keypoints[i], keypoints[i+1], keypoints[i+2] = [
                    x * resize_ration, y * resize_ration, v]

        # 更新数据
        data["keypoints"] = [keypoints]
        data["imagepath"] = os.path.splitext(data["imagepath"])[0] + "-" + flag + ".jpg"

        # 保存增强后的图像
        transformed_image_path = os.path.join(output_folder, os.path.splitext(os.path.basename(image_path))[0] + "-" + flag + ".jpg")
        resized_image.save(transformed_image_path)

        # 更新 JSON 文件
        transformed_json_file_path = os.path.join(output_folder, os.path.splitext(os.path.basename(json_file_path))[0] + "-" + flag + ".json")
        with open(transformed_json_file_path, 'w') as transformed_json_file:
            json.dump(data, transformed_json_file, indent=2)

        print(transformed_image_path)

    print("Data augmentation completed.")

input_path = ".../mmpose/data/coco/ori"
output_path = ".../mmpose/data/coco/col1"
flag = "col1"
color(flag, input_path, output_path)         # flag = col1 col2
# position(flag, input_path, output_path)    # flag = posi

一般会进行一次位置变换,两次色彩变换(一次对原数据,一次对位置变换数据),分别运行3次,记得每次变换用不同的flag来作为文件后缀,便于区分且后面都要混合在一起。图像和json数据都放在一个文件夹内作为input_path
在这里插入图片描述

增前后如图,数据存放在coco内,annotationsperson_detection_results是存放coco和bbox的label的,后面会讲到。增强后在每个存放数据的文件夹内新建一个kp_jsondet_json文件夹,将标注的label都move到kp_json中。

制作bbox文件

然后用yolov5的yolov5x6.pt模型来识别oriposi内的图像(因为识别行人最准)来作为行人框(topdown模式需要先目标识别再进行关键点识别)

python detect.py --weights yolov5x6.pt --imgsz 1280 --source .../mmpose/data/coco/posi/ --save-txt --classes 0

识别后将每组识别结果内的标签文件夹labels复制粘贴一份,并命名为col,作为每组图像的色彩变换的识别结果,因为色彩变换后每个点的位置都不变所以复制一份即可。然后将col内的标签文件进行批量改名:

import os
path = ".../yolov5/runs/detect/posi/col"

files = os.listdir(path)
for file in files:
    file_name = os.path.join(path, file)
    new_f = os.path.splitext(file)[0] + "-col2.txt"
    new_name = os.path.join(path, new_f)
    os.rename(file_name, new_name)
    print(new_name)

若后缀为col1则这里也相应是col1,要相互对应。
接下来yolo格式转labelme格式,每个数据文件夹都进行转换,结果保存在det_json文件夹内,代码如下:

# -*- coding: utf-8 -*-
import json
import cv2
from glob import glob
import os

txt_path = '.../yolov5/runs/detect/exp/labels/'  # darknet格式
saved_path = '.../mmpose/data/coco/posi/det_json/'
img_path = '.../mmpose/data/coco/posi/'
kp_path = ".../mmpose/data/coco/posi/kp_json/"

files = glob(img_path + "*.jpg")
# files = os.listdir(txt_path)
# print(files)
files = [i.split('/')[-1].split('.jpg')[0] for i in files]
print(files)

err_n = 0
err_name = []
for file in files:
    print(file)
    txt_file = txt_path + file + '.txt'
    img_file = img_path + file + '.jpg'
    kp_file = kp_path + file + ".json"
    print(img_file)
    img = cv2.imread(img_file)
    # print(img)
    imgw = img.shape[1]
    imgh = img.shape[0]
    flag = False
    xi = []
    yi = []
    xa = []
    ya = []
    Label = []

    if not os.path.exists(txt_file):
        flag = True
        Label.append('person')
    else:
        with open(txt_file, 'r') as f:  # 读取txt文件将所有标签名存入数组
            # points = []
            area = 0
            for line in f.readlines():
                line = line.strip('\n')
                a = line.split(' ')
                label = 'other'
                if a[0] == '0':
                    label = 'person'  # 'head'
                # elif a[0] == '1':
                #     label = 'hat'  # 'hat'
                # 这里是自己命名的类别及对应的数字

                Label.append(label)
                print(Label)

                centerx = float(a[1]) * imgw
                centery = float(a[2]) * imgh
                w = float(a[3]) * imgw
                h = float(a[4]) * imgh
                x1 = centerx - w / 2
                x2 = centerx + w / 2
                y1 = centery - h / 2
                y2 = centery + h / 2
                x1, x2 = min(x1, x2), max(x1, x2)
                y1, y2 = min(y1, y2), max(y1, y2)

                with open(kp_file, 'r') as fj0:
                    data = json.load(fj0)
                    # 提取关键点信息
                    keypoints = data['keypoints'][0]

                    min_x, min_y = float('inf'), float('inf')
                    max_x, max_y = float('-inf'), float('-inf')

                    # 遍历关键点,找到最小和最大坐标
                    for i in range(0, len(keypoints), 3):
                        x, y, confidence = keypoints[i], keypoints[i + 1], keypoints[i + 2]
                        if confidence > 0:  # 只考虑可信度大于 0 的关键点
                            min_x = min(min_x, x)
                            min_y = min(min_y, y)
                            max_x = max(max_x, x)
                            max_y = max(max_y, y)
                    center_x = (min_x + max_x)/2
                    center_y = (min_y + max_y) / 2

                if (x1 < center_x < x2) and (y1 < center_y < y2):
                    if int((x2 - x1) * (y2 - y1)) > area:
                        area = int((x2 - x1)*(y2 - y1))
                        # points = [x1, x2, y1, y2]
                else:
                    continue

                xi = [x1]
                yi = [y1]
                xa = [x2]
                ya = [y2]

    with open(kp_file, 'r') as fj:
        data = json.load(fj)
        # 提取关键点信息
        keypoints = data['keypoints'][0]

        min_x, min_y = float('inf'), float('inf')
        max_x, max_y = float('-inf'), float('-inf')

        # 遍历关键点,找到最小和最大坐标
        for i in range(0, len(keypoints), 3):
            x, y, confidence = keypoints[i], keypoints[i + 1], keypoints[i + 2]
            if confidence > 0:  # 只考虑可信度大于 0 的关键点
                min_x = min(min_x, x)
                min_y = min(min_y, y)
                max_x = max(max_x, x)
                max_y = max(max_y, y)

        print(flag, len(xi))
        if flag or (len(xi) == 0):
            xi = [min_x - 80]
            if xi[0] < 0:
                xi[0] = 0
            yi = [min_y - 80]
            if yi[0] < 0:
                yi[0] = 0
            xa = [max_x + 80]
            if xa[0] > imgw:
                xa[0] = imgw
            ya = [max_y + 80]
            if ya[0] > imgh:
                ya[0] = imgh
            err_n += 1
            err_name.append(file)
        else:
            if min_x < xi[0]:
                xi[0] = min_x - 80
                if xi[0] < 0:
                    xi[0] = 0
            if min_y < yi[0]:
                yi[0] = min_y - 80
                if yi[0] < 0:
                    yi[0] = 0
            if max_x > xa[0]:
                xa[0] = max_x + 80
                if xa[0] > imgw:
                    xa[0] = imgw
            if max_y > ya[0]:
                ya[0] = max_y + 80
                if ya[0] > imgh:
                    ya[0] = imgh


    # for j in range(0, len(files)):
    labelme_formate = {
        "version": "3.16.7",
        "flags": {},
        "lineColor": [0, 255, 0, 128],
        "fillColor": [255, 0, 0, 128],
        "imagePath": file + ".jpg",
        "imageHeight": imgh,
        "imageWidth": imgw
    }
    labelme_formate['imageData'] = None
    shapes = []
    for i in range(0, len(xi)):
        s = {"label": Label[i], "line_color": None, "fill_color": None, "shape_type": "rectangle"}
        points = [
            [xi[i], yi[i]],
            [xa[i], ya[i]]
        ]
        s['points'] = points
        shapes.append(s)

    labelme_formate['shapes'] = shapes
    json.dump(labelme_formate, open(saved_path + file + ".json", 'w'), ensure_ascii=False, indent=2)
    print(saved_path + file + ".json")
print(err_name)
print(err_n)

因为我是一张图就只有一个目标,为了防止背景有干扰也被yolo识别出来,这里也有进行一些过滤筛选,首先选出最大框,其次框内没有关键点的也被过滤。
如果标注的17个点没有全落在最终得到的框内,则证明识别出的框不是很准确没有完全包裹住整个人,一般这种情况比较少,所以这种情况的文件名最终会被print出来。这时候需要用labelme工具来打开这几个文件来手动修正;
err_n即出错的个数占比比较小但基数比较大懒得去一个一个找的话也没问题,因为这种情况我在代码内就直接对所有关键点的最大最小点分别进行外扩一定的像素进行处理,虽然不是很准确但是可以保证这个人是被框住的。
处理完成后,将所有数据增强的图像数据、关键点数据、目标框数据都合并到原数据的文件夹内(还保持着kp_jsondet_json与图像数据在同一级的这个结构)
最终可以用上面得到的结果来转bbox文件了:

import json
import os

# 创建一个空列表,用于存储每个标注内容
annotations_list = []

# det json
data_folder = ".../mmpose/data/coco/ori/det_json"
# 将列表保存为一个新的JSON文件
bbox_file_path = '.../mmpose/data/coco/person_detection_results/ori_bbox.json'

image_id = 1

# 遍历包含JSON标注文件的文件夹
for json_file in os.listdir(data_folder):
    if json_file.endswith('.json'):
        json_path = os.path.join(data_folder, json_file)

        with open(json_path, 'r') as f:
            # 读取JSON文件并解析为Python对象
            data = json.load(f)

            # 提取关键点信息
            points = data['shapes'][0]["points"]

            x1 = points[0][0]
            y1 = points[0][1]
            x2 = points[1][0]
            y2 = points[1][1]

            bbox = [int(x1), int(y1), int(x2 - x1), int(y2 - y1)]

            #print("Bbox coordinates:", bbox)

            # 提取所需信息
            #bbox = [0, 0, 1920, 1080]
            category_id = 1
            score = 0.992816648089145264

            # 创建一个新字典,包含所需信息
            annotation = {
                "bbox": bbox,
                "category_id": category_id,
                "image_id": image_id,
                "score": score
            }

            # 将新字典添加到列表中
            annotations_list.append(annotation)
            print(image_id, json_file)

            image_id += 1


with open(bbox_file_path, 'w') as output_file:
    json.dump(annotations_list, output_file)

print(f'Bboxes file saved to {bbox_file_path}')

制作coco文件

代码如下:

#生成Coco标注文件

import os
import json

import cv2

# 指定包含 JSON 标注文件的目录
kp_json_path = ".../mmpose/data/coco/ori/kp_json"
det_json_path = ".../mmpose/data/coco/ori/det_json"
img_path = ".../mmpose/data/coco/ori"
output_file = ".../mmpose/data/coco/annotations/ori_coco.json"

# 初始化 COCO 数据集骨架
coco_data = {
    "info": {"description": "COCO 2017 Dataset","url": "http://cocodataset.org","version": "1.0","year": 2017,"contributor": "COCO Consortium","date_created": "2017/09/01"},
    "licenses": [{"url": "http://creativecommons.org/licenses/by-nc-sa/2.0/","id": 1,"name": "Attribution-NonCommercial-ShareAlike License"},{"url": "http://creativecommons.org/licenses/by-nc/2.0/","id": 2,"name": "Attribution-NonCommercial License"},{"url": "http://creativecommons.org/licenses/by-nc-nd/2.0/","id": 3,"name": "Attribution-NonCommercial-NoDerivs License"},{"url": "http://creativecommons.org/licenses/by/2.0/","id": 4,"name": "Attribution License"},{"url": "http://creativecommons.org/licenses/by-sa/2.0/","id": 5,"name": "Attribution-ShareAlike License"},{"url": "http://creativecommons.org/licenses/by-nd/2.0/","id": 6,"name": "Attribution-NoDerivs License"},{"url": "http://flickr.com/commons/usage/","id": 7,"name": "No known copyright restrictions"},{"url": "http://www.usa.gov/copyright.shtml","id": 8,"name": "United States Government Work"}],
    "images": [],
    "annotations": [],
    "categories": [{"supercategory": "person","id": 1,"name": "person","keypoints": ["nose","left_eye","right_eye","left_ear","right_ear","left_shoulder","right_shoulder","left_elbow","right_elbow","left_wrist","right_wrist","left_hip","right_hip","left_knee","right_knee","left_ankle","right_ankle"],"skeleton": [[16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13],[6,7],[6,8],[7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]]}]
}

'''
# 添加 COCO 类别
category = {
    "id": 1,
    "name": "person",
    "supercategory": "person"
}
coco_data["categories"].append(category)
'''

image_id = 1
annotation_id = 1

new_image_directory = ""

# 遍历目录下的所有 JSON 文件
for filename in os.listdir(kp_json_path):
    if filename.endswith(".json"):
        annotation_file = os.path.join(kp_json_path, filename)
        det_json_file = os.path.join(det_json_path, filename)
        img_file = os.path.join(img_path, os.path.splitext(filename)[0] + ".jpg")

        with open(annotation_file, "r") as file:
            annotation_data = json.load(file)

        old_file_name = annotation_data["imagepath"]
        # 构建新的文件路径
        file_name = os.path.basename(old_file_name)
        new_file_path = os.path.join(new_image_directory, file_name)
        img = cv2.imread(img_file)
        height, width = img.shape[:2]

        # 处理图像信息
        image_info = {
            "id": image_id,
            "file_name": new_file_path,
            "width": width,
            "height": height
        }
        coco_data["images"].append(image_info)

        if "keypoints" in annotation_data and annotation_data["keypoints"]:
            keypoints = annotation_data["keypoints"][0]
        else:
            continue
        keypoint_count = int(len(keypoints) / 3)
        keypoint_list = []

        for i in range(keypoint_count):
            x, y, v = keypoints[i * 3], keypoints[i * 3 + 1], keypoints[i * 3 + 2]
            keypoint_list.append([x, y, v])

        bbox = []
        with open(det_json_file, 'r') as f2:
            # 读取JSON文件并解析为Python对象
            data = json.load(f2)

            # 提取关键点信息
            points = data['shapes'][0]["points"]

            x1 = points[0][0]
            y1 = points[0][1]
            x2 = points[1][0]
            y2 = points[1][1]

            bbox = [int(x1), int(y1), int(x2 - x1), int(y2 - y1)]

        annotation = {
            "id": annotation_id,
            "image_id": image_id,
            "category_id": 1,
            "keypoints": [0] * (3 * keypoint_count),
            "num_keypoints": keypoint_count,
            "bbox": bbox,
            "area": bbox[2] * bbox[3],
            'ignore': 0,
            "iscrowd": 0
        }

        for i, keypoint in enumerate(keypoint_list):
            annotation["keypoints"][i * 3] = keypoint[0]
            annotation["keypoints"][i * 3 + 1] = keypoint[1]
            annotation["keypoints"][i * 3 + 2] = keypoint[2]

        coco_data["annotations"].append(annotation)

        # 增加图像和标注 ID
        print(image_id, filename)
        image_id += 1
        annotation_id += 1

# 保存 COCO 数据为 JSON 文件
with open(output_file, "w") as json_file:
    json.dump(coco_data, json_file)

print(f'Annotations file saved to {output_file}')

计算均值方差

import os
import cv2
import numpy as np

count = 0

def calculate_mean_std_batch(folder_path, batch_size=100):
    global count
    image_paths = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if
                   file.endswith('.jpg') or file.endswith('.png')]

    total_pixels = 0  # 所有图像的像素总数
    pixel_sum = np.zeros(3)  # 用于存储像素值的总和,RGB三个通道
    pixel_squared_diff_sum = np.zeros(3)  # 用于存储像素值差的平方和,RGB三个通道

    for i in range(0, len(image_paths), batch_size):
        batch_paths = image_paths[i:i + batch_size]
        batch_images = [cv2.imread(path) for path in batch_paths]
        batch_images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in batch_images]

        for image in batch_images:
            total_pixels += image.size / 3  # 计算像素总数
            pixel_sum += np.sum(image, axis=(0, 1))  # 计算像素值的总和

        mean = pixel_sum / total_pixels  # 计算当前批次的均值

        for image in batch_images:
            pixel_squared_diff_sum += np.sum((image - mean) ** 2, axis=(0, 1))  # 计算像素值差的平方和

        count += batch_size
        print(str(count))

    # 计算均值
    mean = pixel_sum / total_pixels

    # 计算标准差
    std = np.sqrt(pixel_squared_diff_sum / total_pixels)

    return mean, std


folder_path = ".../mmpose/data/coco/ori"  # 更改为包含图像文件的文件夹路径
mean, std = calculate_mean_std_batch(folder_path)
print("Mean:", mean)
print("Standard Deviation:", std)

训练配置文件

选择2d的人体关键点检测,topdown_heatmap模式的,创建或从例子里面更改配置文件,这是其中一个配置文件的例子:

_base_ = ['../../../_base_/default_runtime.py']

# runtime
train_cfg = dict(max_epochs=300, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(
    type='Adam',
    lr=1e-5,
    weight_decay=1e-5,
))

# learning policy
param_scheduler = [
    dict(
        type='LinearLR', begin=0, end=500, start_factor=0.001,
        by_epoch=False),  # warm-up
    dict(
        type='MultiStepLR',
        begin=0,
        end=210,
        milestones=[170, 200],
        gamma=0.1,
        by_epoch=True)
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)

# hooks
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))

# codec settings
codec = dict(
    type='MSRAHeatmap', input_size=(288, 384), heatmap_size=(72, 96), sigma=3)

# model settings
model = dict(
    type='TopdownPoseEstimator',
    data_preprocessor=dict(
        type='PoseDataPreprocessor',
        mean=[108.36341352, 109.43100226, 103.20111829],   
        std=[69.45071565, 64.45677341, 63.89262599],    
        bgr_to_rgb=True),
    backbone=dict(
        type='HRNet',
        in_channels=3,
        extra=dict(
            stage1=dict(
                num_modules=1,
                num_branches=1,
                block='BOTTLENECK',
                num_blocks=(4, ),
                num_channels=(64, )),
            stage2=dict(
                num_modules=1,
                num_branches=2,
                block='BASIC',
                num_blocks=(4, 4),
                num_channels=(48, 96)),
            stage3=dict(
                num_modules=4,
                num_branches=3,
                block='BASIC',
                num_blocks=(4, 4, 4),
                num_channels=(48, 96, 192)),
            stage4=dict(
                num_modules=3,
                num_branches=4,
                block='BASIC',
                num_blocks=(4, 4, 4, 4),
                num_channels=(48, 96, 192, 384))),
        init_cfg=dict(
            type='Pretrained',
            checkpoint='pretrain/pose_hrnet_person_384x288.pth'),
    ),
    head=dict(
        type='HeatmapHead',
        in_channels=48,
        out_channels=17,
        deconv_out_channels=None,
        loss=dict(type='KeypointMSELoss', use_target_weight=True),
        decoder=codec),
    test_cfg=dict(
        flip_test=True,
        flip_mode='heatmap',
        shift_heatmap=True,
    ))

# base dataset settings
dataset_type = 'CocoDataset'
data_mode = 'topdown'
data_root = 'data/coco/'

# pipelines
train_pipeline = [
    dict(type='LoadImage'),
    dict(type='GetBBoxCenterScale'),
    dict(type='RandomFlip', direction='horizontal'),
    dict(type='RandomHalfBody'),
    dict(type='RandomBBoxTransform'),
    dict(type='TopdownAffine', input_size=codec['input_size']),
    dict(type='GenerateTarget', encoder=codec),
    dict(type='PackPoseInputs')
]
val_pipeline = [
    dict(type='LoadImage'),
    dict(type='GetBBoxCenterScale'),
    dict(type='TopdownAffine', input_size=codec['input_size']),
    dict(type='PackPoseInputs')
]

# data loaders
train_dataloader = dict(
    batch_size=32,
    num_workers=5,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_mode=data_mode,
        ann_file='annotations/barlegup_coco.json',
        data_prefix=dict(img='barlegup/'),
        pipeline=train_pipeline,
    ))
val_dataloader = dict(
    batch_size=32,
    num_workers=5,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_mode=data_mode,
        ann_file='annotations/barlegup_coco.json',
        bbox_file='data/coco/person_detection_results/barlegup_bbox.json',
        data_prefix=dict(img='barlegup/'),
        test_mode=True,
        pipeline=val_pipeline,
    ))
test_dataloader = val_dataloader

# evaluators
val_evaluator = dict(
    type='CocoMetric',
    ann_file=data_root + 'annotations/barlegup_coco.json')
test_evaluator = val_evaluator

在这里插入图片描述
首先是改max_epochs最大的训练轮数,一般我是训练300轮;
milestoneslr衰减的里程碑,则训练到指定轮数开始lr衰减;
input_size为训练的输入大小,注意这里的顺序是宽、高,如果训练的图像大多数竖向或者要推理的摄像头画面为竖向的,则这边是宽小高大,输入进网络的图片才不会resize太多导致人物变得很小。这里的heatmap_size相应为它的四分之一;
在这里插入图片描述
然后将刚才计算得出的均值和方差填入其中;
在这里插入图片描述
这里是预训练模型,如果有训练自己的模型要作为预训练模型,可以在这里更改;
在这里插入图片描述
如果显存不够可以减小batch_size,需要注意,如果是多卡训练,这里的batch_size是每张卡显存所能支持的batch_size数量而不是所有显存所支持的;
如果cpu性能不足可以减小num_workers
ann_file为刚才得到的coco文件地址,bbox_file为得到的bbox地址,data_prefix为训练数据的地址。

训练命令

若是多卡训练,假设是双显卡,则可在主目录下输入以下命令进行训练:

bash ./tools/dist_train.sh .../mmpose/configs/body_2d_keypoint/topdown_heatmap/my/xxx.py 2

一个是刚才制作的配置文件地址,后面的数字则是显卡数。
如果是单卡训练则:

python ./tools/train.py .../mmpose/configs/body_2d_keypoint/topdown_heatmap/my/xxx.py

就跟一个配置文件地址即可。


http://www.kler.cn/a/384324.html

相关文章:

  • 17、论文阅读:VMamba:视觉状态空间模型
  • springboot图书管理系统(一个简单的单体架构项目,适合小白)
  • C++builder中的人工智能(9)如何在C++中创建AI二进制/Heaviside步进函数
  • 如何调整pdf的页面尺寸
  • H7-TOOL的CAN/CANFD助手增加帧发送成功标识支持, 继续加强完善功能细节
  • SDL基本使用
  • AJAX 全面教程:从基础到高级
  • [react]10、react性能优化
  • 前端三件套-css
  • 二分答案—愤怒的牛-P1676 [USACO05FEB] Aggressive cows G
  • 11/6密码学 Des对称加密设计
  • 软考系统架构设计师论文:云上自动化运维及其应用
  • mysql查表相关练习
  • 6.0、静态路由
  • 夜天之书 #103 开源嘉年华纪实
  • Chromium127编译指南 Mac篇(六)- 编译优化技巧
  • 苍穹外卖 管理端订单分页查询
  • 【Android】Service
  • 在数据抓取的时候,短效IP比长效IP有哪些优势?
  • ESP32 gptimer通用定时器初始化报错:assert failed: timer_ll_set_clock_prescale
  • 【数字图像处理+MATLAB】对图片进行伽马校正(Gamma Correction):使用 imadjust 函数进行伽马变换
  • 由中文乱码引来的一系列学习——Qt
  • 『Django』初识前后端分离
  • 【CentOS】中的Firewalld:全面介绍与实战应用(上)
  • 基于Spring Boot的船舶监造系统的设计与实现,LW+源码+讲解
  • JavaFx -- chapter06(UDPSocket)