当前位置:   article > 正文

Pyskl自定义数据集

pyskl

一.Pyskl安装

Pyskl地址
需要安装好anaconda
Linux安装

git clone https://github.com/kennymckormick/pyskl.git
cd pyskl
# This command runs well with conda 22.9.0, if you are running an early conda version and got some errors, try to update your conda first
conda env create -f pyskl.yaml
conda activate pyskl
pip install -e .

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

Windows安装,见我的另一篇博客 windows环境安装pyskl

安装成功后可以测试一下

# Running the demo with STGCN++ trained on NTURGB+D 120 (Joint Modality). The input file is demo/ntu_sample.avi, the output file is demo/demo.mp4
python demo/demo_skeleton.py demo/ntu_sample.avi demo/demo.mp4 --config configs/stgcn++/stgcn++_ntu120_xsub_hrnet/j.py --checkpoint http://download.openmmlab.com/mmaction/pyskl/ckpt/stgcnpp/stgcnpp_ntu120_xsub_hrnet/j.pth

  • 1
  • 2
  • 3

输出为
在这里插入图片描述

二.数据集准备

参考的是这位博主参考博客
按他说的做就可以了

三.生成train.jaon和test.json

import os
import decord
import json
def writeJson(path_train, jsonpath):
    outpot_list = []
    trainfile_list = os.listdir(path_train)
    for train_name in trainfile_list:
        traindit = {}
        sp = train_name.split('_')
        traindit['vid_name'] = train_name.replace('.avi', '')
        traindit['label'] = int(sp[1].replace('.avi', ''))
        traindit['start_frame'] = 0

        video_path = os.path.join(path_train, train_name)
        vid = decord.VideoReader(video_path)
        traindit['end_frame'] = len(vid)
        outpot_list.append(traindit.copy())
    with open(jsonpath, 'w') as outfile:
        json.dump(outpot_list, outfile)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

使用如下:

# 第一个参数是数据集的路径,第二个参数时保存的json文件名
 writeJson('./data/Weizmann/test', 'test.json')
  • 1
  • 2

生成的内容为:
[{“vid_name”: “moshe_0”, “label”: 0, “start_frame”: 0, “end_frame”: 62}, {“vid_name”: “moshe_1”, “label”: 1, “start_frame”: 0, “end_frame”: 61}, {“vid_name”: “moshe_2”, “label”: 2, “start_frame”: 0, “end_frame”: 105}, {“vid_name”: “moshe_3”, “label”: 3, “start_frame”: 0, “end_frame”: 39}, {“vid_name”: “moshe_4”, “label”: 4, “start_frame”: 0, “end_frame”: 45}, {“vid_name”: “moshe_5”, “label”: 5, “start_frame”: 0, “end_frame”: 64}, {“vid_name”: “moshe_6”, “label”: 6, “start_frame”: 0, “end_frame”: 46}, {“vid_name”: “moshe_7”, “label”: 7, “start_frame”: 0, “end_frame”: 77}, {“vid_name”: “moshe_8”, “label”: 8, “start_frame”: 0, “end_frame”: 111}, {“vid_name”: “moshe_9”, “label”: 9, “start_frame”: 0, “end_frame”: 60}, {“vid_name”: “shahar_0”, “label”: 0, “start_frame”: 0, “end_frame”: 59}, {“vid_name”: “shahar_1”, “label”: 1, “start_frame”: 0, “end_frame”: 61}, {“vid_name”: “shahar_2”, “label”: 2, “start_frame”: 0, “end_frame”: 103}, {“vid_name”: “shahar_3”, “label”: 3, “start_frame”: 0, “end_frame”: 38}, {“vid_name”: “shahar_4”, “label”: 4, “start_frame”: 0, “end_frame”: 56}, {“vid_name”: “shahar_5”, “label”: 5, “start_frame”: 0, “end_frame”: 67}, {“vid_name”: “shahar_6”, “label”: 6, “start_frame”: 0, “end_frame”: 43}, {“vid_name”: “shahar_7”, “label”: 7, “start_frame”: 0, “end_frame”: 68}, {“vid_name”: “shahar_8”, “label”: 8, “start_frame”: 0, “end_frame”: 120}, {“vid_name”: “shahar_9”, “label”: 9, “start_frame”: 0, “end_frame”: 61}]

