当前位置:   article > 正文

【图像分类】用最简短的代码复现SeNet,小白一定要收藏(keras,Tensorflow2.x)_resnet18 senet

resnet18 senet

目录

摘要

一、SENet概述

二、SENet 结构组成详解

三、详细的计算过程

                                                  ​

SENet 在具体网络中应用(代码实现SE_ResNet)

第一个残差模块

第二个残差模块

ResNet18、ResNet34模型的完整代码

ResNet50、ResNet101、ResNet152完整代码


摘要

一、SENet概述

           Squeeze-and-Excitation Networks(简称 SENet)是 Momenta 胡杰团队(WMW)提出的新的网络结构,利用SENet,一举取得最后一届 ImageNet 2017 竞赛 Image Classification 任务的冠军,在ImageNet数据集上将top-5 error降低到2.251%,原先的最好成绩是2.991%。

     作者在文中将SENet block插入到现有的多种分类网络中,都取得了不错的效果。作者的动机是希望显式地建模特征通道之间的相互依赖关系。另外,作者并未引入新的空间维度来进行特征通道间的融合,而是采用了一种全新的「特征重标定」策略。具体来说,就是通过学习的方式来自动获取到每个特征通道的重要程度,然后依照这个重要程度去提升有用的特征并抑制对当前任务用处不大的特征。

     通俗的来说SENet的核心思想在于通过网络根据loss去学习特征权重,使得有效的feature map权重大,无效或效果小的feature map权重小的方式训练模型达到更好的结果。SE block嵌在原有的一些分类网络中不可避免地增加了一些参数和计算量,但是在效果面前还是可以接受的 。Sequeeze-and-Excitation(SE) block并不是一个完整的网络结构,而是一个子结构,可以嵌到其他分类或检测模型中。

二、SENet 结构组成详解

    上述结构中,Squeeze 和 Excitation 是两个非常关键的操作,下面进行详细说明。

  

    上图是SE 模块的示意图。给定一个输入 x,其特征通道数为 {C}',通过一系列卷积等一般变换后得到一个特征通道数为C 的特征。通过下面的三个操作还重标前面得到的特征:

   1、Squeeze 操作,顺着空间维度来进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。它表征着在特征通道上响应的全局分布,而且使得靠近输入的层也可以获得全局的感受野,这一点在很多任务中都是非常有用的。

  2、 Excitation 操作,它是一个类似于循环神经网络中门的机制。通过参数 w 来为每个特征通道生成权重,其中参数 w 被学习用来显式地建模特征通道间的相关性。

  3、 Reweight 操作,将 Excitation 的输出的权重看做是进过特征选择后的每个特征通道的重要性,然后通过乘法逐通道加权到先前的特征上,完成在通道维度上的对原始特征的重标定。

