赞
踩
类似ucf101数据集格式
- #makelabel.py
- import os
- #图片数据集路径
- baseDir = "D:\Downloads\PaddleVideo-develop\posvideo0/videos"
- #标注文件输出文件夹
- targetDir = "D:\Downloads\PaddleVideo-develop\posvideo0/annotations"
-
- if not os.path.exists(targetDir):
- os.makedirs(targetDir)
- #classind.txt
- labels = os.listdir(baseDir)
- with open(os.path.join(targetDir,"classInd.txt"),"w+") as f:
- for i in range(len(labels)):
- line = "{} {}\n".format(i+1,labels[i])
- f.write(line)
- #label
- labels = dict([(labels[i],i+1) for i in range(len(labels))])
- print(labels)
-
- #trainlist.txt
- with open(os.path.join(targetDir,"trainlist.txt"),"w+") as f:
- for label in labels:
- index = os.listdir(os.path.join(baseDir,label))
- for i in index:
- line = "{}/{} {}\n".format(label,i,labels[label])
- f.write(line)
- #testlist.txt
- with open(os.path.join(targetDir,"testlist.txt"),"w+") as f:
- for label in labels:
- index = os.listdir(os.path.join(baseDir,label))
- for i in index:
- line = "{}/{}\n".format(label,i)
- f.write(line)

下载模型代码: http:// https://github. com/mit-han-lab/temporal-shift-module.
- from __future__ import print_function, division
- import os
- import sys
- import subprocess
-
- def class_process(dir_path, dst_dir_path, class_name):
- class_path = os.path.join(dir_path, class_name)
- if not os.path.isdir(class_path):
- return
-
- dst_class_path = os.path.join(dst_dir_path, class_name)
- if not os.path.exists(dst_class_path):
- os.mkdir(dst_class_path)
-
- for file_name in os.listdir(class_path):
- if '.avi' not in file_name:
- continue
- name, ext = os.path.splitext(file_name)
- dst_directory_path = os.path.join(dst_class_path, name)
-
- video_file_path = os.path.join(class_path, file_name)
- try:
- if os.path.exists(dst_directory_path):
- if not os.path.exists(os.path.join(dst_directory_path, 'image_00001.jpg')):
- subprocess.call('rm -r \"{}\"'.format(dst_directory_path), shell=True)
- print('remove {}'.format(dst_directory_path))
- os.mkdir(dst_directory_path)
- else:
- continue
- else:
- os.mkdir(dst_directory_path)
- except:
- print(dst_directory_path)
- continue
- cmd = 'ffmpeg -i \"{}\" -vf scale=-1:480 \"{}/image_%05d.jpg\"'.format(video_file_path, dst_directory_path)
- print(cmd)
- subprocess.call(cmd, shell=True)
- print('\n')
-
- if __name__=="__main__":
- dir_path = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data3/videos' # 视频文件总路径
- dst_dir_path = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data3/rawframes' # 抽帧后图片存放路径
-
- for class_name in os.listdir(dir_path):
- class_process(dir_path, dst_dir_path, class_name)
-