三.生成tools/data/custom_2d_skeleton.py需要的list文件

代码如下:

def writeList(dirpath,name):
    path_train = os.path.join(dirpath, 'train')
    path_test = os.path.join(dirpath, 'test')
    trainfile_list=os.listdir(path_train)
    testfile_list=os.listdir(path_test)
 
    train=[]
    for train_name in trainfile_list:
        traindit={}
        sp=train_name.split('_')
 
        traindit['vid_name']= train_name
        traindit['label'] = sp[1].replace('.avi','')
        train.append(traindit)
    test = []
    for test_name in testfile_list:
        testdit={}
        sp=test_name.split('_')
        testdit['vid_name']= test_name
        testdit['label'] = sp[1].replace('.avi','')
        test.append(testdit)
 
    tmpl1 =os.path.join(path_train,'{}')
    lines1 = [(tmpl1 + ' {}').format(x['vid_name'], x['label']) for x in train]
 
    tmpl2 = os.path.join(path_test, '{}')
    lines2 = [(tmpl2 + ' {}').format(x['vid_name'], x['label']) for x in test]
    lines=lines1+lines2
    mwlines(lines, os.path.join(dirpath,name))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

path是数据集路径 dirpath = ‘./data/Weizmann’
name为生成的list文件名称,这里为 ‘Weizmann’

 writeList('./data/Weizmann.list', 'Weizmann.list')
  • 1

四.调用custom_2d_skeleton.py,生成训练模型要用的pkl文件

直接用的 大脸猫105这位博主改好的代码,只不过自己修改了一点,如果我的代码有问题请回到大脸猫105

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
# import pdb
import pyskl
from mmdet.apis import inference_detector, init_detector
from mmpose.apis import inference_top_down_pose_model, init_pose_model
import decord
import mmcv
import numpy as np
# import torch.distributed as dist
from tqdm import tqdm
import mmdet
# import mmpose
# from pyskl.smp import mrlines
import cv2

from pyskl.smp import mrlines


def extract_frame(video_path):
    vid = decord.VideoReader(video_path)
    return [x.asnumpy() for x in vid]


def detection_inference(model, frames):
    model = model.cuda()
    results = []
    for frame in frames:
        result = inference_detector(model, frame)
        results.append(result)
    return results


def pose_inference(model, frames, det_results):
    model = model.cuda()
    assert len(frames) == len(det_results)
    total_frames = len(frames)
    num_person = max([len(x) for x in det_results])
    kp = np.zeros((num_person, total_frames, 17, 3), dtype=np.float32)

    for i, (f, d) in enumerate(zip(frames, det_results)):
        # Align input format
        d = [dict(bbox=x) for x in list(d)]
        pose = inference_top_down_pose_model(model, f, d, format='xyxy')[0]
        for j, item in enumerate(pose):
            kp[j, i] = item['keypoints']
    return kp

pyskl_root = osp.dirname(pyskl.__path__[0])
default_det_config = f'{pyskl_root}/demo/faster_rcnn_r50_fpn_1x_coco-person.py'
default_det_ckpt = (
    'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco-person/'
    'faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth')
default_pose_config = f'{pyskl_root}/demo/hrnet_w32_coco_256x192.py'
default_pose_ckpt = (
    'https://download.openmmlab.com/mmpose/top_down/hrnet/'
    'hrnet_w32_coco_256x192-c78dce93_20200708.pth')


