当前位置:   article > 正文

PyTorch搭建卷积神经网络(CNN)进行视频行为识别(附源码和数据集)_视频行为分析数据集

视频行为分析数据集

需要数据集和源码请点赞关注收藏后评论区留下QQ邮箱~~~

一、行为识别简介

行为识别是视频理解中的一项基础任务,它可以从视频中提取语义信息,进而可以为其他任务如行为检测,行为定位等提供通用的视频表征

现有的视频行为数据集大致可以划分为两种类型

1:场景相关数据集  这一类的数据集场景提供了较多的语义信息 仅仅通过单帧图像便能很好的判断对应的行为 

2:时序相关数据集  这一类数据集对时间关系要求很高,需要足够多帧图像才能准确的识别视频中的行为。

例如骑马的例子就与场景高度相关,马和草地给出了足够多的语义信息

但是打开柜子就与时间高度相关,如果反转时序甚至容易认为在关闭柜子

 如下图

 

 二、数据准备

数据的准备包括对视频的抽帧处理,具体原理此处不再赘述

大家可自行前往官网下载数据集

视频行为识别数据集

三、模型搭建与训练

在介绍模型的搭建与训练之外,需要先了解的命令行参数,还有无名的必填参数dataset以及modality。前者用于选择数据集,后者用于确定数据集类型 是RGB图像还是Flow光流图像

过程比较繁琐 此处不再赘述

效果如下图

最终会得到如下的热力图,从红色到黄色到绿色到蓝色,网络的关注度从大到小,可以看到模块可以很好地定位到运动发生的时空区域 

四、代码 

项目结构如下

 