如果是mp4文件将.avi改成.mp4
结果:
打开gen_label_ucf101.py文件,修改路径
- import os
- import glob
- import fnmatch
- import random
- import sys
- root = r"data4/rawframes" # 抽帧后的图片存放目录文件夹,用于写到txt文件中在构建数据集的时候读取
-
- def parse_ucf_splits():
- class_ind = [x.strip().split() for x in open('D:\learning\ActionSeqmentation/temporal-shift-module-master\data4/annotations\classInd.txt')] # 类别txt
- class_mapping = {x[1]:int(x[0])-1 for x in class_ind}
-
- def line2rec(line):
- items = line.strip().split('/')
- label = class_mapping[items[0]]
- vid = items[1].split('.')[0]
- return vid, label
-
- splits = []
- for i in range(1, 4):
- train_list = [line2rec(x) for x in open(r'D:\learning\ActionSeqmentation\temporal-shift-module-master\data4\annotations\trainlist.txt'.format(i))] # 训练集txt
- test_list = [line2rec(x) for x in open(r'D:\learning\ActionSeqmentation\temporal-shift-module-master\data4\annotations\testlist.txt'.format(i))] # 测试集txt
- splits.append((train_list, test_list))
- return splits
-
- split_parsers = dict()
- split_parsers['ucf101'] = parse_ucf_splits()
-
- def parse_split_file(dataset):
- sp = split_parsers[dataset]
- return tuple(sp)
-
- def parse_directory(path, rgb_prefix='image_', flow_x_prefix='flow_x_', flow_y_prefix='flow_y_'):
- """
- Parse directories holding extracted frames from standard benchmarks
- """
- print('parse frames under folder {}'.format(path))
- frame_folders = []
- frame = glob.glob(os.path.join(path, '*'))
- for frame_name in frame:
- frame_path = glob.glob(os.path.join(frame_name, '*'))
- frame_folders.extend(frame_path)
-
- def count_files(directory, prefix_list):
- lst = os.listdir(directory)
- cnt_list = [len(fnmatch.filter(lst, x+'*')) for x in prefix_list]
- return cnt_list
-
- # check RGB
- rgb_counts = {}
- flow_counts = {}
- dir_dict = {}
- for i,f in enumerate(frame_folders):
- all_cnt = count_files(f, (rgb_prefix, flow_x_prefix, flow_y_prefix))
- k = f.split('\\')[-1]
- rgb_counts[k] = all_cnt[0]
- dir_dict[k] = f
-
- x_cnt = all_cnt[1]
- y_cnt = all_cnt[2]
- if x_cnt != y_cnt:
- raise ValueError('x and y direction have different number of flow images. video: '+f)
- flow_counts[k] = x_cnt
- if i % 200 == 0:
- print('{} videos parsed'.format(i))
-
- print('frame folder analysis done')
- return dir_dict, rgb_counts, flow_counts
-
- def build_split_list(split_tuple, frame_info, split_idx, shuffle=False):
- split = split_tuple[split_idx]
-
- def build_set_list(set_list):
- rgb_list, flow_list = list(), list()
- for item in set_list:
- frame_dir = frame_info[0][item[0]]
- frame_dir = root +'/'+ frame_dir.split('\\')[-2] +'/'+ frame_dir.split('\\')[-1]
-
- rgb_cnt = frame_info[1][item[0]]
- flow_cnt = frame_info[2][item[0]]
- rgb_list.append('{} {} {}\n'.format(frame_dir, rgb_cnt, item[1]))
- flow_list.append('{} {} {}\n'.format(frame_dir, flow_cnt, item[1]))
- if shuffle:
- random.shuffle(rgb_list)
- random.shuffle(flow_list)
- return rgb_list, flow_list
- train_rgb_list, train_flow_list = build_set_list(split[0])
- test_rgb_list, test_flow_list = build_set_list(split[1])
- return (train_rgb_list, test_rgb_list), (train_flow_list, test_flow_list)
-
- spl = parse_split_file('ucf101')
- f_info = parse_directory(r"D:\learning\ActionSeqmentation\temporal-shift-module-master\data4\rawframes") # 存放抽帧后的图片
-
- out_path = r"D:\learning\ActionSeqmentation\temporal-shift-module-master\data4" # 标签路径
- dataset = "ucf101"
-
- for i in range(max(3,len(spl))):
- lists = build_split_list(spl,f_info,i)
- open(os.path.join(out_path, '{}_rgb_train_split_{}.txt'.format(dataset, i + 1)), 'w').writelines(lists[0][0])
- open(os.path.join(out_path, '{}_rgb_val_split_{}.txt'.format(dataset, i + 1)), 'w').writelines(lists[0][1])
- # open(os.path.join(out_path, '{}_flow_train_split_{}.txt'.format(dataset, i + 1)), 'w').writelines(lists[1][0])
- # open(os.path.join(out_path, '{}_flow_val_split_{}.txt'.format(dataset, i + 1)), 'w').writelines(lists[1][1])

