当前位置:   article > 正文

使用pyskl的stgcn++训练自己的数据集

pyskl

https://github.com/kennymckormick/pyskl 包含多种动作分类的模型,感谢大佬

训练过程主要参考项目中的

examples/extract_diving48_skeleton/diving48_example.ipynb

但是我一开始不知道这个文件,从网上查不到太多的资料,走了不少弯路,这里就把我训练的过程分享一下。

1.准备自己的数据集

这里使用的是Weizmann数据集,一个有10个分类,每个类别差不多有10个视频。

分成训练集和测试集,目录如下,最好让视频名称按照 ‘视频名_类别.mp4’这样的方式(主要是让视频名称里面含有类别的字段、或者类别的序号,后续好处理)

我的视频名称是这样的,daria_0.avi,我改了原始的视频名称

类别标签按照下面的方式定义,类别序号从0开始,且必须是连续的,要不然后面训练时会报错。

{'bend': '1', 'jack': '2', 'jump': '3', 'pjump': '4','run':'5','side':'6','skip':'7','walk':'8','wave1':'9','wave2':'0'}

2、 按照下述代码,生成train.jaon和test.json

也可以不这样生成,但是json里的内容后续要用

  1. def writeJson(path_train,jsonpath):
  2. outpot_list=[]
  3. trainfile_list = os.listdir(path_train)
  4. for train_name in trainfile_list:
  5. traindit = {}
  6. sp = train_name.split('_')
  7. traindit['vid_name'] = train_name.replace('.avi', '')
  8. traindit['label'] = int(sp[1].replace('.avi', ''))
  9. traindit['start_frame'] = 0
  10. video_path=os.path.join(path_train,train_name)
  11. vid = decord.VideoReader(video_path)
  12. traindit['end_frame'] = len(vid)
  13. outpot_list.append(traindit.copy())
  14. with open(jsonpath, 'w') as outfile:
  15. json.dump(outpot_list, outfile)

生成的json内容如下,这里的vid_name为视频名称去掉了文件扩展名,label为定义的类别序号,

start_frame为0,end_frame为视频的总帧数。

  1. [
  2. {
  3. "vid_name": "lyova_3",
  4. "label": 3,
  5. "start_frame": 0,
  6. "end_frame": 40
  7. },
  8. ]

3、生成tools/data/custom_2d_skeleton.py需要的list文件

这个Weizmann.list文件,里面包含训练集和测集视频,样式如下

视频路径 + 一个空格 + 类别序号

../data/Weizmann/train/lyova_3.avi 3
../data/Weizmann/train/ira_1.avi 1

生成Weizmann.list文件的代码如下

  1. def writeList(dirpath,name):
  2. path_train = os.path.join(dirpath, 'train')
  3. path_test = os.path.join(dirpath, 'test')
  4. trainfile_list=os.listdir(path_train)
  5. testfile_list=os.listdir(path_test)
  6. train=[]
  7. for train_name in trainfile_list:
  8. traindit={}
  9. sp=train_name.split('_')
  10. traindit['vid_name']= train_name
  11. traindit['label'] = sp[1].replace('.avi','')
  12. train.append(traindit)
  13. test = []
  14. for test_name in testfile_list:
  15. testdit={}
  16. sp=test_name.split('_')
  17. testdit['vid_name']= test_name
  18. testdit['label'] = sp[1].replace('.avi','')
  19. test.append(testdit)
  20. tmpl1 =os.path.join(path_train,'{}')
  21. lines1 = [(tmpl1 + ' {}').format(x['vid_name'], x['label']) for x in train]
  22. tmpl2 = os.path.join(path_test, '{}')
  23. lines2 = [(tmpl2 + ' {}').format(x['vid_name'], x['label']) for x in test]
  24. lines=lines1+lines2
  25. mwlines(lines, os.path.join(dirpath,name))

函数传入的参数,

path是数据集路径 dirpath = '../data/Weizmann'

name为生成的list文件名称,这里为 'Weizmann'

4、调用custom_2d_skeleton.py,生成训练模型要用的pkl文件

然后,调用custom_2d_skeleton.py,我参考另一个博主的文章