def parse_args():
    parser = argparse.ArgumentParser(
        description='Generate 2D pose annotations for a custom video dataset')
    # * Both mmdet and mmpose should be installed from source
    # parser.add_argument('--mmdet-root', type=str, default=default_mmdet_root)
    # parser.add_argument('--mmpose-root', type=str, default=default_mmpose_root)

    # parser.add_argument('--det-config', type=str, default='../refe/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco-person.py')
    # parser.add_argument('--det-ckpt', type=str,
    #                     default='../refe/faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth')
    parser.add_argument(
        '--det-config',
        # default='../refe/faster_rcnn_r50_fpn_2x_coco.py',
        default=default_det_config,
        help='human detection config file path (from mmdet)')

    parser.add_argument(
        '--det-ckpt',
        default=default_det_ckpt,
        help='human detection checkpoint file/url')

    parser.add_argument('--pose-config', type=str, default=default_pose_config)
    parser.add_argument('--pose-ckpt', type=str, default=default_pose_ckpt)
    # * Only det boxes with score larger than det_score_thr will be kept
    parser.add_argument('--det-score-thr', type=float, default=0.7)
    # * Only det boxes with large enough sizes will be kept,
    parser.add_argument('--det-area-thr', type=float, default=1300)
    # * Accepted formats for each line in video_list are:
    # * 1. "xxx.mp4" ('label' is missing, the dataset can be used for inference, but not training)
    # * 2. "xxx.mp4 label" ('label' is an integer (category index),
    # * the result can be used for both training & testing)
    # * All lines should take the same format.
    parser.add_argument('--video-list', type=str, help='the list of source videos')
    # * out should ends with '.pkl'
    parser.add_argument('--out', type=str, help='output pickle name')
    parser.add_argument('--tmpdir', type=str, default='tmp')
    parser.add_argument('--local_rank', type=int, default=1)
    # pdb.set_trace()

    # if 'RANK' not in os.environ:
    #     os.environ['RANK'] = str(args.local_rank)
    #     os.environ['WORLD_SIZE'] = str(1)
    # os.environ['MASTER_ADDR'] = 'localhost'
    # os.environ['MASTER_PORT'] = '12345'

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    assert args.out.endswith('.pkl')

    lines = mrlines(args.video_list)
    lines = [x.split() for x in lines]

    assert len(lines[0]) in [1, 2]
    if len(lines[0]) == 1:
        annos = [dict(frame_dir=osp.basename(x[0]).split('.')[0], filename=x[0]) for x in lines]
    else:
        annos = [dict(frame_dir=osp.basename(x[0]).split('.')[0], filename=x[0], label=int(x[1])) for x in lines]

    rank = 0  # 添加该
    world_size = 1  # 添加

    # init_dist('pytorch', backend='nccl')
    # rank, world_size = get_dist_info()
    #
    # if rank == 0:
    #     os.makedirs(args.tmpdir, exist_ok=True)
    # dist.barrier()
    my_part = annos
    # my_part = annos[rank::world_size]
    print("from det_model")
    det_model = init_detector(args.det_config, args.det_ckpt, 'cuda')
    assert det_model.CLASSES[0] == 'person', 'A detector trained on COCO is required'
    print("from pose_model")
    pose_model = init_pose_model(args.pose_config, args.pose_ckpt, 'cuda')
    n = 0
    for anno in tqdm(my_part):
        frames = extract_frame(anno['filename'])
        print("anno['filename", anno['filename'])
        det_results = detection_inference(det_model, frames)
        # * Get detection results for human
        det_results = [x[0] for x in det_results]
        for i, res in enumerate(det_results):
            # * filter boxes with small scores
            res = res[res[:, 4] >= args.det_score_thr]
            # * filter boxes with small areas
            box_areas = (res[:, 3] - res[:, 1]) * (res[:, 2] - res[:, 0])
            assert np.all(box_areas >= 0)
            res = res[box_areas >= args.det_area_thr]
            det_results[i] = res

        pose_results = pose_inference(pose_model, frames, det_results)
        shape = frames[0].shape[:2]
        anno['img_shape'] = anno['original_shape'] = shape
        anno['total_frames'] = len(frames)
        anno['num_person_raw'] = pose_results.shape[0]
        anno['keypoint'] = pose_results[..., :2].astype(np.float16)
        anno['keypoint_score'] = pose_results[..., 2].astype(np.float16)
        anno.pop('filename')

    mmcv.dump(my_part, osp.join(args.tmpdir, f'part_{rank}.pkl'))
    # dist.barrier()

    if rank == 0:
        parts = [mmcv.load(osp.join(args.tmpdir, f'part_{i}.pkl')) for i in range(world_size)]
        rem = len(annos) % world_size
        if rem:
            for i in range(rem, world_size):
                parts[i].append(None)

        ordered_results = []
        for res in zip(*parts):
            ordered_results.extend(list(res))
        ordered_results = ordered_results[:len(annos)]
        mmcv.dump(ordered_results, args.out)


