当前位置:   article > 正文

Non-Local Network理解与Mindspore框架下的实现_nonlocal network

nonlocal network

可执行案例参考:Non-Local Notebook

一、Non-Local Network 算法原理介绍

non-local

Nonlocal这一词是相对于local而来的,那我们就先说说local。论文Non-local Neural Networks(王小龙,CVPR2018)中所提到的local其实是针对感受野(receptive field)而言的,例如在卷积操作中,卷积核的大小就是感受野的大小,但是卷积核的大小一般都比较小,最常用的3*3,5*5等等,只能感受局部区域,因此称为local。Nonlocal则是更大的感受野,并非一个局部区域。

论文中的nonlocal将某一位置的响应当做是一种从特征图谱所有位置的加权和来计算,这些位置既可以代表空间位置, 也可以代表时间, 时空等。Nonlocal其实和self-attention机制十分相关。在文中,为了能够将提出的nonlocal block当作一个组件自由的接入到各个神经网络中,作者设计的nonlocal 操作使得输入输出大小一致,具体实现公式如下:

公式中,x代表输入,y代表输出,i和j分别代表输入的某个空间位置,xi是一个向量,维数跟x的channel数一样,f是一个计算任意两点相似关系的函数,g是一个映射函数,将一个点映射成一个向量,即该点的特征。为了计算输出层的一个点,需要将输入的每个点都考虑一遍,考虑的方式就和attention机制类似:过程中mask则是根据f函数给出,再和g映射函数相乘,最后求和,输出的某个点在原图上的attention。每个点以这样的方式计算,最后得到一个nonlocal的“attention map”。

图1

图1中,θ和Φ来自于f函数,g即g函数。文中,关于g函数,作者设计为1*1*1的卷积。关于f函数,则有四种相似度量函数可供选择:Gaussian、Embedded Gaussian、Dot Product、Concatenation。

Gaussian function的公式如下:

Embedded Gaussian的公式如下:

Dot product的公式如下:

Concatenation的公式如下:

以上就是nonlocal的算法原理的概要介绍了。

二、Minspore框架下的实现

1.NonlocalBlockND(nn.Cell)

