当前位置: 首页 > article >正文

【人工智能学习之STGCN训练自己的数据集】

STGCN训练自己的数据集

  • 准备事项
  • 数据集制作
    • 视频转json
    • jsons转json
    • json转npy&pkl
  • 训练STGCN
    • 添加图结构
    • 修改训练参数
    • 开始训练
    • 测试

准备事项

  1. st-gcn代码下载与环境配置
git clone https://github.com/yysijie/st-gcn.git
cd st-gcn
pip install -r requirements.txt
cd torchlight
python setup.py install
cd ..
  1. 数据集结构
    可以使用open pose制作数据集,制作过程见下章。
    (参考openpose环境搭建和利用openpose提取自建数据集)

我的数据集结构如下:

dataset/
├── stgcn_data/ 				# 最终使用的数据集
│   ├── train  					# 训练数据集
│   └── val/  					# 验证数据集
├── video/  					# 视频数据集
│   └── fall  					# 分类0视频文件夹
│	└── normal					# 分类1视频文件夹
│	└── resized					# 视频缩放与json
│		└── data				# 视频对应的json
│		└── fall				# 分类0缩放视频
│		└── normal				# 分类1缩放视频
│		└── snippets			# 视频每一帧的json
│	└── label0.txt				# 分类0标签文本
│	└── label1.txt				# 分类1标签文本
└── pose_demo/  				# openpose目录
    └── bin/  					# 目录
    	└── openpose_demo.exe/  # openpose的exe可执行文件
    	└── XXX.dll/  			# 各种依赖项

数据集制作

视频转json

首先需要将视频放到对应的目录下,目录名称就是你的类名。

提示一下:用于st-gcn训练的数据集视频帧数不要超过300帧,5~6s的视频时长比较好,不要10几s的视频时长。要不然会报:index 300 is out of bounds for axis 1 with size 300 这种错误。很多博主使用的用FFmpeg对视频进行缩放,但我FFmpeg老是出问题,索性直接自己用cv2实现了。
【win64中FFmpegReader报错:Cannot find installation of real FFmpeg (which comes with ffprobe).】

以下是我video2json的代码,存于dataset目录下运行:

#!/usr/bin/env python
import os
import argparse
import json
import shutil

import numpy as np
import torch
# import skvideo
# ffmpeg_path = r"/dataset/ffmpeg-master-latest-win64-gpl/ffmpeg-master-latest-win64-gpl/bin"
# skvideo.setFFmpegPath(ffmpeg_path)
import skvideo.io

# from .io import IO
import tools
import tools.utils as utils

import cv2
import os


def resize_video(input_path, output_path, size=(340, 256), fps=30):
    # 打开视频文件
    cap = cv2.VideoCapture(input_path)

    # 获取视频帧率
    if fps == 0:
        fps = cap.get(cv2.CAP_PROP_FPS)

    # 获取视频编码格式
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')

    # 创建 VideoWriter 对象
    out = cv2.VideoWriter(output_path, fourcc, fps, size)

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # 调整帧大小
        resized_frame = cv2.resize(frame, size)

        # 写入输出文件
        out.write(resized_frame)

    # 释放资源
    cap.release()
    out.release()

def resize_all_video(originvideo_file,resizedvideo_file):
    # 获取视频文件名列表
    videos_file_names = [f for f in os.listdir(originvideo_file) if f.endswith('.mp4')]

    # 遍历并处理每个视频文件
    for file_name in videos_file_names:
        video_path = os.path.join(originvideo_file, file_name)
        outvideo_path = os.path.join(resizedvideo_file, file_name)
        resize_video(video_path, outvideo_path)
        print(f'{file_name} resize success')


def get_video_frames(video_path):
    """
    读取视频文件并返回所有帧
    :param video_path: 视频文件路径
    :return: 帧列表
    """
    frames = []
    cap = cv2.VideoCapture(video_path)

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)

    cap.release()
    return frames

