当前位置:   article > 正文

temporal shift module(TSM)

temporal shift module

【官方】Paddle2.1实现视频理解经典模型 — TSM - 飞桨AI Studio本项目将带大家深入理解视频理解领域经典模型TSM。从模型理论讲解入手,深入到代码实践。实践部分基于TSM模型在UCF101数据集上从训练到推理全流程实现行为识别任务。 - 飞桨AI Studiohttps://aistudio.baidu.com/aistudio/projectdetail/2310889?channelType=0&channel=0视频理解:基于TSM实现UCF101视频理解 - 飞桨AI Studio基于飞桨开源框架构建TSM,并实现对数据集UCF101的视频理解。 - 飞桨AI Studiohttps://aistudio.baidu.com/aistudio/projectdetail/4114499?channelType=0&channel=0

最近一直在做视频相关的项目,后续会陆续出一些视频理解和视频场景运动的案例,视频这块主推paddlevideo,里面应用层面的东西很丰富,paddle在应用侧一直做的比较好,模型训练这块可以结合mmaction2来,其实从实际应用角度来说,我觉得用paddle和pytorch训练都无所谓,部署的话可能以往我的经验更多是onnx,tensort服务侧的,目前来看,主要也就是服务器,端侧和页面侧的部署这三块,我看paddle分别有paddle inference、lite、js,国产框架中确实是首屈一指的,但是我自己的感觉是从我以前训练gan的结果看,paddle貌似要比pytorch的结果,一样的数据,一样的参数配置,好像要差一点。本文主要介绍一下tsm模块,利用2dcnn来模拟时序信息。视频中核心是视频动作识别,本质就是视频分类,可以用作特征提取,视频时序提取是输入一段长视频获取其中的时序片段,时空定位是同时获取视频中的人物物体的空间位置,核心三大任务,除此之外视频特征提取embedding,这块主要是结合多模态去做,视频,音频和文本侧特征的综合利用和提取。

1.时序信息维度 

上述这个视频序列从左向右播放和从右向左播放表达的意思是不同的,视频理解对视频顺序是强依赖的。

2.temporal shift module

这个模块是核心,其实tsm是可插拔模块,是可以很好的嵌入到resnet等模型中,上述图中,一种颜色是一帧,按照时序T上,一共是四帧,同一帧横向是一个channel,在cnn中channel是统一做cnn的,在a图中是没有shift的,在b中是离线shift操作,可见将channel中第一个向下移动,第二个向上移动,其实至于上下移动几个channel并没有很严的的限制,通常是分成几等分去移动,这样上下移动之后,则第一个channel会向下突出一帧,第二个channel会向上突出一帧,突出帧直接截断,空缺帧直接补0,这样在横向做cnn时,统一channel维度变引入不同色的帧,tsm正是通过这种平移的方式,TSM在特征图中引入 temporal 维度上的上下文交互,通过通道移动操作可以使得在当前帧中包含了前后两帧的通道信息,这样再进2D卷积操作就能像3D卷积一样直接提取视频的时空信息,提高了模型在时间维度上的建模能力。而online模式用于对视频类型的实时预测,在这种情况下,无法预知下一秒的图像,因此只能将channel维度由过去向现在移动,而不能从未来向现在移动。

3.缺点和改进

虽然时间位移的原理很简单,但作者发现直接将空间位移策略应用于时间维度并不能提供高性能和效率。具体来说,如果简单的转移所有通道,则会带来两个问题:

  1. 由于大量数据移动而导致的效率下降问题。位移操作不需要计算但是会涉及数据移动,数据移动增加了硬件上的内存占用和推理延迟,作者观察到在视频理解网络中,当使用naive shift策略时,CPU延迟增加13.7%,GPU延迟增加12.4%,使整体推理变慢。
  2. 空间建模能力变差导致性能下降,由于部分通道被转移到相邻帧,当前帧不能再访问通道中包含的信息,这可能会损失2D CNN主干的空间建模能力。与TSN基线相比,使用naive shift会降低2.6%的准确率。