Nonlocal最重要的结构是NonlocalBlockND(nn.Cell),该block包含四种成对相似度计算公式,以dot_product为例,主要通过三个Conv3d进行线性变换。NonlocalBlockND操作只需用到常用的卷积、矩阵相乘、加法、softmax等算子,用户可以非常方便的实现组网以构建模型。

  1. class NonLocalBlockND(nn.Cell):
  2. r"""
  3. Classification backbone for nonlocal.
  4. Implementation of Non-Local Block with 4 different pairwise functions.
  5. Applies Non-Local Block over 5D input (a mini-batch of 3D inputs with additional channel dimension).
  6. .. math::
  7. embedded_gaussian:
  8. f(x_i, x_j)=e^{\theta(x_i)^{T} \phi(x_j)}.
  9. gaussian:
  10. f(x_i, x_j)=e^{{x_i}^{T} {x_j}}.
  11. concatenation:
  12. f(x_i, x_j)=\{ReLU}({w_f}^{T}[\theta(x_i), \phi(x_j)]).
  13. dot_product:
  14. f(x_i, x_j)=\theta(x_i)^{T} \phi(x_j).
  15. Args:
  16. in_channels (int): original channel size.
  17. inter_channels (int): channel size inside the block if not specified reduced to half.
  18. mode: 4 mode to choose (gaussian, embedded, dot, and concatenation).
  19. bn_layer: whether to add batch norm.
  20. Inputs:
  21. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
  22. Outputs:
  23. Tensor of shape :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
  24. Examples:
  25. >>> net = nn.NonLocalBlockND(in_channels=3, bn_layer=bn_layer)
  26. >>> x = zeros((2, 3, 8, 20, 20), mindspore.float32)
  27. >>> output = net(x).shape
  28. >>> print(output)
  29. (2, 3, 8, 20, 20)
  30. """
  31. def __init__(
  32. self,
  33. in_channels,
  34. inter_channels=None,
  35. mode='embedded',
  36. sub_sample=True,
  37. bn_layer=True):
  38. super(NonLocalBlockND, self).__init__()
  39. if mode not in ['gaussian', 'embedded', 'dot', 'concatenation']:
  40. raise ValueError(
  41. '`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenation`')
  42. self.mode = mode
  43. self.transpose = ops.Transpose()
  44. self.batmatmul = ops.BatchMatMul()
  45. self.tile = ops.Tile()
  46. self.concat_op = ops.Concat(1)
  47. self.zeros = ops.Zeros()
  48. self.softmax = ops.Softmax(axis=-1)
  49. self.in_channels = in_channels
  50. self.inter_channels = inter_channels
  51. if self.inter_channels is None:
  52. self.inter_channels = in_channels // 2
  53. if self.inter_channels == 0:
  54. self.inter_channels = 1
  55. self.g = nn.Conv3d(in_channels=self.in_channels,
  56. out_channels=self.inter_channels,
  57. kernel_size=1,
  58. has_bias=True
  59. )
  60. if bn_layer:
  61. self.w = nn.SequentialCell(
  62. nn.Conv3d(in_channels=self.inter_channels,
  63. out_channels=self.in_channels,
  64. kernel_size=1
  65. ),
  66. nn.BatchNorm3d(self.in_channels)
  67. )
  68. else:
  69. self.w = nn.Conv3d(in_channels=self.inter_channels,
  70. out_channels=self.in_channels,
  71. kernel_size=1
  72. )
  73. if self.mode in ["embedded", "dot", "concatenation"]:
  74. self.theta = nn.Conv3d(in_channels=self.in_channels,
  75. out_channels=self.inter_channels,
  76. kernel_size=1,
  77. has_bias=True
  78. )
  79. self.phi = nn.Conv3d(in_channels=self.in_channels,
  80. out_channels=self.inter_channels,
  81. kernel_size=1,
  82. has_bias=True
  83. )
  84. if self.mode == "concatenation":
  85. self.concat_project = nn.SequentialCell(
  86. nn.Conv2d(
  87. self.inter_channels * 2,
  88. out_channels=1,
  89. kernel_size=1,
  90. pad_mode='same',
  91. has_bias=False),
  92. nn.ReLU()
  93. )
  94. if sub_sample:
  95. max_pool_layer = MaxPool3D(
  96. kernel_size=(1, 2, 2), strides=(1, 2, 2))
  97. self.g = nn.SequentialCell(self.g, max_pool_layer)
  98. if self.mode != 'gaussian':
  99. self.phi = nn.SequentialCell(self.phi, max_pool_layer)
  100. else:
  101. self.phi = max_pool_layer
  102. def construct(self, x):
  103. "nonlocalblock construct."
  104. batch_size = x.shape[0]
  105. g_x = self.g(x).view((batch_size, self.inter_channels, -1))
  106. input_perm = (0, 2, 1)
  107. g_x = self.transpose(g_x, input_perm)
  108. f = self.zeros((1, 1, 1), mindspore.float32)
  109. if self.mode == "gaussian":
  110. theta_x = x.view((batch_size, self.in_channels, -1))
  111. theta_x = self.transpose(theta_x, input_perm)
  112. phi_x = x.view(batch_size, self.in_channels, -1)
  113. f = self.batmatmul(theta_x, phi_x)
  114. elif self.mode in ["embedded", "dot"]:
  115. theta_x = self.theta(x).view((batch_size, self.inter_channels, -1))
  116. theta_x = self.transpose(theta_x, input_perm)
  117. phi_x = self.phi(x).view((batch_size, self.inter_channels, -1))
  118. f = self.batmatmul(theta_x, phi_x)
  119. elif self.mode == "concatenation":
  120. theta_x = self.theta(x).view(
  121. (batch_size, self.inter_channels, -1, 1))
  122. phi_x = self.phi(x).view((batch_size, self.inter_channels, 1, -1))
  123. h = theta_x.shape[2]
  124. w = phi_x.shape[3]
  125. theta_x = self.tile(theta_x, (1, 1, 1, w))
  126. phi_x = self.tile(phi_x, (1, 1, h, 1))
  127. concat_feature = self.concat_op((theta_x, phi_x))
  128. f = self.concat_project(concat_feature)
  129. b, _, h, w = f.shape
  130. f = f.view((b, h, w))
  131. f_div_c = self.zeros((1, 1, 1), mindspore.float32)
  132. if self.mode in ["gaussian", "embedded"]:
  133. f_div_c = self.softmax(f)
  134. elif self.mode in ["dot", "concatenation"]:
  135. n = f.shape[-1]
  136. f_div_c = f / n
  137. y = self.batmatmul(f_div_c, g_x)
  138. y = self.transpose(y, input_perm)
  139. y = y.view((batch_size,
  140. self.inter_channels,
  141. x.shape[2],
  142. x.shape[3],
  143. x.shape[4]))
  144. w_y = self.w(y)
  145. z = x + w_y
  146. return z

