当前位置:   article > 正文

用tsm动作识别训练自己的数据集_annotations/train.list

annotations/train.list

1.准备数据

类似ucf101数据集格式

1.1生成class.txt train.txt test.txt

  1. #makelabel.py
  2. import os
  3. #图片数据集路径
  4. baseDir = "D:\Downloads\PaddleVideo-develop\posvideo0/videos"
  5. #标注文件输出文件夹
  6. targetDir = "D:\Downloads\PaddleVideo-develop\posvideo0/annotations"
  7. if not os.path.exists(targetDir):
  8. os.makedirs(targetDir)
  9. #classind.txt
  10. labels = os.listdir(baseDir)
  11. with open(os.path.join(targetDir,"classInd.txt"),"w+") as f:
  12. for i in range(len(labels)):
  13. line = "{} {}\n".format(i+1,labels[i])
  14. f.write(line)
  15. #label
  16. labels = dict([(labels[i],i+1) for i in range(len(labels))])
  17. print(labels)
  18. #trainlist.txt
  19. with open(os.path.join(targetDir,"trainlist.txt"),"w+") as f:
  20. for label in labels:
  21. index = os.listdir(os.path.join(baseDir,label))
  22. for i in index:
  23. line = "{}/{} {}\n".format(label,i,labels[label])
  24. f.write(line)
  25. #testlist.txt
  26. with open(os.path.join(targetDir,"testlist.txt"),"w+") as f:
  27. for label in labels:
  28. index = os.listdir(os.path.join(baseDir,label))
  29. for i in index:
  30. line = "{}/{}\n".format(label,i)
  31. f.write(line)

下载模型代码: http:// https://github. com/mit-han-lab/temporal-shift-module.

1.2对视频进行抽帧:在vid2img_ucf101.py中

  1. from __future__ import print_function, division
  2. import os
  3. import sys
  4. import subprocess
  5. def class_process(dir_path, dst_dir_path, class_name):
  6. class_path = os.path.join(dir_path, class_name)
  7. if not os.path.isdir(class_path):
  8. return
  9. dst_class_path = os.path.join(dst_dir_path, class_name)
  10. if not os.path.exists(dst_class_path):
  11. os.mkdir(dst_class_path)
  12. for file_name in os.listdir(class_path):
  13. if '.avi' not in file_name:
  14. continue
  15. name, ext = os.path.splitext(file_name)
  16. dst_directory_path = os.path.join(dst_class_path, name)
  17. video_file_path = os.path.join(class_path, file_name)
  18. try:
  19. if os.path.exists(dst_directory_path):
  20. if not os.path.exists(os.path.join(dst_directory_path, 'image_00001.jpg')):
  21. subprocess.call('rm -r \"{}\"'.format(dst_directory_path), shell=True)
  22. print('remove {}'.format(dst_directory_path))
  23. os.mkdir(dst_directory_path)
  24. else:
  25. continue
  26. else:
  27. os.mkdir(dst_directory_path)
  28. except:
  29. print(dst_directory_path)
  30. continue
  31. cmd = 'ffmpeg -i \"{}\" -vf scale=-1:480 \"{}/image_%05d.jpg\"'.format(video_file_path, dst_directory_path)
  32. print(cmd)
  33. subprocess.call(cmd, shell=True)
  34. print('\n')
  35. if __name__=="__main__":
  36. dir_path = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data3/videos' # 视频文件总路径
  37. dst_dir_path = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data3/rawframes' # 抽帧后图片存放路径
  38. for class_name in os.listdir(dir_path):
  39. class_process(dir_path, dst_dir_path, class_name)

如果是mp4文件将.avi改成.mp4

结果:

 1.3生成frames列表,用于训练,据说用帧来训练比较快