if __name__ == '__main__':
    # default_mmdet_root = osp.dirname(mmcv.__path__[0])
    # default_mmpose_root = osp.dirname(mmcv.__path__[0])
    main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186

执行下面命令可以提取到2d骨架:

python tools/data/custom_2d_skeleton.py --video-list  ./data/Weizmann/Weizmann.list --out  ./data/Weizmann/train.pkl
  • 1

生成过程为:
在这里插入图片描述

5.训练模型

根据上面生成的train.pkl和train.json、test.json文件,生成训练要用的pkl文件。

from mmcv import load, dump
def traintest(dirpath, pklname, newpklname):
    os.chdir(dirpath)
    train = load('train.json')
    test = load('test.json')
    annotations = load(pklname)
    split = dict()
    split['xsub_train'] = [x['vid_name'] for x in train]
    split['xsub_val'] = [x['vid_name'] for x in test]
    dump(dict(split=split, annotations=annotations), newpklname)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

选定要使用的模型,我选择了stgcn++,使用了configs/stgcn++/stgcn++_ntu120_xsub_hrnet/j.py

里面有几个地方修改了

#num_classes=10  改成自己数据集的类别数量
model = dict(
    type='RecognizerGCN',
    backbone=dict(
        type='STGCN',
        gcn_adaptive='init',
        gcn_with_res=True,
        tcn_type='mstcn',
        graph_cfg=dict(layout='coco', mode='spatial')),
    cls_head=dict(type='GCNHead', num_classes=10, in_channels=256))
 
dataset_type = 'PoseDataset'
#ann_file,改成上面存放pkl文件的路径
ann_file = './data/Weizmann/wei_xsub_stgn++_ch.pkl'
#下面的train_pipeline、val_pipeline和test_pipeline中num_person可以改成1,我猜是视频中人的数
#量,但是没有证据
train_pipeline = [
    dict(type='PreNormalize2D'),
    dict(type='GenSkeFeat', dataset='coco', feats=['j']),
    dict(type='UniformSample', clip_len=100),
    dict(type='PoseDecode'),
    dict(type='FormatGCNInput', num_person=1),
    dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
    dict(type='ToTensor', keys=['keypoint'])
]
val_pipeline = [
    dict(type='PreNormalize2D'),
    dict(type='GenSkeFeat', dataset='coco', feats=['j']),
    dict(type='UniformSample', clip_len=100, num_clips=1, test_mode=True),
    dict(type='PoseDecode'),
    dict(type='FormatGCNInput', num_person=1),
    dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
    dict(type='ToTensor', keys=['keypoint'])
]
test_pipeline = [
    dict(type='PreNormalize2D'),
    dict(type='GenSkeFeat', dataset='coco', feats=['j']),
    dict(type='UniformSample', clip_len=100, num_clips=10, test_mode=True),
    dict(type='PoseDecode'),
    dict(type='FormatGCNInput', num_person=1),
    dict(type='Collect', keys=['keypoint', 'label'], meta_keys=[]),
    dict(type='ToTensor', keys=['keypoint'])
]
#这里的split='xsub_train'、split='xsub_val'可以按照自己写入的时候的key键进行修改,但是要保证
#wei_xsub_stgn++_ch.pkl中的和这里的一致
data = dict(
    videos_per_gpu=16,
    workers_per_gpu=2,
    test_dataloader=dict(videos_per_gpu=1),
    train=dict(
        type='RepeatDataset',
        times=5,
        dataset=dict(type=dataset_type, ann_file=ann_file, pipeline=train_pipeline, split='xsub_train')),
        
    val=dict(type=dataset_type, ann_file=ann_file, pipeline=val_pipeline, split='xsub_val'),
    test=dict(type=dataset_type, ann_file=ann_file, pipeline=test_pipeline, split='xsub_val'))
    
# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0005, nesterov=True)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='CosineAnnealing', min_lr=0, by_epoch=False)
#可以修改训练的轮数total_epochs
total_epochs = 100
checkpoint_config = dict(interval=1)
evaluation = dict(interval=1, metrics=['top_k_accuracy'])
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
 
# runtime settings
log_level = 'INFO'
#work_dir为保存训练结果文件的地方,可以自己修改
work_dir = './work_dirs/stgcn++/stgcn++_ntu120_xsub_hrnet/j_Wei5'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72

执行训练命令:

bash tools/dist_train.sh configs/stgcn++/stgcn++_ntu120_xsub_hrnet/j.py 1 --validate --test-last --test-best
  • 1

相关训练,测试操作可以看项目文档readme.md

最后

最后得到的全部文件如下所示(不包括训练好的权重):
在这里插入图片描述
上面的用到的几个函数我将他们都写在同一个util.py文件里了,直接调用就可以了。

import os
import decord
import json
from mmcv import load, dump

from pyskl.smp import mwlines


def writeJson(path_train, jsonpath):
    outpot_list = []
    trainfile_list = os.listdir(path_train)
    for train_name in trainfile_list:
        traindit = {}
        sp = train_name.split('_')
        traindit['vid_name'] = train_name.replace('.avi', '')
        traindit['label'] = int(sp[1].replace('.avi', ''))
        traindit['start_frame'] = 0

        video_path = os.path.join(path_train, train_name)
        vid = decord.VideoReader(video_path)
        traindit['end_frame'] = len(vid)
        outpot_list.append(traindit.copy())
    with open(jsonpath, 'w') as outfile:
        json.dump(outpot_list, outfile)


def writeList(dirpath, name):
    path_train = os.path.join(dirpath, 'train')
    path_test = os.path.join(dirpath, 'test')
    trainfile_list = os.listdir(path_train)
    testfile_list = os.listdir(path_test)

    train = []
    for train_name in trainfile_list:
        traindit = {}
        sp = train_name.split('_')

        traindit['vid_name'] = train_name
        traindit['label'] = sp[1].replace('.avi', '')
        train.append(traindit)
    test = []
    for test_name in testfile_list:
        testdit = {}
        sp = test_name.split('_')
        testdit['vid_name'] = test_name
        testdit['label'] = sp[1].replace('.avi', '')
        test.append(testdit)

    tmpl1 = os.path.join(path_train, '{}')
    lines1 = [(tmpl1 + ' {}').format(x['vid_name'], x['label']) for x in train]

    tmpl2 = os.path.join(path_test, '{}')
    lines2 = [(tmpl2 + ' {}').format(x['vid_name'], x['label']) for x in test]
    lines = lines1 + lines2
    mwlines(lines, os.path.join(dirpath, name))


def traintest(dirpath, pklname, newpklname):
    os.chdir(dirpath)
    train = load('train.json')
    test = load('test.json')
    annotations = load(pklname)
    split = dict()
    split['xsub_train'] = [x['vid_name'] for x in train]
    split['xsub_val'] = [x['vid_name'] for x in test]
    dump(dict(split=split, annotations=annotations), newpklname)


if __name__ == '__main__':

    dirpath = './data/Weizmann'
    pklname = 'train.pkl'
    newpklname = 'Wei_xsub_stgn++.pkl'
    # writeJson('./data/Weizmann/test', 'test.json')   
    traintest(dirpath, pklname, newpklname)
    # writeList('./data/Weizmann.list', 'Weizmann.list')  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76

使用自己训练好的模型生成demo
需要在 ./tools/data/label_map文件夹下建立数据集标签名称,从小到大排列,这样得到的输出视频画面中的标签才不会错。

python demo/demo_skeleton.py video/shahar_1.avi video/shahar_1_demo.mp4
--config ./configs/stgcn++/stgcn++_ntu120_xsub_hrnet/j.py
--checkpoint ./work_dirs/stgcn++/stgcn++_ntu120_xsub_hrnet/j_Wei5/best_top1_acc_epoch_11.pth
--label-map ./tools/data/label_map/Weizmann.txt
  • 1
  • 2
  • 3
  • 4

输出结果如下:

shahar_1_demo

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/107932
推荐阅读
相关标签
  

闽ICP备14008679号