2. Nonlocal3d

Nonlocal3d包含backbone、avg_pool、flatten、head几部分组成。大致可以归纳为如下几点。

第一部分:backbone部分是NLResInflate3D50(NLInflateResNet3D类),它是在NLInflateResNet3D结构中实现[3,4,6,3]规格的stage。而NLInflateResNet3D该结构是继承于ResNet3d50的结构,在ResNet3d50的[3,4,6,3]第2,3两个stage中的10层以每隔1层的方式插入一个NonlocalBlockND。

第二部分:NLResInflate3D50输出到一个平均池化并flatten。

第三部分:分类头。将flatten后的tensor输入到Dropdensehead进行分类,得到shape(N,NUM_CLASSES)的tensor。

首先是的NLResInflate3D50(NLInflateResNet3D类)的实现:

  1. class NLInflateBlock3D(ResidualBlock3D):
  2. """
  3. ResNet3D residual block definition.
  4. Args:
  5. in_channel (int): Input channel.
  6. out_channel (int): Output channel.
  7. stride (int): Stride size for the second convolutional layer. Default: 1.
  8. group (int): Group convolutions. Default: 1.
  9. base_width (int): Width of per group. Default: 64.
  10. norm (nn.Cell, optional): Module specifying the normalization layer to use. Default: None.
  11. down_sample (nn.Cell, optional): Downsample structure. Default: None.
  12. Returns:
  13. Tensor, output tensor.
  14. Examples:
  15. >>> from mindvision.classification.models.backbones import ResidualBlock
  16. >>> ResidualBlock(3, 256, stride=2)
  17. """
  18. expansion: int = 4
  19. def __init__(self,
  20. in_channel: int,
  21. out_channel: int,
  22. conv12: Optional[nn.Cell] = Inflate3D,
  23. group: int = 1,
  24. base_width: int = 64,
  25. norm: Optional[nn.Cell] = None,
  26. down_sample: Optional[nn.Cell] = None,
  27. non_local: bool = False,
  28. non_local_mode: str = 'dot',
  29. **kwargs
  30. ) -> None:
  31. super(NLInflateBlock3D, self).__init__(in_channel=in_channel,
  32. out_channel=out_channel,
  33. mid_channel=out_channel,
  34. conv12=conv12,
  35. group=group,
  36. norm=norm,
  37. activation=[nn.ReLU, nn.ReLU],
  38. down_sample=down_sample,
  39. **kwargs)
  40. # conv3d doesn't support group>1 now at 1.6.1 version
  41. out_channel = int(out_channel * (base_width / 64.0)) * group
  42. self.non_local = non_local
  43. if self.non_local:
  44. in_channels = out_channel * self.expansion
  45. self.non_local_block = NonLocalBlockND(
  46. in_channels, mode=non_local_mode)
  47. def construct(self, x):
  48. """NLInflateBlock3D construct."""
  49. identity = x
  50. out = self.conv12(x)
  51. out = self.conv3(out)
  52. if self.down_sample:
  53. identity = self.down_sample(x)
  54. out += identity
  55. out = self.relu(out)
  56. if self.non_local:
  57. out = self.non_local_block(out)
  58. return out
  59. class NLInflateResNet3D(ResNet3D):
  60. """Inflate3D with ResNet3D backbone and non local block.
  61. Args:
  62. block (Optional[nn.Cell]): THe block for network.
  63. layer_nums (list): The numbers of block in different layers.
  64. norm (nn.Cell, optional): The module specifying the normalization layer to use. Default: None.
  65. stage_strides: Stride size for ResNet3D convolutional layer.
  66. non_local: Determine whether to apply nonlocal block in this block.
  67. Inputs:
  68. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
  69. Returns:
  70. Tensor, output tensor.
  71. Supported Platforms:
  72. ``GPU``
  73. Examples:
  74. >>> import numpy as np
  75. >>> import mindspore as ms
  76. >>> from mindvision.msvideo.models.backbones.nonlocal3d import ResNetI3D, ResNetI3DResidualBlock
  77. >>> net = ResNet(ResNetI3DResidualBlock, [3, 4, 6, 3])
  78. >>> x = ms.Tensor(np.ones([1, 3, 32, 224, 224]), ms.float32)
  79. >>> output = net(x)
  80. >>> print(output.shape)
  81. (1, 2048, 16, 7, 7)
  82. """
  83. def __init__(self,
  84. block: Optional[nn.Cell],
  85. layer_nums: Tuple[int],
  86. stage_channels: Tuple[int] = (64, 128, 256, 512),
  87. stage_strides: Tuple[int] = ((1, 1, 1),
  88. (1, 2, 2),
  89. (1, 2, 2),
  90. (1, 2, 2)),
  91. down_sample: Optional[nn.Cell] = Unit3D,
  92. inflate: Tuple[Tuple[int]] = ((1, 1, 1),
  93. (1, 0, 1, 0),
  94. (1, 0, 1, 0, 1, 0),
  95. (0, 1, 0)),
  96. non_local: Tuple[Tuple[int]] = ((0, 0, 0),
  97. (0, 1, 0, 1),
  98. (0, 1, 0, 1, 0, 1),
  99. (0, 0, 0)),
  100. **kwargs
  101. ):
  102. super(NLInflateResNet3D, self).__init__(block=block,
  103. layer_nums=layer_nums,
  104. stage_channels=stage_channels,
  105. stage_strides=stage_strides,
  106. down_sample=down_sample
  107. )
  108. self.in_channels = stage_channels[0]
  109. self.conv1 = Unit3D(3, stage_channels[0], kernel_size=(
  110. 5, 7, 7), stride=(1, 2, 2), norm=self.norm)
  111. self.maxpool = Maxpool3DwithPad(kernel_size=(
  112. 1, 3, 3), padding=(0, 0, 1, 1, 1, 1), strides=(1, 2, 2))
  113. self.pool2 = ops.MaxPool3D(kernel_size=(2, 1, 1), strides=(2, 1, 1))
  114. self.layer1 = self._make_layer(
  115. block,
  116. stage_channels[0],
  117. layer_nums[0],
  118. stride=tuple(stage_strides[0]),
  119. norm=self.norm,
  120. inflate=inflate[0],
  121. non_local=non_local[0],
  122. **kwargs)
  123. self.layer2 = self._make_layer(
  124. block,
  125. stage_channels[1],
  126. layer_nums[1],
  127. stride=tuple(stage_strides[1]),
  128. norm=self.norm,
  129. inflate=inflate[1],
  130. non_local=non_local[1],
  131. **kwargs)
  132. self.layer3 = self._make_layer(
  133. block,
  134. stage_channels[2],
  135. layer_nums[2],
  136. stride=tuple(stage_strides[2]),
  137. norm=self.norm,
  138. inflate=inflate[2],
  139. non_local=non_local[2],
  140. **kwargs)
  141. self.layer4 = self._make_layer(
  142. block,
  143. stage_channels[3],
  144. layer_nums[3],
  145. stride=tuple(stage_strides[3]),
  146. norm=self.norm,
  147. inflate=inflate[3],
  148. non_local=non_local[3],
  149. **kwargs)
  150. def construct(self, x):
  151. x = self.conv1(x)
  152. x = self.maxpool(x)
  153. x = self.layer1(x)
  154. x = self.pool2(x)
  155. x = self.layer2(x)
  156. x = self.layer3(x)
  157. x = self.layer4(x)
  158. return x
  159. class NLResInflate3D50(NLInflateResNet3D):
  160. """
  161. The class of ResNet50 uses the registration mechanism to register, need to use the yaml configuration file to call.
  162. """
  163. def __init__(self, **kwargs):
  164. super(NLResInflate3D50, self).__init__(
  165. NLInflateBlock3D, [3, 4, 6, 3], **kwargs)