三、详细的计算过程

 首先F_{tr}这一步是转换操作(严格讲并不属于SENet,而是属于原网络,可以看后面SENet和Inception及ResNet网络的结合),在文中就是一个标准的卷积操作而已,输入输出的定义如下表示:

                                       

    那么这个F_{tr}的公式就是下面的公式1(卷积操作,V_{c}表示第c个卷积核,X^{s}表示第s个输入)。

                                                
    F_{tr}得到的U就是Figure1中的左边第二个三维矩阵,也叫tensor,或者叫C个大小为H*W的feature map。而uc表示U中第c个二维矩阵,下标c表示channel。
    接下来就是Squeeze操作,公式非常简单,就是一个global average pooling:
                              
    因此公式2就将H*W*C的输入转换成1*1*C的输出,对应Figure1中的Fsq操作。为什么会有这一步呢?这一步的结果相当于表明该层C个feature map的数值分布情况,或者叫全局信息。
    再接下来就是Excitation操作,如公式3。直接看最后一个等号,前面squeeze得到的结果是z,这里先用W1乘以z,就是一个全连接层操作,W1的维度是C/r * C,这个r是一个缩放参数,在文中取的是16,这个参数的目的是为了减少channel个数从而降低计算量。又因为z的维度是1*1*C,所以W1z的结果就是1*1*C/r;然后再经过一个ReLU层,输出的维度不变;然后再和W2相乘,和W2相乘也是一个全连接层的过程,W2的维度是C*C/r,因此输出的维度就是1*1*C;最后再经过sigmoid函数,得到s:
                                    
    也就是说最后得到的这个s的维度是1*1*C,C表示channel数目。这个s其实是本文的核心,它是用来刻画tensor U中C个feature map的权重。而且这个权重是通过前面这些全连接层和非线性层学习得到的,因此可以end-to-end训练。这两个全连接层的作用就是融合各通道的feature map信息,因为前面的squeeze都是在某个channel的feature map里面操作。
    在得到s之后,就可以对原来的tensor U操作了,就是下面的公式4。也很简单,就是channel-wise multiplication,什么意思呢?u_{c}是一个二维矩阵,s_{c}是一个数,也就是权重,因此相当于把u_{c}矩阵中的每个值都乘以s_{c}。对应Figure1中的Fscale。

                                                  

SENet 在具体网络中应用(代码实现SE_ResNet)

介绍完具体的公式实现,下面介绍下SE block怎么运用到具体的网络之中。


    上图是将 SE 模块嵌入到 Inception 结构的一个示例。方框旁边的维度信息代表该层的输出。

    这里我们使用 global average pooling 作为 Squeeze 操作。紧接着两个 Fully Connected 层组成一个 Bottleneck 结构去建模通道间的相关性,并输出和输入特征同样数目的权重。我们首先将特征维度降低到输入的 1/16,然后经过 ReLu 激活后再通过一个 Fully Connected 层升回到原来的维度。这样做比直接用一个 Fully Connected 层的好处在于:

    1)具有更多的非线性,可以更好地拟合通道间复杂的相关性;

    2)极大地减少了参数量和计算量。然后通过一个 Sigmoid 的门获得 0~1 之间归一化的权重,最后通过一个 Scale 的操作来将归一化后的权重加权到每个通道的特征上。

    除此之外,SE 模块还可以嵌入到含有 skip-connections 的模块中。上右图是将 SE 嵌入到 ResNet 模块中的一个例子,操作过程基本和 SE-Inception 一样,只不过是在 Addition 前对分支上 Residual 的特征进行了特征重标定。如果对 Addition 后主支上的特征进行重标定,由于在主干上存在 0~1 的 scale 操作,在网络较深 BP 优化时就会在靠近输入层容易出现梯度消散的情况,导致模型难以优化。

    目前大多数的主流网络都是基于这两种类似的单元通过 repeat 方式叠加来构造的。由此可见,SE 模块可以嵌入到现在几乎所有的网络结构中。通过在原始网络结构的 building block 单元中嵌入 SE 模块,我们可以获得不同种类的 SENet。如 SE-BN-Inception、SE-ResNet、SE-ReNeXt、SE-Inception-ResNet-v2 等等。

本例通过实现SE-ResNet,来显示如何将SE模块嵌入到ResNet网络中。SE-ResNet模型如下图:

第一个残差模块