基于pyskl的poseC3D训练自己的数据集_骑走的小木马的博客-CSDN博客

修改了custom_2d_skeleton.py的代码,

我使用的是模型如下图,是目标检测模型和关节点检测模型,这两部分可以从mmpose和mmdetection找,然后自己替换。

还有一个插曲,不知道为什么下面这个文件就算下载下来,也不能用,会报错,最后改成了从网上下载。

faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth

{文件下载下来,在运行的时候可能会报找不到checkpoint的错误,那就两种方式都试试,第一种就是下载到本地,default改成本地地址,第二种就是直接从网络加载,default改成链接}

  1. parser.add_argument(
  2. '--det-config',
  3. default='../refe/faster_rcnn_r50_fpn_2x_coco.py',
  4. help='human detection config file path (from mmdet)')
  5. parser.add_argument(
  6. '--det-ckpt',
  7. default=('http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/'
  8. 'faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_'
  9. 'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'),
  10. help='human detection checkpoint file/url')
  11. parser.add_argument('--pose-config', type=str, default='../refe/hrnet_w32_coco_256x192.py')
  12. parser.add_argument('--pose-ckpt', type=str, default='../refe/hrnet_w32_coco_256x192-c78dce93_20200708.pth')
  13. # * Only det boxes with score larger than det_score_thr will be kept
  14. parser.add_argument('--det-score-thr', type=float, default=0.7)
  15. # * Only det boxes with large enough sizes will be kept,
  16. parser.add_argument('--det-area-thr', type=float, default=1300)

里面原本有的文件需要通过网络下载,我提前将那些文件下载下来,放在了refe文件夹下面,如下图

在custom_2d_skeleton.py中,我发现下面这样写,一运行程序就卡,找不到原因,我花了好长时间改这个地方

import mmdet
from mmdet.apis import inference_detector, init_detector