class PreProcess():
    """
        利用openpose提取自建数据集的骨骼点数据
    """

    def start(self):

        ###########################修改处################
        type_number = 2
        gongfu_filename_list = ['fall','normal']
        #################################################

        for process_index in range(type_number):

            gongfu_filename = gongfu_filename_list[process_index]
            # 标签信息
            # labelgongfu_name = 'xxx_{}'.format(process_index)
            labelAction_name = '{}'.format(gongfu_filename)
            label_no = process_index

            # 视频所在文件夹
            originvideo_file = './video/{}/'.format(gongfu_filename)
            # resized视频输出文件夹
            resizedvideo_file = './video/resized/{}/'.format(gongfu_filename)
            '''
            videos_file_names = os.listdir(originvideo_file)

            # 1. Resize文件夹下的视频到340x256 30fps
            for file_name in videos_file_names:
                video_path = '{}{}'.format(originvideo_file, file_name)
                outvideo_path = '{}{}'.format(resizedvideo_file, file_name)
                writer = skvideo.io.FFmpegWriter(outvideo_path,
                                                 outputdict={'-f': 'mp4', '-vcodec': 'libx264', '-s': '340x256',
                                                             '-r': '30'})
                reader = skvideo.io.FFmpegReader(video_path)
                for frame in reader.nextFrame():
                    writer.writeFrame(frame)
                writer.close()
                print('{} resize success'.format(file_name))
'''
            # 1. Resize文件夹下的视频到340x256 30fps cv处理
            resize_all_video(originvideo_file,resizedvideo_file)
            # 2. 利用openpose提取每段视频骨骼点数据
            resizedvideos_file_names = os.listdir(resizedvideo_file)
            for file_name in resizedvideos_file_names:
                outvideo_path = '{}{}'.format(resizedvideo_file, file_name)

                # openpose = '{}/examples/openpose/openpose.bin'.format(self.arg.openpose)
                openpose = 'C:/WorkFiles/company_server_SSH/st-gcn-master/dataset/pose_demo/bin/OpenPoseDemo.exe'
                video_name = file_name.split('.')[0]
                output_snippets_dir = './video/resized/snippets/{}'.format(video_name)
                print(f"Output snippets directory: {output_snippets_dir}")
                output_sequence_dir = 'video/resized/data'
                output_sequence_path = '{}/{}.json'.format(output_sequence_dir, video_name)

                label_name_path = 'video/label{}.txt'.format(process_index)
                with open(label_name_path) as f:
                    label_name = f.readlines()
                    label_name = [line.rstrip() for line in label_name]

                # pose estimation
                openpose_args = dict(
                    video=outvideo_path,
                    write_json=output_snippets_dir,
                    display=0,
                    render_pose=0,
                    model_pose='COCO')
                command_line = openpose + ' '
                command_line += ' '.join(['--{} {}'.format(k, v) for k, v in openpose_args.items()])
                print(f"Running command: {command_line}")
                shutil.rmtree(output_snippets_dir, ignore_errors=True)
                os.makedirs(output_snippets_dir)
                os.system(command_line)

                # pack openpose ouputs
                # video = utils.video.get_video_frames(outvideo_path)
                video = get_video_frames(outvideo_path)
                height, width, _ = video[0].shape


                # 这里可以修改label, label_index
                video_info = utils.openpose.json_pack(
                    output_snippets_dir, video_name, width, height, labelAction_name , label_no)

                if not os.path.exists(output_sequence_dir):
                    os.makedirs(output_sequence_dir)

                with open(output_sequence_path, 'w') as outfile:
                    json.dump(video_info, outfile)
                if len(video_info['data']) == 0:
                    print('{} Can not find pose estimation results.'.format(file_name))
                    return
                else:
                    print('{} pose estimation complete.'.format(file_name))

if __name__ == '__main__':
    p=PreProcess()
    p.start()

注意:
如果你的openpose装的是cpu版本,会比较慢,跑很多视频生成json文件的时间就会很长。这个时候晚上如果去睡觉了一定一定一定要设置为从不熄屏休眠!!!从不!!!
不然就会这样:
在这里插入图片描述
经过漫长的等待之后。。。。终于运行完成了,我们可以得到

  1. 缩放过后的视频。

  2. 每一帧的json:在snippets目录下,以视频名称命名的文件下的每一帧json文件:在这里插入图片描述打开是这样的:在这里插入图片描述

  3. 每一个视频的json:在data目录下,以视频名称命名的json文件:打开是这样的:在这里插入图片描述
    到这里,得到了每个视频的json文件,数据集的制作就完成了一半了。

jsons转json

得到了每个视频的json文件之后,需要我们手动将data目录下的json文件分配到train和val下,一般按照9:1划分。放好之后运行下面这段jsons2json.py:

以下是我jsons2json的代码,存于stgcn_data目录下运行:

import json
import os