main函数代码

  1. import os
  2. import time
  3. import shutil
  4. import torch.nn.parallel
  5. imd_norm_
  6. from ops.dataset import TSNDataSet
  7. from ops.models import TSN
  8. from ops.transforms import *
  9. from opts import parser
  10. from ops import dataset_config
  11. from ops.utils import AverageMeter, accuracy
  12. from ops.temporal_shift import make_temporal_pool
  13. from tensorboardX import SummaryWriter
  14. best_prec1 = 0
  15. def main():
  16. global args, best_prec1
  17. args = parser.parse_args()
  18. num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
  19. args.modality)
  20. full_arch_name = args.arch
  21. if args.shift:
  22. full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place)
  23. if args.temporal_pool:
  24. full_arch_name += '_tpool'
  25. args.store_name = '_'.join(
  26. ['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments,
  27. 'e{}'.format(args.epochs)])
  28. args.store_name += '_nl'
  29. if args.suffix is not None:
  30. args.store_name += '_{}'.format(args.suffix)
  31. print('storing name: ' + args.store_name)
  32. check_rootfolders()
  33. model = TSN(num_class, args.num_segments, args.modality,
  34. base_model=args.arch,
  35. consensus_type=args.consensus_type,
  36. dropout=args.dropout,
  37. img_feature_dim=args.img_feature_dim,
  38. partial_bn=not args.no_partialbn,
  39. pretrain=args.pretrain,
  40. is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
  41. fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
  42. temporal_pool=args.temporal_pool,
  43. non_local=args.non_local)
  44. crop_size = model.crop_size
  45. scale_size = model.scale_size
  46. input_mean = model.input_mean
  47. in else True)
  48. model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
  49. optimizer = torch.optim.SGD(policies,
  50. args.lr,
  51. momentum=args.momentum,
  52. weight_decay=args.weight_decay)
  53. if args.resume:
  54. if args.temporal_pool: # early temporal pool so that we can load the state_dict
  55. make_temporal_pool(model.module.base_model, args.num_segments)
  56. if os.path.isfile(args.resume):
  57. print(("=> loading checkpoint '{}'".format(args.resume)))
  58. checkpoint = torch.load(args.resume)
  59. args.start_epoch = checkpoint['epoch']
  60. best_prec1 = checkpoint['best_prec1']
  61. model.load_state_dict(checkpoint['state_dict'])
  62. optimizer.load_state_dict(checkpoint['optimizer'])
  63. print(("=> loaded checkpoint '{}' (epoch {})"
  64. .format(args.evaluate, checkpoint['epoch'])))
  65. else:
  66. print(("=> no checkpoint found at '{}'".format(args.resume)))
  67. ate_dict']
  68. model_dict = model.state_dict()
  69. replace_dict = []
  70. for k, v in sd.items():
  71. if k not in model_dict and k.replace('.net', '') in model_dict:
  72. print('=> Load after remove .net: ', k)
  73. replace_dict.append((k, k.replace('.net', '')))
  74. for k, v in model_dict.items():
  75. if k not in sd and k.replace('.net', '') in sd:
  76. print('=> Load after adding .net: ', k)
  77. replace_dict.append((k.replace('.net', ''), k))
  78. for k, k_new in replace_dict:
  79. sd[k_new] = sd.pop(k)
  80. keys1 = set(list(sd.keys()))
  81. keys2 = set(list(model_dict.keys()))
  82. set_diff = (keys1 - keys2) | (keys2 - keys1)
  83. print('#### Notice: keys that failed to load: {}'.format(set_diff))
  84. if args.dataset not in args.tune_from: # new dataset
  85. print('=> New dataset, do not load fc weights')
  86. sd = {k: v for k, v in sd.items() if 'fc' not in k}
  87. if te_dict(model_dict)
  88. if args.temporal_pool and not args.resume:
  89. make_temporal_pool(model.module.base_model, args.num_segments)
  90. cudnn.benchmark = True
  91. # Data loading code
  92. if args.modality != 'RGBDiff':
  93. normalize = GroupNormalize(input_mean, input_std)
  94. else:
  95. normalize = IdentityTransform()
  96. if args.modality == 'RGB':
  97. data_length = 1
  98. elif args.modality in ['Flow', 'RGBDiff']:
  99. data_length = 5
  100. train_loader = torch.utils.data.DataLoader(
  101. TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
  102. new_length=data_length,
  103. modality=args.modality,
  104. image_tmpl=prefix,
  105. transform=torchvision.transforms.Compose([
  106. train_augmentation,
  107. Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
  108. ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
  109. normalize,
  110. ]), dense_sample=args.dense_sample),
  111. batch_size=args.batch_size, shuffle=True,
  112. num_workers=args.workers, pin_memory=True,
  113. drop_last=True) # prevent something not % n_GPU
  114. val_loader = torch.utils.data.DataLoader(
  115. TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
  116. new_length=data_length,
  117. modality=args.modality,
  118. image_tmpl=prefix,
  119. random_shift=False,
  120. transform=torchvision.transforms.Compose([
  121. GroupScale(int(scale_size)),
  122. GroupCenterCrop(crop_size),
  123. Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
  124. ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
  125. normalize,
  126. ]), dense_sample=args.dense_sample),
  127. batch_size=args.batch_size, shuffle=False,
  128. num_workers=args.workers, pin_memory=True)
  129. # define loss function (criterion) and optimizer
  130. if args.loss_type == 'nll':
  131. criterion = torch.nn.CrossEntropyLoss().cuda()
  132. else:
  133. raise ValueError("Unknown loss type")
  134. for group in policies:
  135. print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
  136. group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))
  137. if args.evaluate:
  138. validate(val_loader, model, criterion, 0)
  139. return
  140. log_training = open(os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
  141. with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f:
  142. f.write(str(args))
  143. tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name))
  144. for epoch in range(args.start_epoch, args.epochs):
  145. adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
  146. # train for one epoch
  147. train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer)
  148. # evaluate on validation set
  149. if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
  150. prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer)
  151. # remember best prec@1 and save checkpoint
  152. is_best = prec1 > best_prec1
  153. best_prec1 = max(prec1, best_prec1)
  154. tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)
  155. output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
  156. print(output_best)
  157. log_training.write(output_best + '\n')
  158. log_training.flush()
  159. save_checkpoint({
  160. 'epoch': epoch + 1,
  161. 'arch': args.arch,
  162. 'state_dict': model.state_dict(),
  163. 'optimizer': optimizer.state_dict(),
  164. 'best_prec1': best_prec1,
  165. }, is_best)
  166. def train(train_loader, model, criterion, optimizer, epoch, log, tf_writer):
  167. batch_time = AverageMeter()
  168. data_time = AverageMeter()
  169. losses = AverageMeter()
  170. top1 = AverageMeter()
  171. top5 = AverageMeter()
  172. if args.no_partialbn:
  173. model.module.partialBN(False)
  174. else:
  175. model.module.partialBN(True)
  176. # switch to train mode
  177. model.train()
  178. end = time.time()
  179. for i, (input, target) in enumerate(train_loader):
  180. # measure data loading time
  181. data_time.update(time.time() - end)
  182. target = target.cuda()
  183. input_var = torch.autograd.Variable(input)
  184. target_var = torch.autograd.Variable(target)
  185. # compute output
  186. output = model(input_var)
  187. loss = criterion(output, target_var)
  188. # measure accuracy and record loss
  189. prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
  190. losses.update(loss.item(), input.size(0))
  191. top1.update(prec1.item(), input.size(0))
  192. top5.update(prec5.item(), input.size(0))
  193. # compute gradient and do SGD step
  194. loss.backward()
  195. if args.clip_gradient is not None:
  196. total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient)
  197. optimizer.step()
  198. optimizer.zero_grad()
  199. # measure elapsed time
  200. batch_time.update(time.time() - end)
  201. end = time.time()
  202. if i % args.print_freq == 0:
  203. output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
  204. 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  205. 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
  206. 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
  207. 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
  208. 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
  209. epoch, i, len(train_loader), batch_time=batch_time,
  210. data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1)) # TODO
  211. print(output)
  212. log.write(output + '\n')
  213. log.flush()
  214. tf_writer.add_scalar('loss/train', losses.avg, epoch)
  215. tf_writer.add_scalar('acc/train_top1', top1.avg, epoch)
  216. tf_writer.add_scalar('acc/train_top5', top5.avg, epoch)
  217. tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)
  218. def validate(val_loader, model, criterion, epoch, log=None, tf_writer=None):
  219. batch_time = AverageMeter()
  220. losses = AverageMeter()
  221. top1 = AverageMeter()
  222. top5 = AverageMeter()
  223. # switch to evaluate mode
  224. model.eval()
  225. end = time.time()
  226. with torch.no_grad():
  227. for i, (input, target) in enumerate(val_loader):
  228. target = target.cuda()
  229. # compute output
  230. output = model(input)
  231. loss = criterion(output, target)
  232. # measure accuracy and record loss
  233. prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
  234. losses.update(loss.item(), input.size(0))
  235. top1.update(prec1.item(), input.size(0))
  236. top5.update(prec5.item(), input.size(0))
  237. # measure elapsed time
  238. batch_time.update(time.time() - end)
  239. end = time.time()
  240. if i % args.print_freq == 0:
  241. output = ('Test: [{0}/{1}]\t'
  242. 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  243. 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
  244. 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
  245. 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
  246. i, len(val_loader), batch_time=batch_time, loss=losses,
  247. top1=top1, top5=top5))
  248. print(output)
  249. if log is not None:
  250. log.write(output + '\n')
  251. log.flush()
  252. output = ('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
  253. .format(top1=top1, top5=top5, loss=losses))
  254. print(output)
  255. if log is not None:
  256. log.write(output + '\n')
  257. log.flush()
  258. if tf_writer is not None:
  259. tf_writer.add_scalar('loss/test', losses.avg, epoch)
  260. tf_writer.add_scalar('acc/test_top1', top1.avg, epoch)
  261. tf_writer.add_scalar('acc/test_top5', top5.avg, epoch)
  262. return top1.avg
  263. def save_checkpoint(state, is_best):
  264. filename = '%s/%s/ckpt.pth.tar' % (args.root_model, args.store_name)
  265. torch.save(state, filename)
  266. if is_best:
  267. shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar'))
  268. def adjust_learning_rate(optimizer, epoch, lr_type, lr_steps):
  269. """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
  270. if lr_type == 'step':
  271. decay = 0.1 ** (sum(epoch >= np.array(lr_steps)))
  272. lr = args.lr * decay
  273. decay = args.weight_decay
  274. elif lr_type == 'cos':
  275. import math
  276. lr = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.epochs))
  277. decay = args.weight_decay
  278. else:
  279. raise NotImplementedError
  280. for param_group in optimizer.param_groups:
  281. param_group['lr'] = lr * param_group['lr_mult']
  282. param_group['weight_decay'] = decay * param_group['decay_mult']
  283. def check_rootfolders():
  284. """Create log and model folder"""
  285. folders_util = [args.root_log, args.root_model,
  286. os.path.join(args.root_log, args.store_name),
  287. os.path.join(args.root_model, args.store_name)]
  288. for folder in folders_util:
  289. if not os.path.exists(folder):
  290. print('creating folder ' + folder)
  291. os.mkdir(folder)
  292. if __name__ == '__main__':
  293. main()

opts类代码如下

  1. #这里下面的参数应该要自行输入
  2. import argparse
  3. parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks")
  4. parser.add_argument('dataset', default="")
  5. parser.add_argument('modality', default="RGB", choices=['RGB', 'Flow'])
  6. parser.add_argument('--train_list', type=str, default="")
  7. parser.add_argument('--val_list', type=str, default="")
  8. parser.add_argument('--root_path', type=str, default="")
  9. parser.add_argument('--store_name', type=str, default="")
  10. # ========================= Model Configs ==========================
  11. parser.add_argument('--arch', type=str, default="BNInception")
  12. parser.add_argument('--num_segments', type=int, default=3)
  13. parser.add_argument('--consensus_type', type=str, default='avg')
  14. parser.add_argument('--k', type=int, default=3)
  15. parser.add_argument('--dropout', '--do', default=0.5, type=float,
  16. metavar='DO', help='dropout ratio (default: 0.5)')
  17. parser.add_argument('--loss_type', type=str, default="nll",
  18. choices=['nll'])
  19. parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame")
  20. parser.add_argument('--suffix', type=str, default=None)
  21. parser.add_argument('--pretrain', type=str, default='imagenet')
  22. parser.add_argument('--tune_from', type=str, default=None, help='fine-tune from checkpoint')
  23. # ========================= Learning Configs ==========================
  24. parser.add_argument('--epochs', default=120, type=int, metavar='N',
  25. help='number of total epochs to run')
  26. parser.add_argument('-b', '--batch-size', default=128, type=int,
  27. metavar='N', help='mini-batch size (default: 256)')
  28. parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
  29. metavar='LR', help='initial learning rate')
  30. parser.add_argument('--lr_type', default='step', type=str,
  31. metavar='LRtype', help='learning rate type')
  32. parser.add_argument('--lr_steps', default=[50, 100], type=float, nargs="+",
  33. metavar='LRSteps', help='epochs to decay learning rate by 10')
  34. parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
  35. help='momentum')
  36. parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
  37. metavar='W', help='weight decay (default: 5e-4)')
  38. parser.add_argument('--clip-gradient', '--gd', default=None, type=float,
  39. metavar='W', help='gradient norm clipping (default: disabled)')
  40. parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true")
  41. # ========================= Monitor Configs ==========================
  42. parser.add_argument('--print-freq', '-p', default=20, type=int,
  43. metavar='N', help='print frequency (default: 10)')
  44. parser.add_argument('--eval-freq', '-ef', default=5, type=int,
  45. metavar='N', help='evaluation frequency (default: 5)')
  46. # ========================= Runtime Configs ==========================
  47. parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
  48. help='number of data loading workers (default: 8)')
  49. parser.add_argument('--resume', default='', type=str, metavar='PATH',
  50. help='path to latest checkpoint (default: none)')
  51. parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
  52. help='evaluate model on validation set')
  53. parser.add_argument('--snapshot_pref', type=str, default="")
  54. parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
  55. help='manual epoch number (useful on restarts)')
  56. parser.add_argument('--gpus', nargs='+', type=int, default=None)
  57. parser.add_argument('--flow_prefix', default="", type=str)
  58. parser.add_argument('--root_log',type=str, default='log')
  59. parser.add_argument('--root_model', type=str, default='checkpoint')
  60. parser.add_argument('--shift', default=False, action="store_true", help='use shift for models')
  61. parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)')
  62. parser.add_argument('--shift_place', default='blockres', type=str, help='place for shift (default: stageres)')
  63. parser.add_argument('--temporal_pool', default=False, action="store_true", help='add temporal pooling')
  64. parser.add_argument('--non_local', default=False, action="store_true", help='add non local block')
  65. parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample for video dataset')

test_models类代码如下

  1. # Notice that this file has been modified to support ensemble testing
  2. from ops.transforms import *
  3. from ops import dataset_config
  4. from torch.nn import functional as F
  5. # options
  6. parser = argparse.ArgumentParser(description="TSM testing on the full validation set")
  7. parser.add_argument('dataset', type=str)
  8. # may contain splits
  9. pars
  10. parser.add_argument('--test_crops', type=int, default=1)
  11. parser.add_argument('--coeff', type=str, default=None)
  12. parser.add_argument('--batch_size', type=int, default=1)
  13. parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
  14. help='number of data loading workers (default: 8)')
  15. # for true test
  16. parser.add_argument('--test_list', type=str, default=None)
  17. parser.add_argument('--csv_file', type=str, default=None)
  18. parser.add_argument('--softmax', default=False, action="store_true", help='use softmax')
  19. parser.add_argument('--max_num', type=int, default=-1)
  20. parser.add_argument('--input_size', type=int, default=224)
  21. parser.add_argument('--crop_fusion_type', type=str, default='avg')
  22. parser.add_argument('--gpus', nargs='+', type=int, default=None)
  23. parser.add_argument('--img_feature_dim',type=int, default=256)
  24. parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video')
  25. parser.add_argument('--pretrain', type=str, default='imagenet')
  26. args = parser.parse_args()
  27. class AverageMeter(object):
  28. """Computes and stores the average and current value"""
  29. def __init__(self):
  30. self.reset()
  31. def reset(self):
  32. self.val = 0
  33. self.avg = 0
  34. self.sum = 0
  35. self.count = 0
  36. def update(self, val, n=1):
  37. self.val = val
  38. self.sum += val * n
  39. self.count += n
  40. self.avg = self.sum / self.count
  41. def accuracy(output, target, topk=(1,)):
  42. """Computes the precision@k for the specified values of k"""
  43. maxk = max(topk)
  44. batch_size = target.size(0)
  45. _, pred = output.topk(maxk, 1, True, True)
  46. pred = pred.t()
  47. correct = pred.eq(target.view(1, -1).expand_as(pred))
  48. res = []
  49. for k in topk:
  50. correct_k = correct[:k].view(-1).float().sum(0)
  51. res.append(correct_k.mul_(100.0 / batch_size))
  52. return res
  53. def parse_shift_option_from_log_name(log_name):
  54. if 'shift' in log_name:
  55. strings = log_name.split('_')
  56. for i, s in enumerate(strings):
  57. if 'shift' in s:
  58. break
  59. return True, int(strings[i].replace('shift', '')), strings[i + 1]
  60. else:
  61. return False, None, None
  62. weights_list = args.weights.split(',')
  63. test_segments_list = [int(s) for s in args.test_segments.split(',')]
  64. assert len(weights_list) == len(test_segments_list)
  65. if args.coeff is None:
  66. coeff_list = [1] * len(weights_list)
  67. else:
  68. coeff_list = [float(c) for c in args.coeff.split(',')]
  69. if args.test_list is not None:
  70. test_file_list = args.test_list.split(',')
  71. else:
  72. test_file_list = [None] * len(weights_list)
  73. data_iter_list = []
  74. net_list = []
  75. modality_list = []
  76. total_num = None
  77. for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):
  78. is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights)
  79. if 'RGB' in this_weights:
  80. modality = 'RGB'
  81. else:
  82. modality = 'Flow'
  83. this_arch = this_weights.split('TSM_')[1].split('_')[2]
  84. modality_list.append(modality)
  85. num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset,
  86. modality)
  87. print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place))
  88. net = TSN(num_class, this_test_segments if is_shift else 1, modality,
  89. base_model=this_arch,
  90. consensus_type=args.crop_fusion_type,
  91. img_feature_dim=args.img_feature_dim,
  92. pretrain=args.pretrain,
  93. is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
  94. non_local='_nl' in this_weights,
  95. )
  96. if 'tpool' in this_weights:
  97. from ops.temporal_shift import make_temporal_pool
  98. make_temporal_pool(net.base_model, this_test_segments) # since DataParallel
  99. checkpoint = torch.load(this_weights)
  100. checkpoint = checkpoint['state_dict']
  101. # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
  102. base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
  103. replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
  104. 'base_model.classifier.bias': 'new_fc.bias',
  105. }
  106. for k, v in replace_dict.items():
  107. if k in base_dict:
  108. base_dict[v] = base_dict.pop(k)
  109. net.load_state_dict(base_dict)
  110. input_size = net.scale_size if args.full_res else net.input_size
  111. if args.test_crops == 1:
  112. cropping = torchvision.transforms.Compose([
  113. GroupScale(net.scale_size),
  114. GroupCenterCrop(input_size),
  115. ])
  116. elif args.test_crops == 3: # do not flip, so only 5 crops
  117. cropping = torchvision.transforms.Compose([
  118. GroupFullResSample(input_size, net.scale_size, flip=False)
  119. ])
  120. elif args.test_crops == 5: # do not flip, so only 5 crops
  121. cropping = torchvision.transforms.Compose([
  122. GroupOverSample(input_size, net.scale_size, flip=False)
  123. ])
  124. elif args.test_crops == 10:
  125. cropping = torchvision.transforms.Compose([
  126. GroupOverSample(input_size, net.scale_size)
  127. ])
  128. else:
  129. raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops))
  130. data_loader = torch.utils.data.DataLoader(
  131. TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments,
  132. new_length=1 if modality == "RGB" else 5,
  133. modality=modality,
  134. image_tmpl=prefix,
  135. test_mode=True,
  136. remove_missing=len(weights_list) == 1,
  137. transform=torchvision.transforms.Compose([
  138. cropping,
  139. Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
  140. ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])),
  141. GroupNormalize(net.input_mean, net.input_std),
  142. ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample),
  143. batch_size=args.batch_size, shuffle=False,
  144. num_workers=args.workers, pin_memory=True,
  145. )
  146. if args.gpus is not None:
  147. devices = [args.gpus[i] for i in range(args.workers)]
  148. else:
  149. devices = list(range(args.workers))
  150. net = torch.nn.DataParallel(net.cuda())
  151. net.eval()
  152. data_gen = enumerate(data_loader)
  153. if total_num is None:
  154. total_num = len(data_loader.dataset)
  155. else:
  156. assert total_num == len(data_loader.dataset)
  157. data_iter_list.append(data_gen)
  158. net_list.append(net)
  159. output = []
  160. def eval_video(video_data, net, this_test_segments, modality):
  161. net.eval()
  162. with torch.no_grad():
  163. i, data, label = video_data
  164. batch_size = label.numel()
  165. num_crop = args.test_crops
  166. if args.dense_sample:
  167. num_crop *= 10 # 10 clips for testing when using dense sample
  168. if args.twice_sample:
  169. num_crop *= 2
  170. if modality == 'RGB':
  171. length = 3
  172. elif modality == 'Flow':
  173. length = 10
  174. elif modality == 'RGBDiff':
  175. length = 18
  176. else:
  177. raise ValueError("Unknown modality "+ modality)
  178. data_in = data.view(-1, length, data.size(2), data.size(3))
  179. if is_shift:
  180. data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3))
  181. rst = net(data_in)
  182. rst = rst.reshape(batch_size, num_crop, -1).mean(1)
  183. if args.softmax:
  184. # take the softmax to normalize the output to probability
  185. rst = F.softmax(rst, dim=1)
  186. rst = rst.data.cpu().numpy().copy()
  187. if net.module.is_shift:
  188. rst = rst.reshape(batch_size, num_class)
  189. else:
  190. rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))
  191. return i, rst, label
  192. proc_start_time = time.time()
  193. max_num = args.max_num if args.max_num > 0 else total_num
  194. top1 = AverageMeter()
  195. top5 = AverageMeter()
  196. for i, data_label_pairs in enumerate(zip(*data_iter_list)):
  197. with torch.no_grad():
  198. if i >= max_num:
  199. break
  200. this_rst_list = []
  201. this_label = None
  202. for n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list):
  203. rst = eval_video((i, data, label), net, n_seg, modality)
  204. this_rst_list.append(rst[1])
  205. this_label = label
  206. assert len(this_rst_list) == len(coeff_list)
  207. for i_coeff in range(len(this_rst_list)):
  208. this_rst_list[i_coeff] *= coeff_list[i_coeff]
  209. ensembled_predict = sum(this_rst_list) / len(this_rst_list)
  210. for p, g in zip(ensembled_predict, this_label.cpu().numpy()):
  211. output.append([p[None, ...], g])
  212. cnt_time = time.time() - proc_start_time
  213. prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5))
  214. top1.update(prec1.item(), this_label.numel())
  215. top5.update(prec5.item(), this_label.numel())
  216. if i % 20 == 0:
  217. print('video {} done, total {}/{}, average {:.3f} sec/video, '
  218. 'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num,
  219. float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg))
  220. video_pred = [np.argmax(x[0]) for x in output]
  221. video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output]
  222. video_labels = [x[1] for x in output]
  223. if args.csv_file is not None:
  224. print('=> Writing result to csv file: {}'.format(args.csv_file))
  225. with open(test_file_list[0].replace('test_videofolder.txt', 'category.txt')) as f:
  226. categories = f.readlines()
  227. categories = [f.strip() for f in categories]
  228. with open(test_file_list[0]) as f:
  229. vid_names = f.readlines()
  230. vid_names = [n.split(' ')[0] for n in vid_names]
  231. assert len(vid_names) == len(video_pred)
  232. if args.dataset != 'somethingv2': # only output top1
  233. with open(args.csv_file, 'w') as f:
  234. for n, pred in zip(vid_names, video_pred):
  235. f.write('{};{}\n'.format(n, categories[pred]))
  236. else:
  237. with open(args.csv_file, 'w') as f:
  238. for n, pred5 in zip(vid_names, video_pred_top5):
  239. fill = [n]
  240. for p in list(pred5):
  241. fill.append(p)
  242. f.write('{};{};{};{};{};{}\n'.format(*fill))
  243. cf = confusion_matrix(video_labels, video_pred).astype(float)
  244. np.save('cm.npy', cf)
  245. cls_cnt = cf.sum(axis=1)
  246. cls_hit = np.diag(cf)
  247. cls_acc = cls_hit / cls_cnt
  248. print(cls_acc)
  249. upper = np.mean(np.max(cf, axis=1) / cls_cnt)
  250. print('upper bound: {}'.format(upper))
  251. print('-----Evaluation is finished------')
  252. print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))
  253. print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg))

创作不易 觉得有帮助请点赞关注收藏~~~

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/726964
推荐阅读
相关标签
  

闽ICP备14008679号