下载预训练权重,我是用迅雷下载的https://hanlab.mit.edu/projects/tsm/models/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e100_dense.pth
更改ops/dataset_config.py中文件
- # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
- # arXiv:1811.08383
- # Ji Lin*, Chuang Gan, Song Han
- # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
-
- import os
-
- ROOT_DATASET = 'D:\learning\ActionSeqmentation/temporal-shift-module-master' # '/data/jilin/'
-
-
- def return_ucf101(modality):
- filename_categories = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data4/annotations\classInd.txt'
- if modality == 'RGB':
- root_data = ROOT_DATASET + '/'
- filename_imglist_train = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data4/ucf101_rgb_train_split_1.txt'
- filename_imglist_val = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data4/ucf101_rgb_val_split_1.txt'
- prefix = 'image_{:05d}.jpg'
- elif modality == 'Flow':
- root_data = ROOT_DATASET + 'UCF101/jpg'
- filename_imglist_train = 'UCF101/file_list/ucf101_flow_train_split_1.txt'
- filename_imglist_val = 'UCF101/file_list/ucf101_flow_val_split_1.txt'
- prefix = 'flow_{}_{:05d}.jpg'
- else:
- raise NotImplementedError('no such modality:' + modality)
- return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
-
-
- def return_hmdb51(modality):
- filename_categories = 51
- if modality == 'RGB':
- root_data = ROOT_DATASET + 'HMDB51/images'
- filename_imglist_train = 'HMDB51/splits/hmdb51_rgb_train_split_1.txt'
- filename_imglist_val = 'HMDB51/splits/hmdb51_rgb_val_split_1.txt'
- prefix = 'img_{:05d}.jpg'
- elif modality == 'Flow':
- root_data = ROOT_DATASET + 'HMDB51/images'
- filename_imglist_train = 'HMDB51/splits/hmdb51_flow_train_split_1.txt'
- filename_imglist_val = 'HMDB51/splits/hmdb51_flow_val_split_1.txt'
- prefix = 'flow_{}_{:05d}.jpg'
- else:
- raise NotImplementedError('no such modality:' + modality)
- return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
-
-
- def return_something(modality):
- filename_categories = 'something/v1/category.txt'
- if modality == 'RGB':
- root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1'
- filename_imglist_train = 'something/v1/train_videofolder.txt'
- filename_imglist_val = 'something/v1/val_videofolder.txt'
- prefix = '{:05d}.jpg'
- elif modality == 'Flow':
- root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1-flow'
- filename_imglist_train = 'something/v1/train_videofolder_flow.txt'
- filename_imglist_val = 'something/v1/val_videofolder_flow.txt'
- prefix = '{:06d}-{}_{:05d}.jpg'
- else:
- print('no such modality:'+modality)
- raise NotImplementedError
- return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
-
-
- def return_somethingv2(modality):
- filename_categories = 'something/v2/category.txt'
- if modality == 'RGB':
- root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-frames'
- filename_imglist_train = 'something/v2/train_videofolder.txt'
- filename_imglist_val = 'something/v2/val_videofolder.txt'
- prefix = '{:06d}.jpg'
- elif modality == 'Flow':
- root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-flow'
- filename_imglist_train = 'something/v2/train_videofolder_flow.txt'
- filename_imglist_val = 'something/v2/val_videofolder_flow.txt'
- prefix = '{:06d}.jpg'
- else:
- raise NotImplementedError('no such modality:'+modality)
- return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
-
-
- def return_jester(modality):
- filename_categories = 'jester/category.txt'
- if modality == 'RGB':
- prefix = '{:05d}.jpg'
- root_data = ROOT_DATASET + 'jester/20bn-jester-v1'
- filename_imglist_train = 'jester/train_videofolder.txt'
- filename_imglist_val = 'jester/val_videofolder.txt'
- else:
- raise NotImplementedError('no such modality:'+modality)
- return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
-
-
- def return_kinetics(modality):
- filename_categories = 400
- if modality == 'RGB':
- root_data = ROOT_DATASET + 'kinetics/images'
- filename_imglist_train = 'kinetics/labels/train_videofolder.txt'
- filename_imglist_val = 'kinetics/labels/val_videofolder.txt'
- prefix = 'img_{:05d}.jpg'
- else:
- raise NotImplementedError('no such modality:' + modality)
- return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
-
-
- def return_dataset(dataset, modality):
- dict_single = {'jester': return_jester, 'something': return_something, 'somethingv2': return_somethingv2,
- 'ucf101': return_ucf101, 'hmdb51': return_hmdb51,
- 'kinetics': return_kinetics }
- if dataset in dict_single:
- file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality)
- else:
- raise ValueError('Unknown dataset '+dataset)
-
- file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train)
- file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val)
- if isinstance(file_categories, str):
- file_categories = os.path.join(ROOT_DATASET, file_categories)
- with open(file_categories) as f:
- lines = f.readlines()
- categories = [item.rstrip() for item in lines]
- else: # number of categories
- categories = [None] * file_categories
- n_class = len(categories)
- print('{}: {} classes'.format(dataset, n_class))
- return n_class, file_imglist_train, file_imglist_val, root_data, prefix

