mmpose框架进行人体姿态识别模型HRNet训练
进行训练之前要先进行标注及数据增强,标注工具写在另一篇,首先讲数据增强。
数据增强
进行简单的色彩变换和位置变换,代码如下:
from PIL import Image, ImageEnhance
import numpy as np
import os
import glob
import json
import torch
import torchvision.transforms as transforms
from TrainingTools.ImageResizeTool import input_path, output_path
def color(flag, input_directory, output_directory):
# 创建输出目录
if not os.path.exists(output_directory):
os.makedirs(output_directory)
# 遍历输入目录中的图片文件
for filename in os.listdir(input_directory):
if filename.endswith(".jpg") or filename.endswith(".png"):
# 读取图片
image_path = os.path.join(input_directory, filename)
original_image = Image.open(image_path)
# 色彩增强参数,可以根据需求进行调整
brightness_factor = np.random.uniform(0.6, 1.8) # 亮度
contrast_factor = np.random.uniform(0.7, 1.7) # 对比度
saturation_factor = np.random.uniform(0.7, 1.7) # 饱和度
# 对图像进行色彩增强
enhancer = ImageEnhance.Brightness(original_image)
enhanced_image = enhancer.enhance(brightness_factor)
enhancer = ImageEnhance.Contrast(enhanced_image)
enhanced_image = enhancer.enhance(contrast_factor)
enhancer = ImageEnhance.Color(enhanced_image)
enhanced_image = enhancer.enhance(saturation_factor)
new_name = os.path.splitext(filename)[0] + "-" + flag + ".jpg"
json_file = os.path.join(input_directory, os.path.splitext(filename)[0] +".json")
with open(json_file, 'r') as fj:
data = json.load(fj)
data["imagepath"] = new_name
saved_json = os.path.join(output_directory, os.path.splitext(new_name)[0] + ".json")
json.dump(data, open(saved_json, 'w'), ensure_ascii=False, indent=2)
# 保存增强后的图片
output_path = os.path.join(output_directory, new_name)
enhanced_image.save(output_path)
print(output_path)
print("数据增强完成。")
def position(flag, data_folder, output_folder):
os.makedirs(output_folder, exist_ok=True)
# 获取文件夹内所有JSON文件路径
json_files = glob.glob(os.path.join(data_folder, '*.json'))
# 定义图像变换
transform = transforms.RandomAffine(degrees=30, translate=(0.2, 0.2), scale=(0.8, 1.2), shear=10)
for json_file_path in json_files:
flage = 0
while True:
# 读取JSON文件
with open(json_file_path, 'r') as json_file:
data = json.load(json_file)
# 图像路径
image_path = os.path.splitext(json_file_path)[0] + ".jpg"
# 读取图像并获取图像宽高
image = Image.open(image_path)
image_width, image_height = image.size
aspect_ratio = image_width / image_height
crop_height = int(torch.randint(int(image_height * 0.5), image_height, (1,)).item())
crop_width = int(crop_height * aspect_ratio)
transform = transforms.CenterCrop((crop_height, crop_width))
transformed_image = transform(image)
# 获取变换后图像的宽高
transformed_image_width, transformed_image_height = transformed_image.size
crop_zero_point = ((image_width - transformed_image_width) / 2, (image_height - transformed_image_height) / 2)
keypoints = data["keypoints"][0]
for i in range(0, len(keypoints), 3):
x, y, v = keypoints[i], keypoints[i + 1], keypoints[i + 2]
keypoints[i], keypoints[i + 1], keypoints[i + 2] = [
x - crop_zero_point[0], y - crop_zero_point[1], v]
# 判断是否有关键点超出图像边缘
if all(0 <= keypoints[i] < transformed_image_width and 0 <= keypoints[i+1] < transformed_image_height for i in range(0, len(keypoints), 3)):
break # 如果所有关键点在图像内,则退出循环
flage += 1
print(flage)
if flage == 20:
break
if flage == 20:
continue
resized_image = transformed_image.resize((image_width, image_height))
resize_ration = image_width / transformed_image_width
# 将关键点坐标转换回绝对坐标
for i in range(0, len(keypoints), 3):
x, y, v = [keypoints[i], keypoints[i + 1], keypoints[i + 2]]
keypoints[i], keypoints[i+1], keypoints[i+2] = [
x * resize_ration, y * resize_ration, v]
# 更新数据
data["keypoints"] = [keypoints]
data["imagepath"] = os.path.splitext(data["imagepath"])[0] + "-" + flag + ".jpg"
# 保存增强后的图像
transformed_image_path = os.path.join(output_folder, os.path.splitext(os.path.basename(image_path))[0] + "-" + flag + ".jpg")
resized_image.save(transformed_image_path)
# 更新 JSON 文件
transformed_json_file_path = os.path.join(output_folder, os.path.splitext(os.path.basename(json_file_path))[0] + "-" + flag + ".json")
with open(transformed_json_file_path, 'w') as transformed_json_file:
json.dump(data, transformed_json_file, indent=2)
print(transformed_image_path)
print("Data augmentation completed.")
input_path = ".../mmpose/data/coco/ori"
output_path = ".../mmpose/data/coco/col1"
flag = "col1"
color(flag, input_path, output_path) # flag = col1 col2
# position(flag, input_path, output_path) # flag = posi
一般会进行一次位置变换,两次色彩变换(一次对原数据,一次对位置变换数据),分别运行3次,记得每次变换用不同的flag
来作为文件后缀,便于区分且后面都要混合在一起。图像和json数据都放在一个文件夹内作为input_path
。
增前后如图,数据存放在coco内,annotations
和person_detection_results
是存放coco和bbox的label的,后面会讲到。增强后在每个存放数据的文件夹内新建一个kp_json
和det_json
文件夹,将标注的label都move到kp_json
中。
制作bbox文件
然后用yolov5的yolov5x6.pt
模型来识别ori
和posi
内的图像(因为识别行人最准)来作为行人框(topdown模式需要先目标识别再进行关键点识别)
python detect.py --weights yolov5x6.pt --imgsz 1280 --source .../mmpose/data/coco/posi/ --save-txt --classes 0
识别后将每组识别结果内的标签文件夹labels
复制粘贴一份,并命名为col
,作为每组图像的色彩变换的识别结果,因为色彩变换后每个点的位置都不变所以复制一份即可。然后将col内的标签文件进行批量改名:
import os
path = ".../yolov5/runs/detect/posi/col"
files = os.listdir(path)
for file in files:
file_name = os.path.join(path, file)
new_f = os.path.splitext(file)[0] + "-col2.txt"
new_name = os.path.join(path, new_f)
os.rename(file_name, new_name)
print(new_name)
若后缀为col1则这里也相应是col1,要相互对应。
接下来yolo格式转labelme格式,每个数据文件夹都进行转换,结果保存在det_json
文件夹内,代码如下:
# -*- coding: utf-8 -*-
import json
import cv2
from glob import glob
import os
txt_path = '.../yolov5/runs/detect/exp/labels/' # darknet格式
saved_path = '.../mmpose/data/coco/posi/det_json/'
img_path = '.../mmpose/data/coco/posi/'
kp_path = ".../mmpose/data/coco/posi/kp_json/"
files = glob(img_path + "*.jpg")
# files = os.listdir(txt_path)
# print(files)
files = [i.split('/')[-1].split('.jpg')[0] for i in files]
print(files)
err_n = 0
err_name = []
for file in files:
print(file)
txt_file = txt_path + file + '.txt'
img_file = img_path + file + '.jpg'
kp_file = kp_path + file + ".json"
print(img_file)
img = cv2.imread(img_file)
# print(img)
imgw = img.shape[1]
imgh = img.shape[0]
flag = False
xi = []
yi = []
xa = []
ya = []
Label = []
if not os.path.exists(txt_file):
flag = True
Label.append('person')
else:
with open(txt_file, 'r') as f: # 读取txt文件将所有标签名存入数组
# points = []
area = 0
for line in f.readlines():
line = line.strip('\n')
a = line.split(' ')
label = 'other'
if a[0] == '0':
label = 'person' # 'head'
# elif a[0] == '1':
# label = 'hat' # 'hat'
# 这里是自己命名的类别及对应的数字
Label.append(label)
print(Label)
centerx = float(a[1]) * imgw
centery = float(a[2]) * imgh
w = float(a[3]) * imgw
h = float(a[4]) * imgh
x1 = centerx - w / 2
x2 = centerx + w / 2
y1 = centery - h / 2
y2 = centery + h / 2
x1, x2 = min(x1, x2), max(x1, x2)
y1, y2 = min(y1, y2), max(y1, y2)
with open(kp_file, 'r') as fj0:
data = json.load(fj0)
# 提取关键点信息
keypoints = data['keypoints'][0]
min_x, min_y = float('inf'), float('inf')
max_x, max_y = float('-inf'), float('-inf')
# 遍历关键点,找到最小和最大坐标
for i in range(0, len(keypoints), 3):
x, y, confidence = keypoints[i], keypoints[i + 1], keypoints[i + 2]
if confidence > 0: # 只考虑可信度大于 0 的关键点
min_x = min(min_x, x)
min_y = min(min_y, y)
max_x = max(max_x, x)
max_y = max(max_y, y)
center_x = (min_x + max_x)/2
center_y = (min_y + max_y) / 2
if (x1 < center_x < x2) and (y1 < center_y < y2):
if int((x2 - x1) * (y2 - y1)) > area:
area = int((x2 - x1)*(y2 - y1))
# points = [x1, x2, y1, y2]
else:
continue
xi = [x1]
yi = [y1]
xa = [x2]
ya = [y2]
with open(kp_file, 'r') as fj:
data = json.load(fj)
# 提取关键点信息
keypoints = data['keypoints'][0]
min_x, min_y = float('inf'), float('inf')
max_x, max_y = float('-inf'), float('-inf')
# 遍历关键点,找到最小和最大坐标
for i in range(0, len(keypoints), 3):
x, y, confidence = keypoints[i], keypoints[i + 1], keypoints[i + 2]
if confidence > 0: # 只考虑可信度大于 0 的关键点
min_x = min(min_x, x)
min_y = min(min_y, y)
max_x = max(max_x, x)
max_y = max(max_y, y)
print(flag, len(xi))
if flag or (len(xi) == 0):
xi = [min_x - 80]
if xi[0] < 0:
xi[0] = 0
yi = [min_y - 80]
if yi[0] < 0:
yi[0] = 0
xa = [max_x + 80]
if xa[0] > imgw:
xa[0] = imgw
ya = [max_y + 80]
if ya[0] > imgh:
ya[0] = imgh
err_n += 1
err_name.append(file)
else:
if min_x < xi[0]:
xi[0] = min_x - 80
if xi[0] < 0:
xi[0] = 0
if min_y < yi[0]:
yi[0] = min_y - 80
if yi[0] < 0:
yi[0] = 0
if max_x > xa[0]:
xa[0] = max_x + 80
if xa[0] > imgw:
xa[0] = imgw
if max_y > ya[0]:
ya[0] = max_y + 80
if ya[0] > imgh:
ya[0] = imgh
# for j in range(0, len(files)):
labelme_formate = {
"version": "3.16.7",
"flags": {},
"lineColor": [0, 255, 0, 128],
"fillColor": [255, 0, 0, 128],
"imagePath": file + ".jpg",
"imageHeight": imgh,
"imageWidth": imgw
}
labelme_formate['imageData'] = None
shapes = []
for i in range(0, len(xi)):
s = {"label": Label[i], "line_color": None, "fill_color": None, "shape_type": "rectangle"}
points = [
[xi[i], yi[i]],
[xa[i], ya[i]]
]
s['points'] = points
shapes.append(s)
labelme_formate['shapes'] = shapes
json.dump(labelme_formate, open(saved_path + file + ".json", 'w'), ensure_ascii=False, indent=2)
print(saved_path + file + ".json")
print(err_name)
print(err_n)
因为我是一张图就只有一个目标,为了防止背景有干扰也被yolo识别出来,这里也有进行一些过滤筛选,首先选出最大框,其次框内没有关键点的也被过滤。
如果标注的17个点没有全落在最终得到的框内,则证明识别出的框不是很准确没有完全包裹住整个人,一般这种情况比较少,所以这种情况的文件名最终会被print出来。这时候需要用labelme工具来打开这几个文件来手动修正;
若err_n
即出错的个数占比比较小但基数比较大懒得去一个一个找的话也没问题,因为这种情况我在代码内就直接对所有关键点的最大最小点分别进行外扩一定的像素进行处理,虽然不是很准确但是可以保证这个人是被框住的。
处理完成后,将所有数据增强的图像数据、关键点数据、目标框数据都合并到原数据的文件夹内(还保持着kp_json
和det_json
与图像数据在同一级的这个结构)
最终可以用上面得到的结果来转bbox文件了:
import json
import os
# 创建一个空列表,用于存储每个标注内容
annotations_list = []
# det json
data_folder = ".../mmpose/data/coco/ori/det_json"
# 将列表保存为一个新的JSON文件
bbox_file_path = '.../mmpose/data/coco/person_detection_results/ori_bbox.json'
image_id = 1
# 遍历包含JSON标注文件的文件夹
for json_file in os.listdir(data_folder):
if json_file.endswith('.json'):
json_path = os.path.join(data_folder, json_file)
with open(json_path, 'r') as f:
# 读取JSON文件并解析为Python对象
data = json.load(f)
# 提取关键点信息
points = data['shapes'][0]["points"]
x1 = points[0][0]
y1 = points[0][1]
x2 = points[1][0]
y2 = points[1][1]
bbox = [int(x1), int(y1), int(x2 - x1), int(y2 - y1)]
#print("Bbox coordinates:", bbox)
# 提取所需信息
#bbox = [0, 0, 1920, 1080]
category_id = 1
score = 0.992816648089145264
# 创建一个新字典,包含所需信息
annotation = {
"bbox": bbox,
"category_id": category_id,
"image_id": image_id,
"score": score
}
# 将新字典添加到列表中
annotations_list.append(annotation)
print(image_id, json_file)
image_id += 1
with open(bbox_file_path, 'w') as output_file:
json.dump(annotations_list, output_file)
print(f'Bboxes file saved to {bbox_file_path}')
制作coco文件
代码如下:
#生成Coco标注文件
import os
import json
import cv2
# 指定包含 JSON 标注文件的目录
kp_json_path = ".../mmpose/data/coco/ori/kp_json"
det_json_path = ".../mmpose/data/coco/ori/det_json"
img_path = ".../mmpose/data/coco/ori"
output_file = ".../mmpose/data/coco/annotations/ori_coco.json"
# 初始化 COCO 数据集骨架
coco_data = {
"info": {"description": "COCO 2017 Dataset","url": "http://cocodataset.org","version": "1.0","year": 2017,"contributor": "COCO Consortium","date_created": "2017/09/01"},
"licenses": [{"url": "http://creativecommons.org/licenses/by-nc-sa/2.0/","id": 1,"name": "Attribution-NonCommercial-ShareAlike License"},{"url": "http://creativecommons.org/licenses/by-nc/2.0/","id": 2,"name": "Attribution-NonCommercial License"},{"url": "http://creativecommons.org/licenses/by-nc-nd/2.0/","id": 3,"name": "Attribution-NonCommercial-NoDerivs License"},{"url": "http://creativecommons.org/licenses/by/2.0/","id": 4,"name": "Attribution License"},{"url": "http://creativecommons.org/licenses/by-sa/2.0/","id": 5,"name": "Attribution-ShareAlike License"},{"url": "http://creativecommons.org/licenses/by-nd/2.0/","id": 6,"name": "Attribution-NoDerivs License"},{"url": "http://flickr.com/commons/usage/","id": 7,"name": "No known copyright restrictions"},{"url": "http://www.usa.gov/copyright.shtml","id": 8,"name": "United States Government Work"}],
"images": [],
"annotations": [],
"categories": [{"supercategory": "person","id": 1,"name": "person","keypoints": ["nose","left_eye","right_eye","left_ear","right_ear","left_shoulder","right_shoulder","left_elbow","right_elbow","left_wrist","right_wrist","left_hip","right_hip","left_knee","right_knee","left_ankle","right_ankle"],"skeleton": [[16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13],[6,7],[6,8],[7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]]}]
}
'''
# 添加 COCO 类别
category = {
"id": 1,
"name": "person",
"supercategory": "person"
}
coco_data["categories"].append(category)
'''
image_id = 1
annotation_id = 1
new_image_directory = ""
# 遍历目录下的所有 JSON 文件
for filename in os.listdir(kp_json_path):
if filename.endswith(".json"):
annotation_file = os.path.join(kp_json_path, filename)
det_json_file = os.path.join(det_json_path, filename)
img_file = os.path.join(img_path, os.path.splitext(filename)[0] + ".jpg")
with open(annotation_file, "r") as file:
annotation_data = json.load(file)
old_file_name = annotation_data["imagepath"]
# 构建新的文件路径
file_name = os.path.basename(old_file_name)
new_file_path = os.path.join(new_image_directory, file_name)
img = cv2.imread(img_file)
height, width = img.shape[:2]
# 处理图像信息
image_info = {
"id": image_id,
"file_name": new_file_path,
"width": width,
"height": height
}
coco_data["images"].append(image_info)
if "keypoints" in annotation_data and annotation_data["keypoints"]:
keypoints = annotation_data["keypoints"][0]
else:
continue
keypoint_count = int(len(keypoints) / 3)
keypoint_list = []
for i in range(keypoint_count):
x, y, v = keypoints[i * 3], keypoints[i * 3 + 1], keypoints[i * 3 + 2]
keypoint_list.append([x, y, v])
bbox = []
with open(det_json_file, 'r') as f2:
# 读取JSON文件并解析为Python对象
data = json.load(f2)
# 提取关键点信息
points = data['shapes'][0]["points"]
x1 = points[0][0]
y1 = points[0][1]
x2 = points[1][0]
y2 = points[1][1]
bbox = [int(x1), int(y1), int(x2 - x1), int(y2 - y1)]
annotation = {
"id": annotation_id,
"image_id": image_id,
"category_id": 1,
"keypoints": [0] * (3 * keypoint_count),
"num_keypoints": keypoint_count,
"bbox": bbox,
"area": bbox[2] * bbox[3],
'ignore': 0,
"iscrowd": 0
}
for i, keypoint in enumerate(keypoint_list):
annotation["keypoints"][i * 3] = keypoint[0]
annotation["keypoints"][i * 3 + 1] = keypoint[1]
annotation["keypoints"][i * 3 + 2] = keypoint[2]
coco_data["annotations"].append(annotation)
# 增加图像和标注 ID
print(image_id, filename)
image_id += 1
annotation_id += 1
# 保存 COCO 数据为 JSON 文件
with open(output_file, "w") as json_file:
json.dump(coco_data, json_file)
print(f'Annotations file saved to {output_file}')
计算均值方差
import os
import cv2
import numpy as np
count = 0
def calculate_mean_std_batch(folder_path, batch_size=100):
global count
image_paths = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if
file.endswith('.jpg') or file.endswith('.png')]
total_pixels = 0 # 所有图像的像素总数
pixel_sum = np.zeros(3) # 用于存储像素值的总和,RGB三个通道
pixel_squared_diff_sum = np.zeros(3) # 用于存储像素值差的平方和,RGB三个通道
for i in range(0, len(image_paths), batch_size):
batch_paths = image_paths[i:i + batch_size]
batch_images = [cv2.imread(path) for path in batch_paths]
batch_images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in batch_images]
for image in batch_images:
total_pixels += image.size / 3 # 计算像素总数
pixel_sum += np.sum(image, axis=(0, 1)) # 计算像素值的总和
mean = pixel_sum / total_pixels # 计算当前批次的均值
for image in batch_images:
pixel_squared_diff_sum += np.sum((image - mean) ** 2, axis=(0, 1)) # 计算像素值差的平方和
count += batch_size
print(str(count))
# 计算均值
mean = pixel_sum / total_pixels
# 计算标准差
std = np.sqrt(pixel_squared_diff_sum / total_pixels)
return mean, std
folder_path = ".../mmpose/data/coco/ori" # 更改为包含图像文件的文件夹路径
mean, std = calculate_mean_std_batch(folder_path)
print("Mean:", mean)
print("Standard Deviation:", std)
训练配置文件
选择2d的人体关键点检测,topdown_heatmap模式的,创建或从例子里面更改配置文件,这是其中一个配置文件的例子:
_base_ = ['../../../_base_/default_runtime.py']
# runtime
train_cfg = dict(max_epochs=300, val_interval=10)
# optimizer
optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=1e-5,
weight_decay=1e-5,
))
# learning policy
param_scheduler = [
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
begin=0,
end=210,
milestones=[170, 200],
gamma=0.1,
by_epoch=True)
]
# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)
# hooks
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))
# codec settings
codec = dict(
type='MSRAHeatmap', input_size=(288, 384), heatmap_size=(72, 96), sigma=3)
# model settings
model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[108.36341352, 109.43100226, 103.20111829],
std=[69.45071565, 64.45677341, 63.89262599],
bgr_to_rgb=True),
backbone=dict(
type='HRNet',
in_channels=3,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(48, 96)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(48, 96, 192)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(48, 96, 192, 384))),
init_cfg=dict(
type='Pretrained',
checkpoint='pretrain/pose_hrnet_person_384x288.pth'),
),
head=dict(
type='HeatmapHead',
in_channels=48,
out_channels=17,
deconv_out_channels=None,
loss=dict(type='KeypointMSELoss', use_target_weight=True),
decoder=codec),
test_cfg=dict(
flip_test=True,
flip_mode='heatmap',
shift_heatmap=True,
))
# base dataset settings
dataset_type = 'CocoDataset'
data_mode = 'topdown'
data_root = 'data/coco/'
# pipelines
train_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(type='RandomBBoxTransform'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
val_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='PackPoseInputs')
]
# data loaders
train_dataloader = dict(
batch_size=32,
num_workers=5,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/barlegup_coco.json',
data_prefix=dict(img='barlegup/'),
pipeline=train_pipeline,
))
val_dataloader = dict(
batch_size=32,
num_workers=5,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/barlegup_coco.json',
bbox_file='data/coco/person_detection_results/barlegup_bbox.json',
data_prefix=dict(img='barlegup/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = val_dataloader
# evaluators
val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'annotations/barlegup_coco.json')
test_evaluator = val_evaluator
首先是改max_epochs
最大的训练轮数,一般我是训练300轮;
milestones
lr衰减的里程碑,则训练到指定轮数开始lr衰减;
input_size
为训练的输入大小,注意这里的顺序是宽、高,如果训练的图像大多数竖向或者要推理的摄像头画面为竖向的,则这边是宽小高大,输入进网络的图片才不会resize太多导致人物变得很小。这里的heatmap_size
相应为它的四分之一;
然后将刚才计算得出的均值和方差填入其中;
这里是预训练模型,如果有训练自己的模型要作为预训练模型,可以在这里更改;
如果显存不够可以减小batch_size
,需要注意,如果是多卡训练,这里的batch_size
是每张卡显存所能支持的batch_size
数量而不是所有显存所支持的;
如果cpu性能不足可以减小num_workers
;
ann_file
为刚才得到的coco文件地址,bbox_file
为得到的bbox地址,data_prefix
为训练数据的地址。
训练命令
若是多卡训练,假设是双显卡,则可在主目录下输入以下命令进行训练:
bash ./tools/dist_train.sh .../mmpose/configs/body_2d_keypoint/topdown_heatmap/my/xxx.py 2
一个是刚才制作的配置文件地址,后面的数字则是显卡数。
如果是单卡训练则:
python ./tools/train.py .../mmpose/configs/body_2d_keypoint/topdown_heatmap/my/xxx.py
就跟一个配置文件地址即可。