当前位置:   article > 正文

视频理解TSM的训练与使用_tsm视频

tsm视频

视频理解TSM的训练与使用

tsm的github地址

总体评价:tsm是一个理解不难但效果优秀的视频理解模型,在我的视频分类任务中,其效果基本达到了使用要求。相比我在github上跑通的其他模型,tsm是最好的。百度团队在不久前也推出了pp-tsm,精度相比tsm提升了几个百分点,我也克隆并调试了,不过非常惭愧,训练模型没有跑通,以后有时间的话会再进行研究。

训练

训练方面我也是借鉴了其他优秀作者的建议,这里给出链接,大家可以参考他的步骤开始自己的训练。链接地址
先要强调的是,本人modality选择的是“RGB”,没有flow之类,感兴趣的可能要自己研究下了。
可以大概说一下tsm的训练原理,对于一个属于某类的视频,我们通过ffmpeg,或者opencv对视频进行抽帧,将一个视频的每一帧的图片根据排序存储至一个文件夹,在训练的采样阶段,模型对一个文件夹一定随机抽取n张(n默认为8)图片,进行concat操作,将concat后的tentor张量作为输入,视频转图片文件夹的文类作为标签,放入网络进行训练。

训练技巧:
1.更改num_segements:
num_segments即为对每一个视频转图片文件夹的采样张数,对于更多的采样,输入可以包含更多的特征信息,所以一般来说将这个参数增大可以提升模型的性能。
2.更改对图片信息的采样压缩方式:
在原始的tsm的源码中,对训练数据进行采样的是

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ]), dense_sample=args.dense_sample),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True,
        drop_last=True) 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

其中的train_augmention包括:

    def get_augmentation(self, flip=True):
        if self.modality == 'RGB':
            if flip:
                return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]),
                                                       GroupRandomHorizontalFlip(is_flow=False)])
  • 1
  • 2
  • 3
  • 4
  • 5

对与图片的裁减与缩放操作就是在Compose的GroupMultiScaleCrop中实现的。这里展示以下关键代码:

    def __call__(self, img_group):

        im_size = img_group[0].size

        crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
        crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
        ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
                         for img in crop_img_group]
        return ret_img_group
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

我这里之贴出了部分,源码大家可以自己看一下,大概的内容就是,对一张原始的单张图片,在原图中以一定的偏移确定裁减的区域,裁减后再进行resize操作(默认为224x224)。
tsm的demo给出的是手势的识别,在这样的任务前提下,图像的大小以及图像的边缘信息似乎没有那么重要,然而,在更复杂任务的时候,一张图像的边缘也包含了重要的特征信息,且如果图像太小会损失重要的特征信息,根据这两点,我重新写了一个图片压缩的类(名字随意),其将图像resize到制定大小,并通过填充黑边保持原图像的形状。

class GroupScale_hyj(object):  
    def __init__(self,input_size):
        self.input_size = input_size
        self.interpolation = Image.BILINEAR

    # @classmethod
    def _black_resize_img(self,ori_img):

        new_size = self.input_size
        ori_img.thumbnail((new_size,new_size))
        w2,h2 = ori_img.size
        bg_img = Image.new('RGB',(new_size,new_size),(0,0,0))
        if w2 == new_size:
            bg_img.paste(ori_img, (0, int((new_size - h2) / 2)))
        elif h2 == new_size:
            bg_img.paste(ori_img, (int((new_size - w2) / 2), 0))
        else:
            bg_img.paste(ori_img, (int((new_size - w2) / 2), (int((new_size - h2) / 2))))

        return bg_img

    def __call__(self,img_group):

        ret_img_group = [self._black_resize_img(img) for img in img_group]

        return ret_img_group
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

train_loader与val_loader均可使用,为了保证我的显存能够带动,我选取了图片大小为320,替换后的代码为:

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       GroupScale_hyj(input_size=320),
                       GroupAugmentor(),  #img_augmentor for the train data
                       GroupRandomHorizontalFlip(is_flow=False),
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ]), dense_sample=args.dense_sample),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU
    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale_hyj(input_size=320),
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ]), dense_sample=args.dense_sample),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

再贴两张对比图(图一是tsm原始的图片缩放,图二是修改后的)
图1
图2修改后的可以保存更多的特征信息。
3.保证采样的一串图片其对应的时间一致(个人觉得,欢迎指正):
我们是对一个视频转图片再抽取一定数量的帧(默认为8),如果我们是对一批数据训练的话,那我们应该要保证每一个文件夹抽取的图片所代表的时间长度是固定的,比如我们规定以2s为基本的时间长,那么每个文件夹抽取的第一张到最后一张所经历的时间应该接近2s,意思就是,我在2s的时间里,对该视频的行为进行分类。由此,当我们的视频数据集有不同的fps时,我们就要调整,使得抽取的一串图片经历时间都接近2s。