打开gen_label_ucf101.py文件,修改路径

  1. import os
  2. import glob
  3. import fnmatch
  4. import random
  5. import sys
  6. root = r"data4/rawframes" # 抽帧后的图片存放目录文件夹,用于写到txt文件中在构建数据集的时候读取
  7. def parse_ucf_splits():
  8. class_ind = [x.strip().split() for x in open('D:\learning\ActionSeqmentation/temporal-shift-module-master\data4/annotations\classInd.txt')] # 类别txt
  9. class_mapping = {x[1]:int(x[0])-1 for x in class_ind}
  10. def line2rec(line):
  11. items = line.strip().split('/')
  12. label = class_mapping[items[0]]
  13. vid = items[1].split('.')[0]
  14. return vid, label
  15. splits = []
  16. for i in range(1, 4):
  17. train_list = [line2rec(x) for x in open(r'D:\learning\ActionSeqmentation\temporal-shift-module-master\data4\annotations\trainlist.txt'.format(i))] # 训练集txt
  18. test_list = [line2rec(x) for x in open(r'D:\learning\ActionSeqmentation\temporal-shift-module-master\data4\annotations\testlist.txt'.format(i))] # 测试集txt
  19. splits.append((train_list, test_list))
  20. return splits
  21. split_parsers = dict()
  22. split_parsers['ucf101'] = parse_ucf_splits()
  23. def parse_split_file(dataset):
  24. sp = split_parsers[dataset]
  25. return tuple(sp)
  26. def parse_directory(path, rgb_prefix='image_', flow_x_prefix='flow_x_', flow_y_prefix='flow_y_'):
  27. """
  28. Parse directories holding extracted frames from standard benchmarks
  29. """
  30. print('parse frames under folder {}'.format(path))
  31. frame_folders = []
  32. frame = glob.glob(os.path.join(path, '*'))
  33. for frame_name in frame:
  34. frame_path = glob.glob(os.path.join(frame_name, '*'))
  35. frame_folders.extend(frame_path)
  36. def count_files(directory, prefix_list):
  37. lst = os.listdir(directory)
  38. cnt_list = [len(fnmatch.filter(lst, x+'*')) for x in prefix_list]
  39. return cnt_list
  40. # check RGB
  41. rgb_counts = {}
  42. flow_counts = {}
  43. dir_dict = {}
  44. for i,f in enumerate(frame_folders):
  45. all_cnt = count_files(f, (rgb_prefix, flow_x_prefix, flow_y_prefix))
  46. k = f.split('\\')[-1]
  47. rgb_counts[k] = all_cnt[0]
  48. dir_dict[k] = f
  49. x_cnt = all_cnt[1]
  50. y_cnt = all_cnt[2]
  51. if x_cnt != y_cnt:
  52. raise ValueError('x and y direction have different number of flow images. video: '+f)
  53. flow_counts[k] = x_cnt
  54. if i % 200 == 0:
  55. print('{} videos parsed'.format(i))
  56. print('frame folder analysis done')
  57. return dir_dict, rgb_counts, flow_counts
  58. def build_split_list(split_tuple, frame_info, split_idx, shuffle=False):
  59. split = split_tuple[split_idx]
  60. def build_set_list(set_list):
  61. rgb_list, flow_list = list(), list()
  62. for item in set_list:
  63. frame_dir = frame_info[0][item[0]]
  64. frame_dir = root +'/'+ frame_dir.split('\\')[-2] +'/'+ frame_dir.split('\\')[-1]
  65. rgb_cnt = frame_info[1][item[0]]
  66. flow_cnt = frame_info[2][item[0]]
  67. rgb_list.append('{} {} {}\n'.format(frame_dir, rgb_cnt, item[1]))
  68. flow_list.append('{} {} {}\n'.format(frame_dir, flow_cnt, item[1]))
  69. if shuffle:
  70. random.shuffle(rgb_list)
  71. random.shuffle(flow_list)
  72. return rgb_list, flow_list
  73. train_rgb_list, train_flow_list = build_set_list(split[0])
  74. test_rgb_list, test_flow_list = build_set_list(split[1])
  75. return (train_rgb_list, test_rgb_list), (train_flow_list, test_flow_list)
  76. spl = parse_split_file('ucf101')
  77. f_info = parse_directory(r"D:\learning\ActionSeqmentation\temporal-shift-module-master\data4\rawframes") # 存放抽帧后的图片
  78. out_path = r"D:\learning\ActionSeqmentation\temporal-shift-module-master\data4" # 标签路径
  79. dataset = "ucf101"
  80. for i in range(max(3,len(spl))):
  81. lists = build_split_list(spl,f_info,i)
  82. open(os.path.join(out_path, '{}_rgb_train_split_{}.txt'.format(dataset, i + 1)), 'w').writelines(lists[0][0])
  83. open(os.path.join(out_path, '{}_rgb_val_split_{}.txt'.format(dataset, i + 1)), 'w').writelines(lists[0][1])
  84. # open(os.path.join(out_path, '{}_flow_train_split_{}.txt'.format(dataset, i + 1)), 'w').writelines(lists[1][0])
  85. # open(os.path.join(out_path, '{}_flow_val_split_{}.txt'.format(dataset, i + 1)), 'w').writelines(lists[1][1])

 2.修改配置文件