下面是我修改后custom_2d_skeleton.py,

  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os
  4. import os.path as osp
  5. # import pdb
  6. from mmdet.apis import inference_detector, init_detector
  7. from mmpose.apis import inference_top_down_pose_model, init_pose_model
  8. import decord
  9. import mmcv
  10. import numpy as np
  11. # import torch.distributed as dist
  12. from tqdm import tqdm
  13. # import mmdet
  14. # import mmpose
  15. # from pyskl.smp import mrlines
  16. import cv2
  17. from pyskl.smp import mrlines
  18. def extract_frame(video_path):
  19. vid = decord.VideoReader(video_path)
  20. return [x.asnumpy() for x in vid]
  21. def detection_inference(model, frames):
  22. model=model.cuda()
  23. results = []
  24. for frame in frames:
  25. result = inference_detector(model, frame)
  26. results.append(result)
  27. return results
  28. def pose_inference(model, frames, det_results):
  29. model=model.cuda()
  30. assert len(frames) == len(det_results)
  31. total_frames = len(frames)
  32. num_person = max([len(x) for x in det_results])
  33. kp = np.zeros((num_person, total_frames, 17, 3), dtype=np.float32)
  34. for i, (f, d) in enumerate(zip(frames, det_results)):
  35. # Align input format
  36. d = [dict(bbox=x) for x in list(d)]
  37. pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0]
  38. for j, item in enumerate(pose):
  39. kp[j, i] = item['keypoints']
  40. return kp
  41. def parse_args():
  42. parser = argparse.ArgumentParser(
  43. description='Generate 2D pose annotations for a custom video dataset')
  44. # * Both mmdet and mmpose should be installed from source
  45. # parser.add_argument('--mmdet-root', type=str, default=default_mmdet_root)
  46. # parser.add_argument('--mmpose-root', type=str, default=default_mmpose_root)
  47. # parser.add_argument('--det-config', type=str, default='../refe/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco-person.py')
  48. # parser.add_argument('--det-ckpt', type=str,
  49. # default='../refe/faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth')
  50. parser.add_argument(
  51. '--det-config',
  52. default='../refe/faster_rcnn_r50_fpn_2x_coco.py',
  53. help='human detection config file path (from mmdet)')
  54. parser.add_argument(
  55. '--det-ckpt',
  56. default=('http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/'
  57. 'faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_'
  58. 'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'),
  59. help='human detection checkpoint file/url')
  60. parser.add_argument('--pose-config', type=str, default='../refe/hrnet_w32_coco_256x192.py')
  61. parser.add_argument('--pose-ckpt', type=str, default='../refe/hrnet_w32_coco_256x192-c78dce93_20200708.pth')
  62. # * Only det boxes with score larger than det_score_thr will be kept
  63. parser.add_argument('--det-score-thr', type=float, default=0.7)
  64. # * Only det boxes with large enough sizes will be kept,
  65. parser.add_argument('--det-area-thr', type=float, default=1300)
  66. # * Accepted formats for each line in video_list are:
  67. # * 1. "xxx.mp4" ('label' is missing, the dataset can be used for inference, but not training)
  68. # * 2. "xxx.mp4 label" ('label' is an integer (category index),
  69. # * the result can be used for both training & testing)
  70. # * All lines should take the same format.
  71. parser.add_argument('--video-list', type=str, help='the list of source videos')
  72. # * out should ends with '.pkl'
  73. parser.add_argument('--out', type=str, help='output pickle name')
  74. parser.add_argument('--tmpdir', type=str, default='tmp')
  75. parser.add_argument('--local_rank', type=int, default=1)
  76. # pdb.set_trace()
  77. # if 'RANK' not in os.environ:
  78. # os.environ['RANK'] = str(args.local_rank)
  79. # os.environ['WORLD_SIZE'] = str(1)
  80. # os.environ['MASTER_ADDR'] = 'localhost'
  81. # os.environ['MASTER_PORT'] = '12345'
  82. args = parser.parse_args()
  83. return args
  84. def main():
  85. args = parse_args()
  86. assert args.out.endswith('.pkl')
  87. lines = mrlines(args.video_list)
  88. lines = [x.split() for x in lines]
  89. assert len(lines[0]) in [1, 2]
  90. if len(lines[0]) == 1:
  91. annos = [dict(frame_dir=osp.basename(x[0]).split('.')[0], filename=x[0]) for x in lines]
  92. else:
  93. annos = [dict(frame_dir=osp.basename(x[0]).split('.')[0], filename=x[0], label=int(x[1])) for x in lines]
  94. rank = 0 # 添加该
  95. world_size = 1 # 添加
  96. # init_dist('pytorch', backend='nccl')
  97. # rank, world_size = get_dist_info()
  98. #
  99. # if rank == 0:
  100. # os.makedirs(args.tmpdir, exist_ok=True)
  101. # dist.barrier()
  102. my_part = annos
  103. # my_part = annos[rank::world_size]
  104. print("from det_model")
  105. det_model = init_detector(args.det_config, args.det_ckpt, 'cuda')
  106. assert det_model.CLASSES[0] == 'person', 'A detector trained on COCO is required'
  107. print("from pose_model")
  108. pose_model = init_pose_model(args.pose_config, args.pose_ckpt, 'cuda')
  109. n = 0
  110. for anno in tqdm(my_part):
  111. frames = extract_frame(anno['filename'])
  112. print("anno['filename", anno['filename'])
  113. det_results = detection_inference(det_model, frames)
  114. # * Get detection results for human
  115. det_results = [x[0] for x in det_results]
  116. for i, res in enumerate(det_results):
  117. # * filter boxes with small scores
  118. res = res[res[:, 4] >= args.det_score_thr]
  119. # * filter boxes with small areas
  120. box_areas = (res[:, 3] - res[:, 1]) * (res[:, 2] - res[:, 0])
  121. assert np.all(box_areas >= 0)
  122. res = res[box_areas >= args.det_area_thr]
  123. det_results[i] = res
  124. pose_results = pose_inference(pose_model, frames, det_results)
  125. shape = frames[0].shape[:2]
  126. anno['img_shape'] = anno['original_shape'] = shape
  127. anno['total_frames'] = len(frames)
  128. anno['num_person_raw'] = pose_results.shape[0]
  129. anno['keypoint'] = pose_results[..., :2].astype(np.float16)
  130. anno['keypoint_score'] = pose_results[..., 2].astype(np.float16)
  131. anno.pop('filename')
  132. mmcv.dump(my_part, osp.join(args.tmpdir, f'part_{rank}.pkl'))
  133. # dist.barrier()
  134. if rank == 0:
  135. parts = [mmcv.load(osp.join(args.tmpdir, f'part_{i}.pkl')) for i in range(world_size)]
  136. rem = len(annos) % world_size
  137. if rem:
  138. for i in range(rem, world_size):
  139. parts[i].append(None)
  140. ordered_results = []
  141. for res in zip(*parts):
  142. ordered_results.extend(list(res))
  143. ordered_results = ordered_results[:len(annos)]
  144. mmcv.dump(ordered_results, args.out)
  145. if __name__ == '__main__':
  146. # default_mmdet_root = osp.dirname(mmcv.__path__[0])
  147. # default_mmpose_root = osp.dirname(mmcv.__path__[0])
  148. main()