if __name__ == '__main__':
    train_json_path = './train'
    val_json_path = './val'
    test_json_path = './test'

    output_train_json_path = './train_label.json'
    output_val_json_path = './val_label.json'
    output_test_json_path = './test_label.json'

    train_json_names = os.listdir(train_json_path)
    val_json_names = os.listdir(val_json_path)
    test_json_names = os.listdir(test_json_path)

    train_label_json = dict()
    val_label_json = dict()
    test_label_json = dict()


    for file_name in train_json_names:
        name = file_name.split('.')[0]
        json_file_path = '{}/{}'.format(train_json_path, file_name)
        json_file = json.load(open(json_file_path))

        file_label = dict()
        if len(json_file['data']) == 0:
            file_label['has_skeleton'] = False
        else:
            file_label['has_skeleton'] = True
        file_label['label'] = json_file['label']
        file_label['label_index'] = json_file['label_index']

        train_label_json['{}'.format(name)] = file_label

        print('{} success'.format(file_name))

    with open(output_train_json_path, 'w') as outfile:
        json.dump(train_label_json, outfile)

    for file_name in val_json_names:
        name = file_name.split('.')[0]
        json_file_path = '{}/{}'.format(val_json_path, file_name)
        json_file = json.load(open(json_file_path))

        file_label = dict()
        if len(json_file['data']) == 0:
            file_label['has_skeleton'] = False
        else:
            file_label['has_skeleton'] = True
        file_label['label'] = json_file['label']
        file_label['label_index'] = json_file['label_index']

        val_label_json['{}'.format(name)] = file_label

        print('{} success'.format(file_name))

    with open(output_val_json_path, 'w') as outfile:
        json.dump(val_label_json, outfile)

    for file_name in test_json_names:
        name = file_name.split('.')[0]
        json_file_path = '{}/{}'.format(test_json_path, file_name)
        json_file = json.load(open(json_file_path))

        file_label = dict()
        if len(json_file['data']) == 0:
            file_label['has_skeleton'] = False
        else:
            file_label['has_skeleton'] = True
        file_label['label'] = json_file['label']
        file_label['label_index'] = json_file['label_index']

        test_label_json['{}'.format(name)] = file_label

        print('{} success'.format(file_name))

        with open(output_test_json_path, 'w') as outfile:
            json.dump(test_label_json, outfile)

运行完成后,我们就可以得到总的训练和验证的json文件了,里面包含了所有的视频动作关键点和标签信息。
还差一步就完成啦!
在这里插入图片描述

json转npy&pkl

第三步需要用到官方的工具文件kinetics_gendata.py。
所在的路径应该是:./st-gcn-master/tools/kinetics_gendata.py

这里有两处地方需要修改:

  1. 修改考虑参数
def gendata(
        data_path,
        label_path,
        data_out_path,
        label_out_path,
        num_person_in=3,  #observe the first 5 persons
        #这个参数指定了在每个视频帧中考虑的最大人数。
        #例如,如果设置为5,则脚本会尝试从每个视频帧中获取最多5个人的骨架信息。
        #这并不意味着每个视频帧都会有5个人,而是说脚本最多会处理5个人的数据。
        num_person_out=1,  #then choose 2 persons with the highest score
        #这个参数指定了最终每个序列(或视频)中保留的人数。
        #即使输入中有更多的人员,也只有得分最高的两个人会被选择出来用于训练。
        #这里的“得分”通常指的是骨架检测的置信度分数,即算法对某个点确实是人体某部位的信心程度。
        max_frame=300):

  1. 修改路径位置
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Kinetics-skeleton Data Converter.')
    parser.add_argument(
        '--data_path', default='./st-gcn-master/dataset/stgcn_data')
    parser.add_argument(
        '--out_folder', default='./st-gcn-master/dataset/stgcn_data')
    arg = parser.parse_args()

    part = ['train', 'val']
    for p in part:
        data_path = '{}/{}'.format(arg.data_path, p)
        label_path = '{}/{}_label.json'.format(arg.data_path, p)
        data_out_path = '{}/{}_data.npy'.format(arg.out_folder, p)
        label_out_path = '{}/{}_label.pkl'.format(arg.out_folder, p)

        if not os.path.exists(arg.out_folder):
            os.makedirs(arg.out_folder)
        gendata(data_path, label_path, data_out_path, label_out_path)

ok,检查一下是否得到了npy和pkl文件
在这里插入图片描述

yes,到这里就可以开始训练啦。

训练STGCN

添加图结构

在net/utils/graph.py文件里面get_edge函数中保存的是不同的图结构。