下载预训练权重,我是用迅雷下载的https://hanlab.mit.edu/projects/tsm/models/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e100_dense.pth

更改ops/dataset_config.py中文件

  1. # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
  2. # arXiv:1811.08383
  3. # Ji Lin*, Chuang Gan, Song Han
  4. # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
  5. import os
  6. ROOT_DATASET = 'D:\learning\ActionSeqmentation/temporal-shift-module-master' # '/data/jilin/'
  7. def return_ucf101(modality):
  8. filename_categories = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data4/annotations\classInd.txt'
  9. if modality == 'RGB':
  10. root_data = ROOT_DATASET + '/'
  11. filename_imglist_train = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data4/ucf101_rgb_train_split_1.txt'
  12. filename_imglist_val = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data4/ucf101_rgb_val_split_1.txt'
  13. prefix = 'image_{:05d}.jpg'
  14. elif modality == 'Flow':
  15. root_data = ROOT_DATASET + 'UCF101/jpg'
  16. filename_imglist_train = 'UCF101/file_list/ucf101_flow_train_split_1.txt'
  17. filename_imglist_val = 'UCF101/file_list/ucf101_flow_val_split_1.txt'
  18. prefix = 'flow_{}_{:05d}.jpg'
  19. else:
  20. raise NotImplementedError('no such modality:' + modality)
  21. return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
  22. def return_hmdb51(modality):
  23. filename_categories = 51
  24. if modality == 'RGB':
  25. root_data = ROOT_DATASET + 'HMDB51/images'
  26. filename_imglist_train = 'HMDB51/splits/hmdb51_rgb_train_split_1.txt'
  27. filename_imglist_val = 'HMDB51/splits/hmdb51_rgb_val_split_1.txt'
  28. prefix = 'img_{:05d}.jpg'
  29. elif modality == 'Flow':
  30. root_data = ROOT_DATASET + 'HMDB51/images'
  31. filename_imglist_train = 'HMDB51/splits/hmdb51_flow_train_split_1.txt'
  32. filename_imglist_val = 'HMDB51/splits/hmdb51_flow_val_split_1.txt'
  33. prefix = 'flow_{}_{:05d}.jpg'
  34. else:
  35. raise NotImplementedError('no such modality:' + modality)
  36. return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
  37. def return_something(modality):
  38. filename_categories = 'something/v1/category.txt'
  39. if modality == 'RGB':
  40. root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1'
  41. filename_imglist_train = 'something/v1/train_videofolder.txt'
  42. filename_imglist_val = 'something/v1/val_videofolder.txt'
  43. prefix = '{:05d}.jpg'
  44. elif modality == 'Flow':
  45. root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1-flow'
  46. filename_imglist_train = 'something/v1/train_videofolder_flow.txt'
  47. filename_imglist_val = 'something/v1/val_videofolder_flow.txt'
  48. prefix = '{:06d}-{}_{:05d}.jpg'
  49. else:
  50. print('no such modality:'+modality)
  51. raise NotImplementedError
  52. return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
  53. def return_somethingv2(modality):
  54. filename_categories = 'something/v2/category.txt'
  55. if modality == 'RGB':
  56. root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-frames'
  57. filename_imglist_train = 'something/v2/train_videofolder.txt'
  58. filename_imglist_val = 'something/v2/val_videofolder.txt'
  59. prefix = '{:06d}.jpg'
  60. elif modality == 'Flow':
  61. root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-flow'
  62. filename_imglist_train = 'something/v2/train_videofolder_flow.txt'
  63. filename_imglist_val = 'something/v2/val_videofolder_flow.txt'
  64. prefix = '{:06d}.jpg'
  65. else:
  66. raise NotImplementedError('no such modality:'+modality)
  67. return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
  68. def return_jester(modality):
  69. filename_categories = 'jester/category.txt'
  70. if modality == 'RGB':
  71. prefix = '{:05d}.jpg'
  72. root_data = ROOT_DATASET + 'jester/20bn-jester-v1'
  73. filename_imglist_train = 'jester/train_videofolder.txt'
  74. filename_imglist_val = 'jester/val_videofolder.txt'
  75. else:
  76. raise NotImplementedError('no such modality:'+modality)
  77. return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
  78. def return_kinetics(modality):
  79. filename_categories = 400
  80. if modality == 'RGB':
  81. root_data = ROOT_DATASET + 'kinetics/images'
  82. filename_imglist_train = 'kinetics/labels/train_videofolder.txt'
  83. filename_imglist_val = 'kinetics/labels/val_videofolder.txt'
  84. prefix = 'img_{:05d}.jpg'
  85. else:
  86. raise NotImplementedError('no such modality:' + modality)
  87. return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
  88. def return_dataset(dataset, modality):
  89. dict_single = {'jester': return_jester, 'something': return_something, 'somethingv2': return_somethingv2,
  90. 'ucf101': return_ucf101, 'hmdb51': return_hmdb51,
  91. 'kinetics': return_kinetics }
  92. if dataset in dict_single:
  93. file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality)
  94. else:
  95. raise ValueError('Unknown dataset '+dataset)
  96. file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train)
  97. file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val)
  98. if isinstance(file_categories, str):
  99. file_categories = os.path.join(ROOT_DATASET, file_categories)
  100. with open(file_categories) as f:
  101. lines = f.readlines()
  102. categories = [item.rstrip() for item in lines]
  103. else: # number of categories
  104. categories = [None] * file_categories
  105. n_class = len(categories)
  106. print('{}: {} classes'.format(dataset, n_class))
  107. return n_class, file_imglist_train, file_imglist_val, root_data, prefix