第一个残差模块用于实现ResNet18、ResNet34模型,SENet嵌入到第二个卷积的后面。

  1. # 第一个残差模块
  2. class BasicBlock(layers.Layer):
  3. def __init__(self, filter_num, stride=1):
  4. super(BasicBlock, self).__init__()
  5. self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=stride, padding='same')
  6. self.bn1 = layers.BatchNormalization()
  7. self.relu = layers.Activation('relu')
  8. self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')
  9. self.bn2 = layers.BatchNormalization()
  10. # se-block
  11. self.se_globalpool = keras.layers.GlobalAveragePooling2D()
  12. self.se_resize = keras.layers.Reshape((1, 1, filter_num))
  13. self.se_fc1 = keras.layers.Dense(units=filter_num // 16, activation='relu',
  14. use_bias=False)
  15. self.se_fc2 = keras.layers.Dense(units=filter_num, activation='sigmoid',
  16. use_bias=False)
  17. if stride != 1:
  18. self.downsample = Sequential()
  19. self.downsample.add(layers.Conv2D(filter_num, (1, 1), strides=stride))
  20. else:
  21. self.downsample = lambda x: x
  22. def call(self, input, training=None):
  23. out = self.conv1(input)
  24. out = self.bn1(out)
  25. out = self.relu(out)
  26. out = self.conv2(out)
  27. out = self.bn2(out)
  28. # se_block
  29. b = out
  30. out = self.se_globalpool(out)
  31. out = self.se_resize(out)
  32. out = self.se_fc1(out)
  33. out = self.se_fc2(out)
  34. out = keras.layers.Multiply()([b, out])
  35. identity = self.downsample(input)
  36. output = layers.add([out, identity])
  37. output = tf.nn.relu(output)
  38. return output

第二个残差模块

第二个残差模块用于实现ResNet50、ResNet101、ResNet152模型,SENet模块嵌入到第三个卷积后面。

  1. # 第二个残差模块
  2. class Block(layers.Layer):
  3. def __init__(self, filters, downsample=False, stride=1):
  4. super(Block, self).__init__()
  5. self.downsample = downsample
  6. self.conv1 = layers.Conv2D(filters, (1, 1), strides=stride, padding='same')
  7. self.bn1 = layers.BatchNormalization()
  8. self.relu = layers.Activation('relu')
  9. self.conv2 = layers.Conv2D(filters, (3, 3), strides=1, padding='same')
  10. self.bn2 = layers.BatchNormalization()
  11. self.conv3 = layers.Conv2D(4 * filters, (1, 1), strides=1, padding='same')
  12. self.bn3 = layers.BatchNormalization()
  13. # se-block
  14. self.se_globalpool = keras.layers.GlobalAveragePooling2D()
  15. self.se_resize = keras.layers.Reshape((1, 1, 4 * filters))
  16. self.se_fc1 = keras.layers.Dense(units=4 * filters // 16, activation='relu',
  17. use_bias=False)
  18. self.se_fc2 = keras.layers.Dense(units=4 * filters, activation='sigmoid',
  19. use_bias=False)
  20. if self.downsample:
  21. self.shortcut = Sequential()
  22. self.shortcut.add(layers.Conv2D(4 * filters, (1, 1), strides=stride))
  23. self.shortcut.add(layers.BatchNormalization(axis=3))
  24. def call(self, input, training=None):
  25. out = self.conv1(input)
  26. out = self.bn1(out)
  27. out = self.relu(out)
  28. out = self.conv2(out)
  29. out = self.bn2(out)
  30. out = self.relu(out)
  31. out = self.conv3(out)
  32. out = self.bn3(out)
  33. b = out
  34. out = self.se_globalpool(out)
  35. out = self.se_resize(out)
  36. out = self.se_fc1(out)
  37. out = self.se_fc2(out)
  38. out = keras.layers.Multiply()([b, out])
  39. if self.downsample:
  40. shortcut = self.shortcut(input)
  41. else:
  42. shortcut = input
  43. output = layers.add([out, shortcut])
  44. output = tf.nn.relu(output)
  45. return output