然后执行命令

python tools/data/custom_2d_skeleton.py --video-list  ../data/Weizmann/Weizmann.list --out  ../data/Weizmann/train.pkl

5、训练模型

根据上面生成的train.pkl和train.json、test.json文件,生成训练要用的pkl文件。

其中

dirpath = '../data/Weizmann'
pklname='train.pkl'
newpklname='Wei_xsub_stgn++.pkl'
  1. def traintest(dirpath,pklname,newpklname):
  2. os.chdir(dirpath)
  3. train = load('train.json')
  4. test = load('test.json')
  5. annotations = load(pklname)
  6. split = dict()
  7. split['xsub_train'] = [x['vid_name'] for x in train]
  8. split['xsub_val'] = [x['vid_name'] for x in test]
  9. dump(dict(split=split, annotations=annotations), newpklname)

选定要使用的模型,我选择了stgcn++,使用了configs/stgcn++/stgcn++_ntu120_xsub_hrnet/j.py

里面有几个地方修改了

  1. #num_classes=10 改成自己数据集的类别数量
  2. model = dict(
  3. type='RecognizerGCN',
  4. backbone=dict(
  5. type='STGCN',
  6. gcn_adaptive='init',
  7. gcn_with_res=True,
  8. tcn_type='mstcn',
  9. graph_cfg=dict(layout='coco', mode='spatial')),
  10. cls_head=dict(type='GCNHead', num_classes=10, in_channels=256))
  11. dataset_type = 'PoseDataset'
  12. #ann_file,改成上面存放pkl文件的路径
  13. ann_file = './data/Weizmann/wei_xsub_stgn++_ch.pkl'
  14. #下面的train_pipeline、val_pipeline和test_pipeline中num_person可以改成1,我猜是视频中人的数
  15. #量,但是没有证据
  16. train_pipeline = [
  17. dict(type='PreNormalize2D'),
  18. dict(type='GenSkeFeat', dataset='coco', feats=['j']),
  19. dict(type='UniformSample', clip_len=100),
  20. dict(type='PoseDecode'),
  21. dict(type='FormatGCNInput', num_person=1),
  22. dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
  23. dict(type='ToTensor', keys=['keypoint'])
  24. ]
  25. val_pipeline = [
  26. dict(type='PreNormalize2D'),
  27. dict(type='GenSkeFeat', dataset='coco', feats=['j']),
  28. dict(type='UniformSample', clip_len=100, num_clips=1, test_mode=True),
  29. dict(type='PoseDecode'),
  30. dict(type='FormatGCNInput', num_person=1),
  31. dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
  32. dict(type='ToTensor', keys=['keypoint'])
  33. ]
  34. test_pipeline = [
  35. dict(type='PreNormalize2D'),
  36. dict(type='GenSkeFeat', dataset='coco', feats=['j']),
  37. dict(type='UniformSample', clip_len=100, num_clips=10, test_mode=True),
  38. dict(type='PoseDecode'),
  39. dict(type='FormatGCNInput', num_person=1),
  40. dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
  41. dict(type='ToTensor', keys=['keypoint'])
  42. ]
  43. #这里的split='xsub_train'、split='xsub_val'可以按照自己写入的时候的key键进行修改,但是要保证
  44. #wei_xsub_stgn++_ch.pkl中的和这里的一致
  45. data = dict(
  46. videos_per_gpu=16,
  47. workers_per_gpu=2,
  48. test_dataloader=dict(videos_per_gpu=1),
  49. train=dict(
  50. type='RepeatDataset',
  51. times=5,
  52. dataset=dict(type=dataset_type, ann_file=ann_file, pipeline=train_pipeline, split='xsub_train')),
  53. val=dict(type=dataset_type, ann_file=ann_file, pipeline=val_pipeline, split='xsub_val'),
  54. test=dict(type=dataset_type, ann_file=ann_file, pipeline=test_pipeline, split='xsub_val'))
  55. # optimizer
  56. optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0005, nesterov=True)
  57. optimizer_config = dict(grad_clip=None)
  58. # learning policy
  59. lr_config = dict(policy='CosineAnnealing', min_lr=0, by_epoch=False)
  60. #可以修改训练的轮数total_epochs
  61. total_epochs = 100
  62. checkpoint_config = dict(interval=1)
  63. evaluation = dict(interval=1, metrics=['top_k_accuracy'])
  64. log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
  65. # runtime settings
  66. log_level = 'INFO'
  67. #work_dir为保存训练结果文件的地方,可以自己修改
  68. work_dir = './work_dirs/stgcn++/stgcn++_ntu120_xsub_hrnet/j_Wei5'