注意这里的默认的layout如果符合自己定义的姿态就不用修改,否则需要自定义一个,本文采用的openpose即默认的openpose的18个关键点,不需要修改。其中num_node为关键点的个数,neighbor_link为关键点连接关系。如果自己的数据集是新定义的姿态点数不为18,在后续转换中可能还有修改需要保持一致。

修改训练参数

  1. 将config/st_gcn/kinetics-skeleton/train.yaml复制一份到根目录,重命名为mytrain.yaml,并修改其中参数。
  2. data_path和label_path修改为之前生成的文件路径;
  3. num_class改为自建数据集的行为类别个数;
  4. layout参数修改为之前添加的layout类别;
  5. strategy设置为spatial;
  6. 修改使用的GPU数量,单个设置device: [0];
  7. optim部分适当调整,base_lr: 0.1是基础学习率,step: [80, 120, 160, 200]:这表明使用了一种学习率衰减策略,num_epoch: 200:指定了整个训练过程将进行多少个周期(epochs)
  8. 不知道是不是我搞错了,居然不会自动保存best?而是每10轮自动保存一次;

以下是我的yaml:

work_dir: ./work_dir/recognition/kinetics_skeleton/ST_GCN

# feeder
feeder: feeder.feeder.Feeder
train_feeder_args:
  random_choose: True
  random_move: True
  window_size: 150 
  data_path: C:/WorkFiles/company_server_SSH/st-gcn-master/dataset/stgcn_data/train_data.npy
  label_path: C:/WorkFiles/company_server_SSH/st-gcn-master/dataset/stgcn_data/train_label.pkl
test_feeder_args:
  data_path: C:/WorkFiles/company_server_SSH/st-gcn-master/dataset/stgcn_data/val_data.npy
  label_path: C:/WorkFiles/company_server_SSH/st-gcn-master/dataset/stgcn_data/val_label.pkl

# model
model: net.st_gcn.Model
model_args:
  in_channels: 3
  num_class: 2
  edge_importance_weighting: True
  graph_args:
    layout: 'openpose'
    strategy: 'spatial'

# training
device: [0]
batch_size: 32
test_batch_size: 32

#optim
base_lr: 0.1
step: [80, 120, 160, 200]
num_epoch: 200




开始训练

训练指令(记得先激活环境,或者在pychar的终端运行)
yaml文件可以直接放在根目录下;也可以放在自己喜欢的位置(指令需要加上路径)

python main.py recognition -c mytrain.yaml

在这里插入图片描述
这里我的验证集俩分类各取了3个,总训练集200多个,所以top很高。

测试

可以像训练集那样仿照一个测试集出来。
也可以像我一样直接使用视频进行测试,不过我的代码涉及了一些实际的工业工作内容,无法提供,但可以给出一些思路:

  1. 通过其他关键点网络输出关键点(我用的YoloPose)
  2. 将某自定义时长内的所有关键点拼接在一起重塑为 (N, in_channels, T_in, V_in, M_in) 形状
  3. 通过网络输出获取这个时间段内的行为分类

http://www.kler.cn/a/417724.html

相关文章:

  • 文字加持:让 OpenCV 轻松在图像中插上文字
  • 服务端渲染技术
  • 录音质检,只质检录音,没有显卡的服务器配置分析
  • Hive on Spark优化
  • 扩展域并查集 带权并查集
  • Spring中@Conditional注解详解:条件装配的终极指南
  • 深度学习每周学习总结J7(对ResNeXt-50 算法的思考)
  • 洛谷P1443 马的遍历
  • 电路基础——相量法
  • 积鼎科技携手西北工业大学动力与能源学院共建复杂多相流仿真联合实验室
  • VUE 入门级教程:开启 Vue.js 编程之旅
  • 008静态路由-特定主机路由
  • 矩阵乘法实现填充矩阵F.padding
  • C语言模拟实现简单链表的复盘
  • 【机器学习算法】XGBoost原理
  • 【故障处理系列--业务官网无法打开】
  • 0017. shell命令--tac
  • Milvus的索引类型
  • 【经典论文阅读】Transformer(多头注意力 编码器-解码器)
  • 打造高质量技术文档的关键要素(结合MATLAB)
  • selinux、firewalld
  • JD - HotKey:缓存热 Key 管理的高效解决方案
  • Vue + Vite + Element Plus 与 Django 进行前后端对接
  • 【系统架构设计师】高分论文:论敏捷软件开发方法及其成用
  • TYUT设计模式大题
  • 架构04-透明多级分流系统