更改opts.py

  1. # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
  2. # arXiv:1811.08383
  3. # Ji Lin*, Chuang Gan, Song Han
  4. # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
  5. import argparse
  6. parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks")
  7. parser.add_argument('--dataset',
  8. default='ucf101',
  9. type=str)
  10. parser.add_argument('--modality', type=str,
  11. default='RGB',
  12. choices=['RGB', 'Flow'])
  13. parser.add_argument('--train_list', type=str, default="data4/ucf101_rgb_train_split_1.txt")
  14. parser.add_argument('--val_list', type=str, default="data4/ucf101_rgb_val_split_1.txt")
  15. parser.add_argument('--root_path', type=str, default="D:\learning\ActionSeqmentation/temporal-shift-module-master")
  16. # ========================= Model Configs ==========================
  17. parser.add_argument('--arch', type=str, default="resnet50")
  18. parser.add_argument('--num_segments', type=int, default=8)
  19. parser.add_argument('--consensus_type', type=str, default='avg')
  20. parser.add_argument('--k', type=int, default=3)
  21. parser.add_argument('--dropout', '--do', default=0.8, type=float,
  22. metavar='DO', help='dropout ratio (default: 0.5)')
  23. parser.add_argument('--loss_type', type=str, default="nll",
  24. choices=['nll'])
  25. parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame")
  26. parser.add_argument('--suffix', type=str, default=None)
  27. parser.add_argument('--pretrain', type=str, default='imagenet')
  28. parser.add_argument('--tune_from', type=str, default='D:\learning\ActionSeqmentation/temporal-shift-module-master\pretrain\TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e100_dense.pth', help='fine-tune from checkpoint')
  29. # ========================= Learning Configs ==========================
  30. parser.add_argument('--epochs', default=20, type=int, metavar='N',
  31. help='number of total epochs to run')
  32. parser.add_argument('-b', '--batch-size', default=4, type=int,
  33. metavar='N', help='mini-batch size (default: 256)')
  34. parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
  35. metavar='LR', help='initial learning rate')
  36. parser.add_argument('--lr_type', default='step', type=str,
  37. metavar='LRtype', help='learning rate type')
  38. parser.add_argument('--lr_steps', default=[10, 20], type=float, nargs="+",
  39. metavar='LRSteps', help='epochs to decay learning rate by 10')
  40. parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
  41. help='momentum')
  42. parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
  43. metavar='W', help='weight decay (default: 5e-4)')
  44. parser.add_argument('--clip-gradient', '--gd', default=None, type=float,
  45. metavar='W', help='gradient norm clipping (default: disabled)')
  46. parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true")
  47. # ========================= Monitor Configs ==========================
  48. parser.add_argument('--print-freq', '-p', default=20, type=int,
  49. metavar='N', help='print frequency (default: 10)')
  50. parser.add_argument('--eval-freq', '-ef', default=1, type=int,
  51. metavar='N', help='evaluation frequency (default: 5)')
  52. # ========================= Runtime Configs ==========================
  53. parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
  54. help='number of data loading workers (default: 8)')
  55. parser.add_argument('--resume', default='', type=str, metavar='PATH',
  56. help='path to latest checkpoint (default: none)')
  57. parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
  58. help='evaluate model on validation set')
  59. parser.add_argument('--snapshot_pref', type=str, default="")
  60. parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
  61. help='manual epoch number (useful on restarts)')
  62. parser.add_argument('--gpus', nargs='+', type=int, default=None)
  63. parser.add_argument('--flow_prefix', default="", type=str)
  64. parser.add_argument('--root_log',type=str, default='log')
  65. parser.add_argument('--root_model', type=str, default='checkpoint')
  66. parser.add_argument('--shift', default=True, action="store_true", help='use shift for models')
  67. parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)')
  68. parser.add_argument('--shift_place', default='blockres', type=str, help='place for shift (default: stageres)')
  69. parser.add_argument('--temporal_pool', default=False, action="store_true", help='add temporal pooling')
  70. parser.add_argument('--non_local', default=False, action="store_true", help='add non local block')
  71. parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample for video dataset')