测试/使用

官方给出了一个手势识别的demo,想要成功运行的话,可以参考我前面给出的作者的博客,亲测有效。
不过更多的,我们想将自己的视频分类任务进行测试,而这方面的参考代码比较少。经过了之前的训练,其实我们需要的就是对读入的视频进行抽帧采样,将图片放入dataset中,经模型输出一个分类向量,将分类向量对应的种类名称写在视频流上显示就行.所以关键其实就是:对视频抽帧采样;初始化模型并加载训练参数;采样图片转成model能接受的输入格式(效果等同于TSN_DATASET)以下是本人针对打架检测的使用代码:

import os
import time
from ops.models import TSN
from ops.transforms import *
import cv2
from PIL import Image

arch = 'mobilenetv2'
num_class = 2
num_segments = 8
modality = 'RGB'
base_model = 'mobilenetv2'
consensus_type='avg'
dataset = 'ucf101'
dropout = 0.1
img_feature_dim = 256
no_partialbn = True
pretrain = 'imagenet'
shift = True
shift_div = 8
shift_place = 'blockres'
temporal_pool = False
non_local = False
tune_from = None


#load model
model = TSN(num_class, num_segments, modality,
                base_model=arch,
                consensus_type=consensus_type,
                dropout=dropout,
                img_feature_dim=img_feature_dim,
                partial_bn=not no_partialbn,
                pretrain=pretrain,
                is_shift=shift, shift_div=shift_div, shift_place=shift_place,
                fc_lr5=not (tune_from and dataset in tune_from),
                temporal_pool=temporal_pool,
                non_local=non_local)

model = torch.nn.DataParallel(model, device_ids=None).cuda()
resume = '/home/hyj/桌面/master_projects/temporal-shift-module-master/best_weights/mobilenet_360_93.916/ckpt.pth.tar' #  the last weights
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint['state_dict'])
model.eval()

#how to deal with the pictures
input_mean = [0.485, 0.456, 0.406]
input_std = [0.229, 0.224, 0.225]
normalize = GroupNormalize(input_mean, input_std)
transform_hyj = torchvision.transforms.Compose([
    GroupScale_hyj(input_size=320),
    Stack(roll=(arch in ['BNInception', 'InceptionV3'])),
    ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])),
    normalize,
])

video_path = '/home/hyj/桌面/master_projects/temporal-shift-module-master/test_videos/fight5.mp4'

pil_img_list = list()

cls_text = ['nofight','fight']
cls_color = [(0,255,0),(0,0,255)]

import time

cap = cv2.VideoCapture(video_path) #导入的视频所在路径
start_time = time.time()
counter = 0
frame_numbers = 0
training_fps = 30
training_time = 2.5
fps = cap.get(cv2.CAP_PROP_FPS) #视频平均帧率
if fps < 1:
    fps = 30
duaring = int(fps * training_time / num_segments)
print(duaring)
# exit()


state = 0
while cap.isOpened():
    ret, frame = cap.read()
    if ret:
        frame_numbers+=1
        print(frame_numbers)
        # print(len(pil_img_list))
        if frame_numbers%duaring == 0 and len(pil_img_list)<8:
            frame_pil = Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB))
            pil_img_list.extend([frame_pil])
        if frame_numbers%duaring == 0 and  len(pil_img_list)==8:
            frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            pil_img_list.pop(0)
            pil_img_list.extend([frame_pil])
            input = transform_hyj(pil_img_list)
            input = input.unsqueeze(0).cuda()
            out = model(input)
            print(out)
            output_index = int(torch.argmax(out).cpu())
            state = output_index

        #键盘输入空格暂停,输入q退出
        key = cv2.waitKey(1) & 0xff
        if key == ord(" "):
            cv2.waitKey(0)
        if key == ord("q"):
            break
        counter += 1#计算帧数
        if (time.time() - start_time) != 0:#实时显示帧数
            cv2.putText(frame, "{0} {1}".format((cls_text[state]),float('%.1f' % (counter / (time.time() - start_time)))), (50, 50),cv2.FONT_HERSHEY_SIMPLEX, 2, cls_color[state],3)
            cv2.imshow('frame', frame)

            counter = 0
            start_time = time.time()
        time.sleep(1 / fps)#按原帧率播放
        # time.sleep(2/fps)# observe the output
    else:
        break

cap.release()
cv2.destroyAllWindows()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121

打架识别运行效果(我的基准时间为2.5s,即行为进行2.5s后判断其分类,所以从视觉上会感觉到一定的延迟,ubuntu录视频软件不好找,直接手机录了)
在这里插入图片描述
最后提醒一下,当你的分类数目少于5时,需要将main.py中top5的代码去掉,不然会报错。欢迎各位讨论与建议。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/985338
推荐阅读
相关标签
  

闽ICP备14008679号