ResNet18、ResNet34模型的完整代码

  1. import tensorflow as tf
  2. from tensorflow import keras
  3. from tensorflow.keras import layers, Sequential
  4. # 第一个残差模块
  5. class BasicBlock(layers.Layer):
  6. def __init__(self, filter_num, stride=1):
  7. super(BasicBlock, self).__init__()
  8. self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=stride, padding='same')
  9. self.bn1 = layers.BatchNormalization()
  10. self.relu = layers.Activation('relu')
  11. self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')
  12. self.bn2 = layers.BatchNormalization()
  13. # se-block
  14. self.se_globalpool = keras.layers.GlobalAveragePooling2D()
  15. self.se_resize = keras.layers.Reshape((1, 1, filter_num))
  16. self.se_fc1 = keras.layers.Dense(units=filter_num // 16, activation='relu',
  17. use_bias=False)
  18. self.se_fc2 = keras.layers.Dense(units=filter_num, activation='sigmoid',
  19. use_bias=False)
  20. if stride != 1:
  21. self.downsample = Sequential()
  22. self.downsample.add(layers.Conv2D(filter_num, (1, 1), strides=stride))
  23. else:
  24. self.downsample = lambda x: x
  25. def call(self, input, training=None):
  26. out = self.conv1(input)
  27. out = self.bn1(out)
  28. out = self.relu(out)
  29. out = self.conv2(out)
  30. out = self.bn2(out)
  31. # se_block
  32. b = out
  33. out = self.se_globalpool(out)
  34. out = self.se_resize(out)
  35. out = self.se_fc1(out)
  36. out = self.se_fc2(out)
  37. out = keras.layers.Multiply()([b, out])
  38. identity = self.downsample(input)
  39. output = layers.add([out, identity])
  40. output = tf.nn.relu(output)
  41. return output
  42. class ResNet(keras.Model):
  43. def __init__(self, layer_dims, num_classes=10):
  44. super(ResNet, self).__init__()
  45. # 预处理层
  46. self.padding = keras.layers.ZeroPadding2D((3, 3))
  47. self.stem = Sequential([
  48. layers.Conv2D(64, (7, 7), strides=(2, 2)),
  49. layers.BatchNormalization(),
  50. layers.Activation('relu'),
  51. layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same')
  52. ])
  53. # resblock
  54. self.layer1 = self.build_resblock(64, layer_dims[0])
  55. self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
  56. self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
  57. self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)
  58. # 全局池化
  59. self.avgpool = layers.GlobalAveragePooling2D()
  60. # 全连接层
  61. self.fc = layers.Dense(num_classes, activation=tf.keras.activations.softmax)
  62. def call(self, input, training=None):
  63. x= self.padding(input)
  64. x = self.stem(x)
  65. x = self.layer1(x)
  66. x = self.layer2(x)
  67. x = self.layer3(x)
  68. x = self.layer4(x)
  69. # [b,c]
  70. x = self.avgpool(x)
  71. x = self.fc(x)
  72. return x
  73. def build_resblock(self, filter_num, blocks, stride=1):
  74. res_blocks = Sequential()
  75. res_blocks.add(BasicBlock(filter_num, stride))
  76. for pre in range(1, blocks):
  77. res_blocks.add(BasicBlock(filter_num, stride=1))
  78. return res_blocks
  79. def ResNet34(num_classes=10):
  80. return ResNet([2, 2, 2, 2], num_classes=num_classes)
  81. def ResNet34(num_classes=10):
  82. return ResNet([3, 4, 6, 3], num_classes=num_classes)
  83. model = ResNet34(num_classes=1000)
  84. model.build(input_shape=(1, 224, 224, 3))
  85. print(model.summary()) # 统计网络参数