随后,运行

bash tools/dist_train.sh configs/stgcn++/stgcn++_ntu120_xsub_hrnet/j.py 1 --validate --test-last --test-best

我训练得到的最好结果如下

2022-07-29 11:02:37,424 - pyskl - INFO - Testing results of the best checkpoint
2022-07-29 11:02:37,424 - pyskl - INFO - top1_acc: 0.9000
2022-07-29 11:02:37,424 - pyskl - INFO - top5_acc: 1.0000

6、测试

注意,pth文件选用的是训练结果最好的,test-res.json得到的是每个训练视频属于类别的概率

bash tools/dist_test.sh configs/stgcn++/stgcn++_ntu120_xsub_hrnet/j.py work_dirs/stgcn++/stgcn++_ntu120_xsub_hrnet/j_Wei4/best_top1_acc_epoch_39.pth 1 --out data/Weizmann/test-res.json --eval top_k_accuracy mean_class_accuracy

运行自己训练的模型时,主要要在../tools/data/label_map文件夹下建立数据集标签名称,从小到大排列,这样得到的输出视频画面中的标签才不会错。

python demo/demo_skeleton.py video/shahar_1.avi res/shahar_1_res.mp4 
--config ../configs/stgcn++/stgcn++_ntu120_xsub_hrnet/j.py 
--checkpoint ../work_dirs/stgcn++/stgcn++_ntu120_xsub_hrnet/j_Wei4/best_top1_acc_epoch_39.pth
--label-map ../tools/data/label_map/Weizmann.txt

我还用KTH数据集进行了训练,得到结果为0.9167,也还不错了

最后

stgcn++一个视频只能给出一个动作标签,如果想要实现识别一段视频中的多个动作,需要将视频分段。比如说设置200帧为一段,然后将一段视频输入到模型中,得到识别结果。这样的硬切分,会导致动作识别效果不好。也可以识别多人的动作,在姿态识别和追踪那里改一下就行了,这个不多说了,就是数据处理的问题。

我当时使用自建的数据集训练模型,准确率很高,现在想想应该是过拟合了。过拟合有很多方法解决,我那只是个demo,也就没有再做了。

还有,这博客看看就行了,我当时也只是做成demo看看,学习一下用自己的数据集训练模型。评论区友好讨论,我看到会回复。

但是要源码的不太行,我第一次编辑这个博客已经是快三个月之前了,你是为什么觉得我会为了你找项目代码。而且pyskl本来就是个开源项目,上面过程也写得差不多了,出现别的问题自己再搜一些,多看看别人的博客。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/107945
推荐阅读
相关标签
  

闽ICP备14008679号