3.之后开始训练:main.py修改如下:

  1. # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
  2. # arXiv:1811.08383
  3. # Ji Lin*, Chuang Gan, Song Han
  4. # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
  5. import os
  6. import time
  7. import shutil
  8. import torch.nn.parallel
  9. import torch.backends.cudnn as cudnn
  10. import torch.optim
  11. from torch.nn.utils import clip_grad_norm_
  12. from ops.dataset import TSNDataSet
  13. from ops.models import TSN
  14. from ops.transforms import *
  15. from opts import parser
  16. from ops import dataset_config
  17. from ops.utils import AverageMeter, accuracy
  18. from ops.temporal_shift import make_temporal_pool
  19. from tensorboardX import SummaryWriter
  20. os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
  21. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  22. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  23. best_prec1 = 0
  24. def main():
  25. global args, best_prec1
  26. args = parser.parse_args()
  27. num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
  28. args.modality)
  29. full_arch_name = args.arch
  30. if args.shift:
  31. full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place)
  32. if args.temporal_pool:
  33. full_arch_name += '_tpool'
  34. args.store_name = '_'.join(
  35. ['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments,
  36. 'e{}'.format(args.epochs)])
  37. if args.pretrain != 'imagenet':
  38. args.store_name += '_{}'.format(args.pretrain)
  39. if args.lr_type != 'step':
  40. args.store_name += '_{}'.format(args.lr_type)
  41. if args.dense_sample:
  42. args.store_name += '_dense'
  43. if args.non_local > 0:
  44. args.store_name += '_nl'
  45. if args.suffix is not None:
  46. args.store_name += '_{}'.format(args.suffix)
  47. print('storing name: ' + args.store_name)
  48. check_rootfolders()
  49. model = TSN(num_class, args.num_segments, args.modality,
  50. base_model=args.arch,
  51. consensus_type=args.consensus_type,
  52. dropout=args.dropout,
  53. img_feature_dim=args.img_feature_dim,
  54. partial_bn=not args.no_partialbn,
  55. pretrain=args.pretrain,
  56. is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
  57. fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
  58. temporal_pool=args.temporal_pool,
  59. non_local=args.non_local)
  60. crop_size = model.crop_size
  61. scale_size = model.scale_size
  62. input_mean = model.input_mean
  63. input_std = model.input_std
  64. policies = model.get_optim_policies()
  65. train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)
  66. model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
  67. optimizer = torch.optim.SGD(policies,
  68. args.lr,
  69. momentum=args.momentum,
  70. weight_decay=args.weight_decay)
  71. if args.resume:
  72. if args.temporal_pool: # early temporal pool so that we can load the state_dict
  73. make_temporal_pool(model.module.base_model, args.num_segments)
  74. if os.path.isfile(args.resume):
  75. print(("=> loading checkpoint '{}'".format(args.resume)))
  76. checkpoint = torch.load(args.resume)
  77. args.start_epoch = checkpoint['epoch']
  78. best_prec1 = checkpoint['best_prec1']
  79. model.load_state_dict(checkpoint['state_dict'])
  80. optimizer.load_state_dict(checkpoint['optimizer'])
  81. print(("=> loaded checkpoint '{}' (epoch {})"
  82. .format(args.evaluate, checkpoint['epoch'])))
  83. else:
  84. print(("=> no checkpoint found at '{}'".format(args.resume)))
  85. if args.tune_from:
  86. print(("=> fine-tuning from '{}'".format(args.tune_from)))
  87. sd = torch.load(args.tune_from)
  88. sd = sd['state_dict']
  89. model_dict = model.state_dict()
  90. replace_dict = []
  91. for k, v in sd.items():
  92. if k not in model_dict and k.replace('.net', '') in model_dict:
  93. print('=> Load after remove .net: ', k)
  94. replace_dict.append((k, k.replace('.net', '')))
  95. for k, v in model_dict.items():
  96. if k not in sd and k.replace('.net', '') in sd:
  97. print('=> Load after adding .net: ', k)
  98. replace_dict.append((k.replace('.net', ''), k))
  99. for k, k_new in replace_dict:
  100. sd[k_new] = sd.pop(k)
  101. keys1 = set(list(sd.keys()))
  102. keys2 = set(list(model_dict.keys()))
  103. set_diff = (keys1 - keys2) | (keys2 - keys1)
  104. print('#### Notice: keys that failed to load: {}'.format(set_diff))
  105. if args.dataset not in args.tune_from: # new dataset
  106. print('=> New dataset, do not load fc weights')
  107. sd = {k: v for k, v in sd.items() if 'fc' not in k}
  108. if args.modality == 'Flow' and 'Flow' not in args.tune_from:
  109. sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
  110. model_dict.update(sd)
  111. model.load_state_dict(model_dict)
  112. if args.temporal_pool and not args.resume:
  113. make_temporal_pool(model.module.base_model, args.num_segments)
  114. cudnn.benchmark = True
  115. # Data loading code
  116. if args.modality != 'RGBDiff':
  117. normalize = GroupNormalize(input_mean, input_std)
  118. else:
  119. normalize = IdentityTransform()
  120. if args.modality == 'RGB':
  121. data_length = 1
  122. elif args.modality in ['Flow', 'RGBDiff']:
  123. data_length = 5
  124. train_loader = torch.utils.data.DataLoader(
  125. TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
  126. new_length=data_length,
  127. modality=args.modality,
  128. image_tmpl=prefix,
  129. transform=torchvision.transforms.Compose([
  130. train_augmentation,
  131. Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
  132. ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
  133. normalize,
  134. ]), dense_sample=args.dense_sample),
  135. batch_size=args.batch_size, shuffle=True,
  136. num_workers=args.workers, pin_memory=True,
  137. drop_last=True) # prevent something not % n_GPU
  138. val_loader = torch.utils.data.DataLoader(
  139. TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
  140. new_length=data_length,
  141. modality=args.modality,
  142. image_tmpl=prefix,
  143. random_shift=False,
  144. transform=torchvision.transforms.Compose([
  145. GroupScale(int(scale_size)),
  146. GroupCenterCrop(crop_size),
  147. Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
  148. ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
  149. normalize,
  150. ]), dense_sample=args.dense_sample),
  151. batch_size=args.batch_size, shuffle=False,
  152. num_workers=args.workers, pin_memory=True)
  153. # define loss function (criterion) and optimizer
  154. if args.loss_type == 'nll':
  155. criterion = torch.nn.CrossEntropyLoss().cuda()
  156. else:
  157. raise ValueError("Unknown loss type")
  158. for group in policies:
  159. print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
  160. group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))
  161. if args.evaluate:
  162. validate(val_loader, model, criterion, 0)
  163. return
  164. log_training = open(os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
  165. with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f:
  166. f.write(str(args))
  167. tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name))
  168. for epoch in range(args.start_epoch, args.epochs):
  169. adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
  170. # train for one epoch
  171. train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer)
  172. # evaluate on validation set
  173. if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
  174. prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer)
  175. # remember best prec@1 and save checkpoint
  176. is_best = prec1 > best_prec1
  177. best_prec1 = max(prec1, best_prec1)
  178. tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)
  179. output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
  180. print(output_best)
  181. log_training.write(output_best + '\n')
  182. log_training.flush()
  183. save_checkpoint({
  184. 'epoch': epoch + 1,
  185. 'arch': args.arch,
  186. 'state_dict': model.state_dict(),
  187. 'optimizer': optimizer.state_dict(),
  188. 'best_prec1': best_prec1,
  189. }, is_best)
  190. def train(train_loader, model, criterion, optimizer, epoch, log, tf_writer):
  191. batch_time = AverageMeter()
  192. data_time = AverageMeter()
  193. losses = AverageMeter()
  194. top1 = AverageMeter()
  195. top5 = AverageMeter()
  196. if args.no_partialbn:
  197. model.module.partialBN(False)
  198. else:
  199. model.module.partialBN(True)
  200. # switch to train mode
  201. model.train()
  202. end = time.time()
  203. for i, (input, target) in enumerate(train_loader):
  204. # measure data loading time
  205. data_time.update(time.time() - end)
  206. target = target.cuda()
  207. input_var = torch.autograd.Variable(input)
  208. target_var = torch.autograd.Variable(target)
  209. # compute output
  210. output = model(input_var)
  211. loss = criterion(output, target_var)
  212. # measure accuracy and record loss
  213. prec1, prec5 = accuracy(output.data, target, topk=(1, 2))
  214. losses.update(loss.item(), input.size(0))
  215. top1.update(prec1.item(), input.size(0))
  216. top5.update(prec5.item(), input.size(0))
  217. # compute gradient and do SGD step
  218. loss.backward()
  219. if args.clip_gradient is not None:
  220. total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient)
  221. optimizer.step()
  222. optimizer.zero_grad()
  223. # measure elapsed time
  224. batch_time.update(time.time() - end)
  225. end = time.time()
  226. if i % args.print_freq == 0:
  227. output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
  228. 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  229. 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
  230. 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
  231. 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
  232. 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
  233. epoch, i, len(train_loader), batch_time=batch_time,
  234. data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1)) # TODO
  235. print(output)
  236. log.write(output + '\n')
  237. log.flush()
  238. tf_writer.add_scalar('loss/train', losses.avg, epoch)
  239. tf_writer.add_scalar('acc/train_top1', top1.avg, epoch)
  240. tf_writer.add_scalar('acc/train_top5', top5.avg, epoch)
  241. tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)
  242. def validate(val_loader, model, criterion, epoch, log=None, tf_writer=None):
  243. batch_time = AverageMeter()
  244. losses = AverageMeter()
  245. top1 = AverageMeter()
  246. top5 = AverageMeter()
  247. # switch to evaluate mode
  248. model.eval()
  249. end = time.time()
  250. with torch.no_grad():
  251. for i, (input, target) in enumerate(val_loader):
  252. target = target.cuda()
  253. # compute output
  254. output = model(input)
  255. loss = criterion(output, target)
  256. # measure accuracy and record loss
  257. prec1, prec5 = accuracy(output.data, target, topk=(1, 2))
  258. losses.update(loss.item(), input.size(0))
  259. top1.update(prec1.item(), input.size(0))
  260. top5.update(prec5.item(), input.size(0))
  261. # measure elapsed time
  262. batch_time.update(time.time() - end)
  263. end = time.time()
  264. if i % args.print_freq == 0:
  265. output = ('Test: [{0}/{1}]\t'
  266. 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  267. 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
  268. 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
  269. 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
  270. i, len(val_loader), batch_time=batch_time, loss=losses,
  271. top1=top1, top5=top5))
  272. print(output)
  273. if log is not None:
  274. log.write(output + '\n')
  275. log.flush()
  276. output = ('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
  277. .format(top1=top1, top5=top5, loss=losses))
  278. print(output)
  279. if log is not None:
  280. log.write(output + '\n')
  281. log.flush()
  282. if tf_writer is not None:
  283. tf_writer.add_scalar('loss/test', losses.avg, epoch)
  284. tf_writer.add_scalar('acc/test_top1', top1.avg, epoch)
  285. tf_writer.add_scalar('acc/test_top5', top5.avg, epoch)
  286. return top1.avg
  287. def save_checkpoint(state, is_best):
  288. filename = '%s/%s/ckpt.pth.tar' % (args.root_model, args.store_name)
  289. torch.save(state, filename)
  290. if is_best:
  291. shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar'))
  292. def adjust_learning_rate(optimizer, epoch, lr_type, lr_steps):
  293. """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
  294. if lr_type == 'step':
  295. decay = 0.1 ** (sum(epoch >= np.array(lr_steps)))
  296. lr = args.lr * decay
  297. decay = args.weight_decay
  298. elif lr_type == 'cos':
  299. import math
  300. lr = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.epochs))
  301. decay = args.weight_decay
  302. else:
  303. raise NotImplementedError
  304. for param_group in optimizer.param_groups:
  305. param_group['lr'] = lr * param_group['lr_mult']
  306. param_group['weight_decay'] = decay * param_group['decay_mult']
  307. def check_rootfolders():
  308. """Create log and model folder"""
  309. folders_util = [args.root_log, args.root_model,
  310. os.path.join(args.root_log, args.store_name),
  311. os.path.join(args.root_model, args.store_name)]
  312. for folder in folders_util:
  313. if not os.path.exists(folder):
  314. print('creating folder ' + folder)
  315. os.mkdir(folder)
  316. if __name__ == '__main__':
  317. main()

 以上是训练结果

4.demo测试:

新建demo.py文件:

  1. import os
  2. import time
  3. from ops.models import TSN
  4. from ops.transforms import *
  5. import cv2
  6. from PIL import Image
  7. arch = 'resnet50'
  8. num_class = 2
  9. num_segments = 8
  10. modality = 'RGB'
  11. base_model = 'resnet50'
  12. consensus_type='avg'
  13. dataset = 'ucf101'
  14. dropout = 0.1
  15. img_feature_dim = 256
  16. no_partialbn = True
  17. pretrain = 'imagenet'
  18. shift = True
  19. shift_div = 8
  20. shift_place = 'blockres'
  21. temporal_pool = False
  22. non_local = False
  23. tune_from = None
  24. #load model
  25. model = TSN(num_class, num_segments, modality,
  26. base_model=arch,
  27. consensus_type=consensus_type,
  28. dropout=dropout,
  29. img_feature_dim=img_feature_dim,
  30. partial_bn=not no_partialbn,
  31. pretrain=pretrain,
  32. is_shift=shift, shift_div=shift_div, shift_place=shift_place,
  33. fc_lr5=not (tune_from and dataset in tune_from),
  34. temporal_pool=temporal_pool,
  35. non_local=non_local)
  36. model = torch.nn.DataParallel(model, device_ids=None).cuda()
  37. resume = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\checkpoint\TSM_ucf101_RGB_resnet50_shift8_blockres_avg_segment8_e20\ckpt.best.pth.tar' # the last weights
  38. checkpoint = torch.load(resume)
  39. model.load_state_dict(checkpoint['state_dict'])
  40. model.eval()
  41. #how to deal with the pictures
  42. input_mean = [0.485, 0.456, 0.406]
  43. input_std = [0.229, 0.224, 0.225]
  44. normalize = GroupNormalize(input_mean, input_std)
  45. transform_hyj = torchvision.transforms.Compose([
  46. GroupScale_hyj(input_size=320),
  47. Stack(roll=(arch in ['BNInception', 'InceptionV3'])),
  48. ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])),
  49. normalize,
  50. ])
  51. video_path = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data\posvideo\sketch/videos\YoYo/v_YoYo_g08_c01.avi'
  52. pil_img_list = list()
  53. cls_text = ['Rowing','YoYo']
  54. cls_color = [(0,255,0),(0,0,255)]
  55. import time
  56. cap = cv2.VideoCapture(video_path) #导入的视频所在路径
  57. start_time = time.time()
  58. counter = 0
  59. frame_numbers = 0
  60. training_fps = 30
  61. training_time = 2.5
  62. fps = cap.get(cv2.CAP_PROP_FPS) #视频平均帧率
  63. if fps < 1:
  64. fps = 30
  65. duaring = int(fps * training_time / num_segments)
  66. print(duaring)
  67. # exit()
  68. state = 0
  69. while cap.isOpened():
  70. ret, frame = cap.read()
  71. if ret:
  72. frame_numbers+=1
  73. print(frame_numbers)
  74. # print(len(pil_img_list))
  75. if frame_numbers%duaring == 0 and len(pil_img_list)<8:
  76. frame_pil = Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB))
  77. pil_img_list.extend([frame_pil])
  78. if frame_numbers%duaring == 0 and len(pil_img_list)==8:
  79. frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
  80. pil_img_list.pop(0)
  81. pil_img_list.extend([frame_pil])
  82. input = transform_hyj(pil_img_list)
  83. input = input.unsqueeze(0).cuda()
  84. out = model(input)
  85. print(out)
  86. output_index = int(torch.argmax(out).cpu())
  87. state = output_index
  88. #键盘输入空格暂停,输入q退出
  89. key = cv2.waitKey(1) & 0xff
  90. if key == ord(" "):
  91. cv2.waitKey(0)
  92. if key == ord("q"):
  93. break
  94. counter += 1#计算帧数
  95. if (time.time() - start_time) != 0:#实时显示帧数
  96. cv2.putText(frame, "{0} {1}".format((cls_text[state]),float('%.1f' % (counter / (time.time() - start_time)))), (50, 50),cv2.FONT_HERSHEY_SIMPLEX, 2, cls_color[state],3)
  97. cv2.imshow('frame', frame)
  98. counter = 0
  99. start_time = time.time()
  100. time.sleep(1 / fps)#按原帧率播放
  101. # time.sleep(2/fps)# observe the output
  102. else:
  103. break
  104. cap.release()
  105. cv2.destroyAllWindows()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/193208
推荐阅读
相关标签
  

闽ICP备14008679号