ResNet50、ResNet101、ResNet152完整代码

  1. import tensorflow as tf
  2. from tensorflow import keras
  3. from tensorflow.keras import layers, Sequential
  4. # 第二个残差模块
  5. class Block(layers.Layer):
  6. def __init__(self, filters, downsample=False, stride=1):
  7. super(Block, self).__init__()
  8. self.downsample = downsample
  9. self.conv1 = layers.Conv2D(filters, (1, 1), strides=stride, padding='same')
  10. self.bn1 = layers.BatchNormalization()
  11. self.relu = layers.Activation('relu')
  12. self.conv2 = layers.Conv2D(filters, (3, 3), strides=1, padding='same')
  13. self.bn2 = layers.BatchNormalization()
  14. self.conv3 = layers.Conv2D(4 * filters, (1, 1), strides=1, padding='same')
  15. self.bn3 = layers.BatchNormalization()
  16. # se-block
  17. self.se_globalpool = keras.layers.GlobalAveragePooling2D()
  18. self.se_resize = keras.layers.Reshape((1, 1, 4 * filters))
  19. self.se_fc1 = keras.layers.Dense(units=4 * filters // 16, activation='relu',
  20. use_bias=False)
  21. self.se_fc2 = keras.layers.Dense(units=4 * filters, activation='sigmoid',
  22. use_bias=False)
  23. if self.downsample:
  24. self.shortcut = Sequential()
  25. self.shortcut.add(layers.Conv2D(4 * filters, (1, 1), strides=stride))
  26. self.shortcut.add(layers.BatchNormalization(axis=3))
  27. def call(self, input, training=None):
  28. out = self.conv1(input)
  29. out = self.bn1(out)
  30. out = self.relu(out)
  31. out = self.conv2(out)
  32. out = self.bn2(out)
  33. out = self.relu(out)
  34. out = self.conv3(out)
  35. out = self.bn3(out)
  36. b = out
  37. out = self.se_globalpool(out)
  38. out = self.se_resize(out)
  39. out = self.se_fc1(out)
  40. out = self.se_fc2(out)
  41. out = keras.layers.Multiply()([b, out])
  42. if self.downsample:
  43. shortcut = self.shortcut(input)
  44. else:
  45. shortcut = input
  46. output = layers.add([out, shortcut])
  47. output = tf.nn.relu(output)
  48. return output
  49. class ResNet(keras.Model):
  50. def __init__(self, layer_dims, num_classes=10):
  51. super(ResNet, self).__init__()
  52. # 预处理层
  53. self.padding = keras.layers.ZeroPadding2D((3, 3))
  54. self.stem = Sequential([
  55. layers.Conv2D(64, (7, 7), strides=(2, 2)),
  56. layers.BatchNormalization(),
  57. layers.Activation('relu'),
  58. layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same')
  59. ])
  60. # resblock
  61. self.layer1 = self.build_resblock(64, layer_dims[0],stride=1)
  62. self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
  63. self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
  64. self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)
  65. # 全局池化
  66. self.avgpool = layers.GlobalAveragePooling2D()
  67. # 全连接层
  68. self.fc = layers.Dense(num_classes, activation=tf.keras.activations.softmax)
  69. def call(self, input, training=None):
  70. x = self.padding(input)
  71. x = self.stem(x)
  72. x = self.layer1(x)
  73. x = self.layer2(x)
  74. x = self.layer3(x)
  75. x = self.layer4(x)
  76. # [b,c]
  77. x = self.avgpool(x)
  78. x = self.fc(x)
  79. return x
  80. def build_resblock(self, filter_num, blocks, stride=1):
  81. res_blocks = Sequential()
  82. if stride != 1 or filter_num * 4 != 64:
  83. res_blocks.add(Block(filter_num, downsample=True,stride=stride))
  84. for pre in range(1, blocks):
  85. res_blocks.add(Block(filter_num, stride=1))
  86. return res_blocks
  87. def ResNet50(num_classes=10):
  88. return ResNet([3, 4, 6, 3], num_classes=num_classes)
  89. def ResNet101(num_classes=10):
  90. return ResNet([3, 4, 23, 3], num_classes=num_classes)
  91. def ResNet152(num_classes=10):
  92. return ResNet([3, 8, 36, 3], num_classes=num_classes)
  93. model = ResNet50(num_classes=1000)
  94. model.build(input_shape=(1, 224, 224, 3))
  95. print(model.summary()) # 统计网络参数


 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/290974
推荐阅读
相关标签
  

闽ICP备14008679号