为了解决naive shift的两个问题,TSM给出了相应的解决方法。

  1. 减少数据移动。 为了研究数据移动的影响,作者测量了TSM模型在不同硬件设备上的推理延迟,作者移动了不同比例的通道数并测量了延迟,位移方式分为无位移、部分位移(位移1/8、1/4、1/2的通道)和全部位移,使用ResNet-50主干和8帧输入测量模型。作者观察到,如果移动所有的通道,那么延迟开销将占CPU推理时间的13.7%,如果只移动一小部分通道,如1/8,则可将开销限制在3%左右。       
  2. 保持空间特征学习能力。 一种简单的TSM使用方法是将其直接插入到每个卷基层或残差模块前,如 所示,这种方法被称为 in-place shift,但是它会损失主干模型的空间特征学习能力,尤其当我们移动大量通道时,存储在通道中的当前帧信息会随着通道移动而丢失。为解决这个问题,作者提出了另一种方法,即将TSM放在残差模块的残差分支中,这种方法被称为 residual TSM,如所示,它可以解决退化的空间特征学习问题,因为原始的激活信息在时间转移后仍可通过identity映射访问。

 4.mmaction2中的代码

  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. from mmcv.cnn import NonLocal3d
  5. from torch.nn.modules.utils import _ntuple
  6. from ..builder import BACKBONES
  7. from .resnet import ResNet
  8. class NL3DWrapper(nn.Module):
  9. """3D Non-local wrapper for ResNet50.
  10. Wrap ResNet layers with 3D NonLocal modules.
  11. Args:
  12. block (nn.Module): Residual blocks to be built.
  13. num_segments (int): Number of frame segments.
  14. non_local_cfg (dict): Config for non-local layers. Default: ``dict()``.
  15. """
  16. def __init__(self, block, num_segments, non_local_cfg=dict()):
  17. super(NL3DWrapper, self).__init__()
  18. self.block = block
  19. self.non_local_cfg = non_local_cfg
  20. self.non_local_block = NonLocal3d(self.block.conv3.norm.num_features,
  21. **self.non_local_cfg)
  22. self.num_segments = num_segments
  23. def forward(self, x):
  24. x = self.block(x)
  25. n, c, h, w = x.size()
  26. x = x.view(n // self.num_segments, self.num_segments, c, h,
  27. w).transpose(1, 2).contiguous()
  28. x = self.non_local_block(x)
  29. x = x.transpose(1, 2).contiguous().view(n, c, h, w)
  30. return x
  31. class TemporalShift(nn.Module):
  32. """Temporal shift module.
  33. This module is proposed in
  34. `TSM: Temporal Shift Module for Efficient Video Understanding
  35. <https://arxiv.org/abs/1811.08383>`_
  36. Args:
  37. net (nn.module): Module to make temporal shift.
  38. num_segments (int): Number of frame segments. Default: 3.
  39. shift_div (int): Number of divisions for shift. Default: 8.
  40. """
  41. def __init__(self, net, num_segments=3, shift_div=8):
  42. super().__init__()
  43. self.net = net
  44. self.num_segments = num_segments
  45. self.shift_div = shift_div
  46. def forward(self, x):
  47. """Defines the computation performed at every call.
  48. Args:
  49. x (torch.Tensor): The input data.
  50. Returns:
  51. torch.Tensor: The output of the module.
  52. """
  53. x = self.shift(x, self.num_segments, shift_div=self.shift_div)
  54. return self.net(x)
  55. @staticmethod
  56. def shift(x, num_segments, shift_div=3):
  57. """Perform temporal shift operation on the feature.
  58. Args:
  59. x (torch.Tensor): The input feature to be shifted.
  60. num_segments (int): Number of frame segments.
  61. shift_div (int): Number of divisions for shift. Default: 3.
  62. Returns:
  63. torch.Tensor: The shifted feature.
  64. """
  65. # 假设当前feature map的通道是256,shift_div=3,
  66. # 那么就有256/3的特征进行shift left,256/3的特征进行shift right,其他一部分特征不动
  67. # num_segments每个视频采样的帧数
  68. # 每帧有c个通道,
  69. # [
  70. # [0_1,0_2,0_3,1_1,1_2,3_5,3_6,3_7] 第一帧,8个通道,但是shift_div表示这个通道维度被切分成3个等分
  71. # [] 第二帧
  72. # [] 第三帧
  73. # ]
  74. # [N, C, H, W]
  75. n, c, h, w = x.size()
  76. # [N // num_segments, num_segments, C, H*W]
  77. # can't use 5 dimensional array on PPL2D backend for caffe
  78. x = x.view(-1, num_segments, c, h * w)
  79. # get shift fold
  80. fold = c // shift_div
  81. # split c channel into three parts:
  82. # left_split, mid_split, right_split
  83. left_split = x[:, :, :fold, :]
  84. mid_split = x[:, :, fold:2 * fold, :]
  85. right_split = x[:, :, 2 * fold:, :]
  86. # can't use torch.zeros(*A.shape) or torch.zeros_like(A)
  87. # because array on caffe inference must be got by computing
  88. # shift left on num_segments channel in `left_split`
  89. zeros = left_split - left_split
  90. blank = zeros[:, :1, :, :]
  91. left_split = left_split[:, 1:, :, :]
  92. left_split = torch.cat((left_split, blank), 1)
  93. # shift right on num_segments channel in `mid_split`
  94. zeros = mid_split - mid_split
  95. blank = zeros[:, :1, :, :]
  96. mid_split = mid_split[:, :-1, :, :]
  97. mid_split = torch.cat((blank, mid_split), 1)
  98. # right_split: no shift
  99. # concatenate
  100. out = torch.cat((left_split, mid_split, right_split), 2)
  101. # [N, C, H, W]
  102. # restore the original dimension
  103. return out.view(n, c, h, w)
  104. @BACKBONES.register_module()
  105. class ResNetTSM(ResNet):
  106. """ResNet backbone for TSM.
  107. Args:
  108. num_segments (int): Number of frame segments. Default: 8.
  109. is_shift (bool): Whether to make temporal shift in reset layers.
  110. Default: True.
  111. non_local (Sequence[int]): Determine whether to apply non-local module
  112. in the corresponding block of each stages. Default: (0, 0, 0, 0).
  113. non_local_cfg (dict): Config for non-local module. Default: ``dict()``.
  114. shift_div (int): Number of div for shift. Default: 8.
  115. shift_place (str): Places in resnet layers for shift, which is chosen
  116. from ['block', 'blockres'].
  117. If set to 'block', it will apply temporal shift to all child blocks
  118. in each resnet layer.
  119. If set to 'blockres', it will apply temporal shift to each `conv1`
  120. layer of all child blocks in each resnet layer.
  121. Default: 'blockres'.
  122. temporal_pool (bool): Whether to add temporal pooling. Default: False.
  123. **kwargs (keyword arguments, optional): Arguments for ResNet.
  124. """
  125. def __init__(self,
  126. depth,
  127. num_segments=8,
  128. is_shift=True,
  129. non_local=(0, 0, 0, 0),
  130. non_local_cfg=dict(),
  131. shift_div=8,
  132. shift_place='blockres',
  133. temporal_pool=False,
  134. **kwargs):
  135. super().__init__(depth, **kwargs)
  136. self.num_segments = num_segments
  137. self.is_shift = is_shift
  138. self.shift_div = shift_div
  139. self.shift_place = shift_place
  140. self.temporal_pool = temporal_pool
  141. self.non_local = non_local
  142. self.non_local_stages = _ntuple(self.num_stages)(non_local)
  143. self.non_local_cfg = non_local_cfg
  144. def make_temporal_shift(self):
  145. """Make temporal shift for some layers."""
  146. if self.temporal_pool:
  147. num_segment_list = [
  148. self.num_segments, self.num_segments // 2,
  149. self.num_segments // 2, self.num_segments // 2
  150. ]
  151. else:
  152. num_segment_list = [self.num_segments] * 4
  153. if num_segment_list[-1] <= 0:
  154. raise ValueError('num_segment_list[-1] must be positive')
  155. if self.shift_place == 'block':
  156. def make_block_temporal(stage, num_segments):
  157. """Make temporal shift on some blocks.
  158. Args:
  159. stage (nn.Module): Model layers to be shifted.
  160. num_segments (int): Number of frame segments.
  161. Returns:
  162. nn.Module: The shifted blocks.
  163. """
  164. blocks = list(stage.children())
  165. for i, b in enumerate(blocks):
  166. blocks[i] = TemporalShift(
  167. b, num_segments=num_segments, shift_div=self.shift_div)
  168. return nn.Sequential(*blocks)
  169. self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])
  170. self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])
  171. self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])
  172. self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])
  173. elif 'blockres' in self.shift_place:
  174. n_round = 1
  175. if len(list(self.layer3.children())) >= 23:
  176. n_round = 2
  177. def make_block_temporal(stage, num_segments):
  178. """Make temporal shift on some blocks.
  179. Args:
  180. stage (nn.Module): Model layers to be shifted.
  181. num_segments (int): Number of frame segments.
  182. Returns:
  183. nn.Module: The shifted blocks.
  184. """
  185. blocks = list(stage.children())
  186. for i, b in enumerate(blocks):
  187. if i % n_round == 0:
  188. blocks[i].conv1.conv = TemporalShift(
  189. b.conv1.conv,
  190. num_segments=num_segments,
  191. shift_div=self.shift_div)
  192. return nn.Sequential(*blocks)
  193. self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])
  194. self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])
  195. self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])
  196. self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])
  197. else:
  198. raise NotImplementedError
  199. def make_temporal_pool(self):
  200. """Make temporal pooling between layer1 and layer2, using a 3D max
  201. pooling layer."""
  202. class TemporalPool(nn.Module):
  203. """Temporal pool module.
  204. Wrap layer2 in ResNet50 with a 3D max pooling layer.
  205. Args:
  206. net (nn.Module): Module to make temporal pool.
  207. num_segments (int): Number of frame segments.
  208. """
  209. def __init__(self, net, num_segments):
  210. super().__init__()
  211. self.net = net
  212. self.num_segments = num_segments
  213. self.max_pool3d = nn.MaxPool3d(
  214. kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
  215. def forward(self, x):
  216. # [N, C, H, W]
  217. n, c, h, w = x.size()
  218. # [N // num_segments, C, num_segments, H, W]
  219. x = x.view(n // self.num_segments, self.num_segments, c, h,
  220. w).transpose(1, 2)
  221. # [N // num_segmnets, C, num_segments // 2, H, W]
  222. x = self.max_pool3d(x)
  223. # [N // 2, C, H, W]
  224. x = x.transpose(1, 2).contiguous().view(n // 2, c, h, w)
  225. return self.net(x)
  226. self.layer2 = TemporalPool(self.layer2, self.num_segments)
  227. def make_non_local(self):
  228. # This part is for ResNet50
  229. for i in range(self.num_stages):
  230. non_local_stage = self.non_local_stages[i]
  231. if sum(non_local_stage) == 0:
  232. continue
  233. layer_name = f'layer{i + 1}'
  234. res_layer = getattr(self, layer_name)
  235. for idx, non_local in enumerate(non_local_stage):
  236. if non_local:
  237. res_layer[idx] = NL3DWrapper(res_layer[idx],
  238. self.num_segments,
  239. self.non_local_cfg)
  240. def init_weights(self):
  241. """Initiate the parameters either from existing checkpoint or from
  242. scratch."""
  243. super().init_weights()
  244. if self.is_shift:
  245. self.make_temporal_shift()
  246. if len(self.non_local_cfg) != 0:
  247. self.make_non_local()
  248. if self.temporal_pool:
  249. self.make_temporal_pool()

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/985386
推荐阅读
相关标签
  

闽ICP备14008679号