更改opts.py
- # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
- # arXiv:1811.08383
- # Ji Lin*, Chuang Gan, Song Han
- # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
-
- import argparse
- parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks")
- parser.add_argument('--dataset',
- default='ucf101',
- type=str)
- parser.add_argument('--modality', type=str,
- default='RGB',
- choices=['RGB', 'Flow'])
- parser.add_argument('--train_list', type=str, default="data4/ucf101_rgb_train_split_1.txt")
- parser.add_argument('--val_list', type=str, default="data4/ucf101_rgb_val_split_1.txt")
- parser.add_argument('--root_path', type=str, default="D:\learning\ActionSeqmentation/temporal-shift-module-master")
- # ========================= Model Configs ==========================
- parser.add_argument('--arch', type=str, default="resnet50")
- parser.add_argument('--num_segments', type=int, default=8)
- parser.add_argument('--consensus_type', type=str, default='avg')
- parser.add_argument('--k', type=int, default=3)
-
- parser.add_argument('--dropout', '--do', default=0.8, type=float,
- metavar='DO', help='dropout ratio (default: 0.5)')
- parser.add_argument('--loss_type', type=str, default="nll",
- choices=['nll'])
- parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame")
- parser.add_argument('--suffix', type=str, default=None)
- parser.add_argument('--pretrain', type=str, default='imagenet')
- parser.add_argument('--tune_from', type=str, default='D:\learning\ActionSeqmentation/temporal-shift-module-master\pretrain\TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e100_dense.pth', help='fine-tune from checkpoint')
-
- # ========================= Learning Configs ==========================
- parser.add_argument('--epochs', default=20, type=int, metavar='N',
- help='number of total epochs to run')
- parser.add_argument('-b', '--batch-size', default=4, type=int,
- metavar='N', help='mini-batch size (default: 256)')
- parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
- metavar='LR', help='initial learning rate')
- parser.add_argument('--lr_type', default='step', type=str,
- metavar='LRtype', help='learning rate type')
- parser.add_argument('--lr_steps', default=[10, 20], type=float, nargs="+",
- metavar='LRSteps', help='epochs to decay learning rate by 10')
- parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
- help='momentum')
- parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
- metavar='W', help='weight decay (default: 5e-4)')
- parser.add_argument('--clip-gradient', '--gd', default=None, type=float,
- metavar='W', help='gradient norm clipping (default: disabled)')
- parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true")
-
- # ========================= Monitor Configs ==========================
- parser.add_argument('--print-freq', '-p', default=20, type=int,
- metavar='N', help='print frequency (default: 10)')
- parser.add_argument('--eval-freq', '-ef', default=1, type=int,
- metavar='N', help='evaluation frequency (default: 5)')
-
-
- # ========================= Runtime Configs ==========================
- parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
- help='number of data loading workers (default: 8)')
- parser.add_argument('--resume', default='', type=str, metavar='PATH',
- help='path to latest checkpoint (default: none)')
- parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
- help='evaluate model on validation set')
- parser.add_argument('--snapshot_pref', type=str, default="")
- parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
- help='manual epoch number (useful on restarts)')
- parser.add_argument('--gpus', nargs='+', type=int, default=None)
- parser.add_argument('--flow_prefix', default="", type=str)
- parser.add_argument('--root_log',type=str, default='log')
- parser.add_argument('--root_model', type=str, default='checkpoint')
-
- parser.add_argument('--shift', default=True, action="store_true", help='use shift for models')
- parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)')
- parser.add_argument('--shift_place', default='blockres', type=str, help='place for shift (default: stageres)')
-
- parser.add_argument('--temporal_pool', default=False, action="store_true", help='add temporal pooling')
- parser.add_argument('--non_local', default=False, action="store_true", help='add non local block')
-
- parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample for video dataset')

- # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
- # arXiv:1811.08383
- # Ji Lin*, Chuang Gan, Song Han
- # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
-
- import os
- import time
- import shutil
- import torch.nn.parallel
- import torch.backends.cudnn as cudnn
- import torch.optim
- from torch.nn.utils import clip_grad_norm_
-
- from ops.dataset import TSNDataSet
- from ops.models import TSN
- from ops.transforms import *
- from opts import parser
- from ops import dataset_config
- from ops.utils import AverageMeter, accuracy
- from ops.temporal_shift import make_temporal_pool
-
- from tensorboardX import SummaryWriter
- os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
- best_prec1 = 0
-
-
- def main():
- global args, best_prec1
- args = parser.parse_args()
-
- num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
- args.modality)
- full_arch_name = args.arch
- if args.shift:
- full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place)
- if args.temporal_pool:
- full_arch_name += '_tpool'
- args.store_name = '_'.join(
- ['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments,
- 'e{}'.format(args.epochs)])
- if args.pretrain != 'imagenet':
- args.store_name += '_{}'.format(args.pretrain)
- if args.lr_type != 'step':
- args.store_name += '_{}'.format(args.lr_type)
- if args.dense_sample:
- args.store_name += '_dense'
- if args.non_local > 0:
- args.store_name += '_nl'
- if args.suffix is not None:
- args.store_name += '_{}'.format(args.suffix)
- print('storing name: ' + args.store_name)
-
- check_rootfolders()
-
- model = TSN(num_class, args.num_segments, args.modality,
- base_model=args.arch,
- consensus_type=args.consensus_type,
- dropout=args.dropout,
- img_feature_dim=args.img_feature_dim,
- partial_bn=not args.no_partialbn,
- pretrain=args.pretrain,
- is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
- fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
- temporal_pool=args.temporal_pool,
- non_local=args.non_local)
-
- crop_size = model.crop_size
- scale_size = model.scale_size
- input_mean = model.input_mean
- input_std = model.input_std
- policies = model.get_optim_policies()
- train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)
-
- model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
-
- optimizer = torch.optim.SGD(policies,
- args.lr,
- momentum=args.momentum,
- weight_decay=args.weight_decay)
-
- if args.resume:
- if args.temporal_pool: # early temporal pool so that we can load the state_dict
- make_temporal_pool(model.module.base_model, args.num_segments)
- if os.path.isfile(args.resume):
- print(("=> loading checkpoint '{}'".format(args.resume)))
- checkpoint = torch.load(args.resume)
- args.start_epoch = checkpoint['epoch']
- best_prec1 = checkpoint['best_prec1']
- model.load_state_dict(checkpoint['state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- print(("=> loaded checkpoint '{}' (epoch {})"
- .format(args.evaluate, checkpoint['epoch'])))
- else:
- print(("=> no checkpoint found at '{}'".format(args.resume)))
-
- if args.tune_from:
- print(("=> fine-tuning from '{}'".format(args.tune_from)))
- sd = torch.load(args.tune_from)
- sd = sd['state_dict']
- model_dict = model.state_dict()
- replace_dict = []
- for k, v in sd.items():
- if k not in model_dict and k.replace('.net', '') in model_dict:
- print('=> Load after remove .net: ', k)
- replace_dict.append((k, k.replace('.net', '')))
- for k, v in model_dict.items():
- if k not in sd and k.replace('.net', '') in sd:
- print('=> Load after adding .net: ', k)
- replace_dict.append((k.replace('.net', ''), k))
-
- for k, k_new in replace_dict:
- sd[k_new] = sd.pop(k)
- keys1 = set(list(sd.keys()))
- keys2 = set(list(model_dict.keys()))
- set_diff = (keys1 - keys2) | (keys2 - keys1)
- print('#### Notice: keys that failed to load: {}'.format(set_diff))
- if args.dataset not in args.tune_from: # new dataset
- print('=> New dataset, do not load fc weights')
- sd = {k: v for k, v in sd.items() if 'fc' not in k}
- if args.modality == 'Flow' and 'Flow' not in args.tune_from:
- sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
- model_dict.update(sd)
- model.load_state_dict(model_dict)
-
- if args.temporal_pool and not args.resume:
- make_temporal_pool(model.module.base_model, args.num_segments)
-
- cudnn.benchmark = True
-
- # Data loading code
- if args.modality != 'RGBDiff':
- normalize = GroupNormalize(input_mean, input_std)
- else:
- normalize = IdentityTransform()
-
- if args.modality == 'RGB':
- data_length = 1
- elif args.modality in ['Flow', 'RGBDiff']:
- data_length = 5
-
- train_loader = torch.utils.data.DataLoader(
- TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
- new_length=data_length,
- modality=args.modality,
- image_tmpl=prefix,
- transform=torchvision.transforms.Compose([
- train_augmentation,
- Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
- ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
- normalize,
- ]), dense_sample=args.dense_sample),
- batch_size=args.batch_size, shuffle=True,
- num_workers=args.workers, pin_memory=True,
- drop_last=True) # prevent something not % n_GPU
-
- val_loader = torch.utils.data.DataLoader(
- TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
- new_length=data_length,
- modality=args.modality,
- image_tmpl=prefix,
- random_shift=False,
- transform=torchvision.transforms.Compose([
- GroupScale(int(scale_size)),
- GroupCenterCrop(crop_size),
- Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
- ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
- normalize,
- ]), dense_sample=args.dense_sample),
- batch_size=args.batch_size, shuffle=False,
- num_workers=args.workers, pin_memory=True)
-
- # define loss function (criterion) and optimizer
- if args.loss_type == 'nll':
- criterion = torch.nn.CrossEntropyLoss().cuda()
- else:
- raise ValueError("Unknown loss type")
-
- for group in policies:
- print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
- group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))
-
- if args.evaluate:
- validate(val_loader, model, criterion, 0)
- return
-
- log_training = open(os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
- with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f:
- f.write(str(args))
- tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name))
- for epoch in range(args.start_epoch, args.epochs):
- adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
-
- # train for one epoch
- train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer)
-
- # evaluate on validation set
- if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
- prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer)
-
- # remember best prec@1 and save checkpoint
- is_best = prec1 > best_prec1
- best_prec1 = max(prec1, best_prec1)
- tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)
-
- output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
- print(output_best)
- log_training.write(output_best + '\n')
- log_training.flush()
-
- save_checkpoint({
- 'epoch': epoch + 1,
- 'arch': args.arch,
- 'state_dict': model.state_dict(),
- 'optimizer': optimizer.state_dict(),
- 'best_prec1': best_prec1,
- }, is_best)
-
-
- def train(train_loader, model, criterion, optimizer, epoch, log, tf_writer):
- batch_time = AverageMeter()
- data_time = AverageMeter()
- losses = AverageMeter()
- top1 = AverageMeter()
- top5 = AverageMeter()
-
- if args.no_partialbn:
- model.module.partialBN(False)
- else:
- model.module.partialBN(True)
-
- # switch to train mode
- model.train()
-
- end = time.time()
- for i, (input, target) in enumerate(train_loader):
- # measure data loading time
- data_time.update(time.time() - end)
-
- target = target.cuda()
- input_var = torch.autograd.Variable(input)
- target_var = torch.autograd.Variable(target)
-
- # compute output
- output = model(input_var)
- loss = criterion(output, target_var)
-
- # measure accuracy and record loss
- prec1, prec5 = accuracy(output.data, target, topk=(1, 2))
- losses.update(loss.item(), input.size(0))
- top1.update(prec1.item(), input.size(0))
- top5.update(prec5.item(), input.size(0))
-
- # compute gradient and do SGD step
- loss.backward()
-
- if args.clip_gradient is not None:
- total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient)
-
- optimizer.step()
- optimizer.zero_grad()
-
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
-
- if i % args.print_freq == 0:
- output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
- 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
- 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
- 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
- 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
- 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
- epoch, i, len(train_loader), batch_time=batch_time,
- data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1)) # TODO
- print(output)
- log.write(output + '\n')
- log.flush()
-
- tf_writer.add_scalar('loss/train', losses.avg, epoch)
- tf_writer.add_scalar('acc/train_top1', top1.avg, epoch)
- tf_writer.add_scalar('acc/train_top5', top5.avg, epoch)
- tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)
-
-
- def validate(val_loader, model, criterion, epoch, log=None, tf_writer=None):
- batch_time = AverageMeter()
- losses = AverageMeter()
- top1 = AverageMeter()
- top5 = AverageMeter()
-
- # switch to evaluate mode
- model.eval()
-
- end = time.time()
- with torch.no_grad():
- for i, (input, target) in enumerate(val_loader):
- target = target.cuda()
-
- # compute output
- output = model(input)
- loss = criterion(output, target)
-
- # measure accuracy and record loss
- prec1, prec5 = accuracy(output.data, target, topk=(1, 2))
-
- losses.update(loss.item(), input.size(0))
- top1.update(prec1.item(), input.size(0))
- top5.update(prec5.item(), input.size(0))
-
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
-
- if i % args.print_freq == 0:
- output = ('Test: [{0}/{1}]\t'
- 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
- 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
- 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
- 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
- i, len(val_loader), batch_time=batch_time, loss=losses,
- top1=top1, top5=top5))
- print(output)
- if log is not None:
- log.write(output + '\n')
- log.flush()
-
- output = ('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
- .format(top1=top1, top5=top5, loss=losses))
- print(output)
- if log is not None:
- log.write(output + '\n')
- log.flush()
-
- if tf_writer is not None:
- tf_writer.add_scalar('loss/test', losses.avg, epoch)
- tf_writer.add_scalar('acc/test_top1', top1.avg, epoch)
- tf_writer.add_scalar('acc/test_top5', top5.avg, epoch)
-
- return top1.avg
-
-
- def save_checkpoint(state, is_best):
- filename = '%s/%s/ckpt.pth.tar' % (args.root_model, args.store_name)
- torch.save(state, filename)
- if is_best:
- shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar'))
-
-
- def adjust_learning_rate(optimizer, epoch, lr_type, lr_steps):
- """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
- if lr_type == 'step':
- decay = 0.1 ** (sum(epoch >= np.array(lr_steps)))
- lr = args.lr * decay
- decay = args.weight_decay
- elif lr_type == 'cos':
- import math
- lr = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.epochs))
- decay = args.weight_decay
- else:
- raise NotImplementedError
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr * param_group['lr_mult']
- param_group['weight_decay'] = decay * param_group['decay_mult']
-
-
- def check_rootfolders():
- """Create log and model folder"""
- folders_util = [args.root_log, args.root_model,
- os.path.join(args.root_log, args.store_name),
- os.path.join(args.root_model, args.store_name)]
- for folder in folders_util:
- if not os.path.exists(folder):
- print('creating folder ' + folder)
- os.mkdir(folder)
-
-
- if __name__ == '__main__':
- main()

以上是训练结果
4.demo测试:
新建demo.py文件:
- import os
- import time
- from ops.models import TSN
- from ops.transforms import *
- import cv2
- from PIL import Image
-
- arch = 'resnet50'
- num_class = 2
- num_segments = 8
- modality = 'RGB'
- base_model = 'resnet50'
- consensus_type='avg'
- dataset = 'ucf101'
- dropout = 0.1
- img_feature_dim = 256
- no_partialbn = True
- pretrain = 'imagenet'
- shift = True
- shift_div = 8
- shift_place = 'blockres'
- temporal_pool = False
- non_local = False
- tune_from = None
-
-
- #load model
- model = TSN(num_class, num_segments, modality,
- base_model=arch,
- consensus_type=consensus_type,
- dropout=dropout,
- img_feature_dim=img_feature_dim,
- partial_bn=not no_partialbn,
- pretrain=pretrain,
- is_shift=shift, shift_div=shift_div, shift_place=shift_place,
- fc_lr5=not (tune_from and dataset in tune_from),
- temporal_pool=temporal_pool,
- non_local=non_local)
-
- model = torch.nn.DataParallel(model, device_ids=None).cuda()
- resume = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\checkpoint\TSM_ucf101_RGB_resnet50_shift8_blockres_avg_segment8_e20\ckpt.best.pth.tar' # the last weights
- checkpoint = torch.load(resume)
- model.load_state_dict(checkpoint['state_dict'])
- model.eval()
-
- #how to deal with the pictures
- input_mean = [0.485, 0.456, 0.406]
- input_std = [0.229, 0.224, 0.225]
- normalize = GroupNormalize(input_mean, input_std)
- transform_hyj = torchvision.transforms.Compose([
- GroupScale_hyj(input_size=320),
- Stack(roll=(arch in ['BNInception', 'InceptionV3'])),
- ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])),
- normalize,
- ])
-
- video_path = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data\posvideo\sketch/videos\YoYo/v_YoYo_g08_c01.avi'
-
- pil_img_list = list()
-
- cls_text = ['Rowing','YoYo']
- cls_color = [(0,255,0),(0,0,255)]
-
- import time
-
- cap = cv2.VideoCapture(video_path) #导入的视频所在路径
- start_time = time.time()
- counter = 0
- frame_numbers = 0
- training_fps = 30
- training_time = 2.5
- fps = cap.get(cv2.CAP_PROP_FPS) #视频平均帧率
- if fps < 1:
- fps = 30
- duaring = int(fps * training_time / num_segments)
- print(duaring)
- # exit()
-
-
- state = 0
- while cap.isOpened():
- ret, frame = cap.read()
- if ret:
- frame_numbers+=1
- print(frame_numbers)
- # print(len(pil_img_list))
- if frame_numbers%duaring == 0 and len(pil_img_list)<8:
- frame_pil = Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB))
- pil_img_list.extend([frame_pil])
- if frame_numbers%duaring == 0 and len(pil_img_list)==8:
- frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
- pil_img_list.pop(0)
- pil_img_list.extend([frame_pil])
- input = transform_hyj(pil_img_list)
- input = input.unsqueeze(0).cuda()
- out = model(input)
- print(out)
- output_index = int(torch.argmax(out).cpu())
- state = output_index
-
- #键盘输入空格暂停,输入q退出
- key = cv2.waitKey(1) & 0xff
- if key == ord(" "):
- cv2.waitKey(0)
- if key == ord("q"):
- break
- counter += 1#计算帧数
- if (time.time() - start_time) != 0:#实时显示帧数
- cv2.putText(frame, "{0} {1}".format((cls_text[state]),float('%.1f' % (counter / (time.time() - start_time)))), (50, 50),cv2.FONT_HERSHEY_SIMPLEX, 2, cls_color[state],3)
- cv2.imshow('frame', frame)
-
- counter = 0
- start_time = time.time()
- time.sleep(1 / fps)#按原帧率播放
- # time.sleep(2/fps)# observe the output
- else:
- break
-
- cap.release()
- cv2.destroyAllWindows()
-

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。