当前位置:   article > 正文

Non_Local 网络模块的解析_non-local模块

non-local模块

该模块可以插入到现有的CNN网络结构中,优点是可以提高整体的性能,缺点是参数多,计算量大

写此文章只为自己日后回顾

首先base cnn 为resnet50 ,插入在block3之后,模块的输入和输出都不变,这也是可以即插即用的原因。

首先输入为【B,H,W,C】 先对输入使用1X1大小的卷积核做卷积,降低输入通道。减少计算量。然后把【B,H,W,C】reshape 成[B,HW,C],然后两个相乘(其中一个transpose),这样可以得到【B,HW,HW】,可以得到图像像素和其他位置的相关性,然后将结果做softmax 处理,突出共性,然后将softmax得到的记过和【B,HW,C】矩阵相乘,将权重应用到输入上,然后在和原始输入相加,就完成位置注意力机制。

这个模块,可以让模型把注意力放在要识别的物体上。详细代码如下:

  1. def nonlocal_dot(net, depth, embed=True, softmax=False, maxpool=2, scope=None):
  2. """ Implementation of the non-local block in its various forms.
  3. See "Non-local Neural Networks" by
  4. Xiaolong Wang, Ross Girshick, Abhinav Gupta, Kaiming He
  5. https://arxiv.org/pdf/1711.07971.pdf
  6. Args:
  7. - `net`: The symbolic input into the block, a (B,H,W,C) Tensor.
  8. - `depth`: The number of channels in which to execute the non-local operation.
  9. - `embed`: Whether or not use the "embedded version" as in Sec.3.2
  10. - `softmax`: Whether or not to use the softmax operation which makes it
  11. equivalent to soft-attention.
  12. - `maxpool`: How large of a max-pooling (Sec.3.3) to use to help reduce
  13. the computational burden. Default is 2, use `False` for none.
  14. - `scope`: An optional scope for all created variables.
  15. Returns:
  16. The symbolic output of the non-local block operation.
  17. Note:
  18. The final BatchNorm's gamma is initialized to zero, so as to make this a
  19. no-op (skip) at initialization, as described in Sec.4.1.
  20. """
  21. with tf.variable_scope(scope, 'nonlocal', values=[net]) as sc:
  22. with slim.arg_scope([slim.conv2d], normalizer_fn=None):
  23. if embed:
  24. #change input channels to 512
  25. a = conv2d_same(net, depth, 1, stride=1, scope='embA')
  26. b = conv2d_same(net, depth, 1, stride=1, scope='embB')
  27. else:
  28. a, b = net, net
  29. g_orig = g = conv2d_same(net, depth, 1, stride=1, scope='g')
  30. if maxpool is not False and maxpool > 1:
  31. b = slim.max_pool2d(b, [maxpool, maxpool], stride=maxpool, scope='pool')
  32. g = slim.max_pool2d(g, [maxpool, maxpool], stride=maxpool, scope='pool')
  33. # Flatten from (B,H,W,C) to (B,HW,C) or similar
  34. a_flat = tf.reshape(a, [tf.shape(a)[0], -1, tf.shape(a)[-1]])
  35. b_flat = tf.reshape(b, [tf.shape(b)[0], -1, tf.shape(b)[-1]])
  36. g_flat = tf.reshape(g, [tf.shape(g)[0], -1, tf.shape(g)[-1]])
  37. a_flat.set_shape([8, a.shape[1] * a.shape[2] if None not in a.shape[1:3] else None, a.shape[-1]])
  38. b_flat.set_shape([8, b.shape[1] * b.shape[2] if None not in b.shape[1:3] else None, b.shape[-1]])
  39. g_flat.set_shape([8, g.shape[1] * g.shape[2] if None not in g.shape[1:3] else None, g.shape[-1]])
  40. # Compute f(a, b) -> (B,HW,HW) 计算相似性
  41. a_flat_new = tf.gather(a_flat,0)
  42. b_flat_new = tf.gather(b_flat,0)
  43. print("&&&&&&&&&&&&&&&&&&&&&&&&&")
  44. print(a_flat_new.shape) #[49,512]
  45. print(b_flat_new.shape) #[16,512]
  46. f0 = tf.matmul(a_flat_new, tf.transpose(b_flat_new, [1,0]))
  47. f0 = tf.reshape(f0,[-1,f0.shape[0],f0.shape[1]])
  48. print("f0.shape") #[1,49,16]
  49. print(f0.shape)
  50. a_flat_new = tf.gather(a_flat,1)
  51. b_flat_new = tf.gather(b_flat,1)
  52. print("&&&&&&&&&&&&&&&&&&&&&&&&&")
  53. print(a_flat_new.shape)
  54. print(b_flat_new.shape)
  55. f1 = tf.matmul(a_flat_new, tf.transpose(b_flat_new, [1,0]))
  56. f1 = tf.reshape(f1,[-1,f1.shape[0],f1.shape[1]])
  57. print("f1.shape")
  58. print(f1.shape)
  59. f = tf.concat([f0,f1],axis=0)
  60. print("f.shape")
  61. print(f.shape)
  62. for i in range(6):
  63. i = i+2
  64. a_flat_new = tf.gather(a_flat,i)
  65. b_flat_new = tf.gather(b_flat,i)
  66. print("&&&&&&&&&&&&&&&&&&&&&&&&&")
  67. print(a_flat_new.shape)
  68. print(b_flat_new.shape)
  69. f0 = tf.matmul(a_flat_new, tf.transpose(b_flat_new, [1,0]))
  70. f0 = tf.reshape(f0,[-1,f0.shape[0],f0.shape[1]])
  71. print("f0.shape")
  72. print(f0.shape)
  73. f =tf.concat([f,f0],axis=0)
  74. print("f.shape")
  75. print(f.shape) #[8,49,16]
  76. if softmax:
  77. f = tf.nn.softmax(f)
  78. else:
  79. f = f / tf.cast(tf.shape(f)[-1], tf.float32)
  80. # Compute f * g ("self-attention") -> (B,HW,C)
  81. print("********************")
  82. print(g_flat.shape) #[8,16,512]
  83. f_flat_new = tf.gather(f,0)
  84. g_flat_new = tf.gather(g_flat,0)
  85. print("###################")
  86. print(f_flat_new.shape) #[49,16]
  87. print(g_flat_new.shape) #[16,512]
  88. f0 = tf.matmul(f_flat_new, g_flat_new)
  89. f0 = tf.reshape(f0,[-1,f0.shape[0],f0.shape[1]])
  90. print("f0.shape") #[1,49,16]
  91. print(f0.shape)
  92. f_flat_new = tf.gather(f,1)
  93. g_flat_new = tf.gather(g_flat,1)
  94. print("##########################")
  95. print(f_flat_new.shape)
  96. print(g_flat_new.shape)
  97. f1 = tf.matmul(f_flat_new, g_flat_new)
  98. f1 = tf.reshape(f1,[-1,f1.shape[0],f1.shape[1]])
  99. print("f1.shape")
  100. print(f1.shape)
  101. f_new = tf.concat([f0,f1],axis=0)
  102. print("f_new.shape")
  103. print(f_new.shape)
  104. for i in range(6):
  105. i = i+2
  106. f_flat_new = tf.gather(f,i)
  107. g_flat_new = tf.gather(g_flat,i)
  108. print("&&&&&&&&&&&&&&&&&&&&&&&&&")
  109. print(f_flat_new.shape)
  110. print(g_flat_new.shape)
  111. f0 = tf.matmul(f_flat_new, g_flat_new)
  112. f0 = tf.reshape(f0,[-1,f0.shape[0],f0.shape[1]])
  113. print("f0.shape")
  114. print(f0.shape)
  115. f_new =tf.concat([f_new,f0],axis=0)
  116. print("f_new.shape")
  117. print(f_new.shape) #[8,49,16]
  118. #fg = tf.matmul(f, g_flat)
  119. # Expand and fix the static shapes TF lost track of.
  120. fg = tf.reshape(f_new, tf.shape(g_orig))
  121. # fg.set_shape(g.shape) # NOTE: This actually appears unnecessary.
  122. # Go back up to the original depth, add residually, zero-init.
  123. #with slim.arg_scope([slim.conv2d],
  124. # weights_initializer=tf.zeros_initializer()):
  125. with slim.arg_scope([slim.batch_norm], param_initializers={'gamma': tf.zeros_initializer()}):
  126. fg = conv2d_same(fg, net.shape[-1], 1, stride=1, scope='fgup')
  127. net = net + fg
  128. return slim.utils.collect_named_outputs(None, sc.name, net)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Guff_9hys/article/detail/855343
推荐阅读
相关标签
  

闽ICP备14008679号