接下来是nonlocal3d的实现:

  1. class nonlocal3d(nn.Cell):
  2. """
  3. nonlocal3d model
  4. Xiaolong Wang.
  5. "Non-local Neural Networks."
  6. https://arxiv.org/pdf/1711.07971v3
  7. Args:
  8. in_d: Depth of input data, it can be considered as frame number of a video. Default: 32.
  9. in_h: Height of input frames. Default: 224.
  10. in_w: Width of input frames. Default: 224.
  11. num_classes(int): Number of classes, it is the size of classfication score for every sample,
  12. i.e. :math:`CLASSES_{out}`. Default: 400.
  13. pooling_keep_dim: whether to keep dim when pooling. Default: True.
  14. keep_prob(float): Probability of dropout for multi-dense-layer head, the number of probabilities equals
  15. the number of dense layers.
  16. pretrained(bool): If `True`, it will create a pretrained model, the pretrained model will be loaded
  17. from network. If `False`, it will create a nonlocal3d model with uniform initialization for weight and bias.
  18. backbone: Bcxkbone of nonlocal3d.
  19. avg_pool: Avgpooling and flatten.
  20. head: LinearClsHead architecture.
  21. Inputs:
  22. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
  23. Outputs:
  24. Tensor of shape :math:`(N, CLASSES_{out})`.
  25. Supported Platforms:
  26. ``GPU``
  27. Examples:
  28. >>> import numpy as np
  29. >>>
  30. >>> import mindspore as ms
  31. >>> from mindspore import Tensor
  32. >>> from mindvision.msvideo.models import nonlocal3d
  33. >>>
  34. >>> net = nonlocal3d()
  35. >>> x = Tensor(np.random.randn(1, 3, 32, 224, 224).astype(np.float32))
  36. >>> output = net(x)
  37. >>> print(output.shape)
  38. (1, 400)
  39. """
  40. def __init__(self,
  41. in_d: int = 32,
  42. in_h: int = 224,
  43. in_w: int = 224,
  44. num_classes: int = 400,
  45. keep_prob: float = 0.5,
  46. backbone: Optional[nn.Cell] = NLResInflate3D50,
  47. avg_pool: Optional[nn.Cell] = AdaptiveAvgPool3D,
  48. flatten: Optional[nn.Cell] = nn.Flatten,
  49. head: Optional[nn.Cell] = DropoutDense
  50. ):
  51. super(nonlocal3d, self).__init__()
  52. last_d = math.ceil(in_d / 32)
  53. last_h = math.ceil((math.ceil(in_h / 32) + 1) / 4)
  54. last_w = math.ceil((math.ceil(in_w / 32) + 1) / 4)
  55. backbone_output_channel = 512 * last_d * last_h * last_w
  56. self.backbone = backbone()
  57. self.avg_pool = avg_pool((1, 1, 1))
  58. self.flatten = flatten()
  59. self.head = head(input_channel=backbone_output_channel,
  60. out_channel=num_classes,
  61. keep_prob=keep_prob)
  62. def construct(self, x):
  63. x = self.backbone(x)
  64. x = self.avg_pool(x)
  65. x = self.flatten(x)
  66. x = self.head(x)
  67. return x

3. 实验结果

在mindspore框架下进行精度测试,得到如下精度,接近于原文精度:

  1. [Start eval `nonlocal`]
  2. eval: 1/19877
  3. eval: 2/19877
  4. eval: 3/19877
  5. eval: 4/19877
  6. eval: 5/19877
  7. eval: 6/19877
  8. eval: 7/19877
  9. eval: 8/19877
  10. eval: 9/19877
  11. eval: 10/19877
  12. ...
  13. eval: 19874/19877
  14. eval: 19875/19877
  15. eval: 19876/19877
  16. eval: 19877/19877
  17. {'Top_1_Accuracy': 0.7248, 'Top_5_Accuracy': 0.9072}

三、代码仓库

如有读者对mindspore框架下Non-Local Network感兴趣的话,可以使用如下仓库:

nonlocal_mindspore

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/855372
推荐阅读
相关标签
  

闽ICP备14008679号