当前位置:   article > 正文




  • 知道GoogLeNet网络结构的特点
  • 能够利用GoogLeNet完成图像分类

GoogLeNet在2014年由Google团队提出, 斩获当年ImageNet(ILSVRC14)竞赛中Classification Task (分类任务) 第一名,VGG获得了第二名,为了向“LeNet”致敬,因此取名为“GoogLeNet”。



1.Inception 块






  • 实现跨通道的交互和信息整合

  • 卷积核通道数的降维和升维,减少网络参数


  1. # 定义Inception模块
  2. class Inception(tf.keras.layers.Layer):
  3. # 输入参数为各个卷积的卷积核个数
  4. def __init__(self, c1, c2, c3, c4):
  5. super().__init__()
  6. # 线路11 x 1卷积层,激活函数是RELU,padding是same
  7. self.p1_1 = tf.keras.layers.Conv2D(
  8. c1, kernel_size=1, activation='relu', padding='same')
  9. # 线路21 x 1卷积层后接3 x 3卷积层,激活函数是RELU,padding是same
  10. self.p2_1 = tf.keras.layers.Conv2D(
  11. c2[0], kernel_size=1, padding='same', activation='relu')
  12. self.p2_2 = tf.keras.layers.Conv2D(c2[1], kernel_size=3, padding='same',
  13. activation='relu')
  14. # 线路31 x 1卷积层后接5 x 5卷积层,激活函数是RELU,padding是same
  15. self.p3_1 = tf.keras.layers.Conv2D(
  16. c3[0], kernel_size=1, padding='same', activation='relu')
  17. self.p3_2 = tf.keras.layers.Conv2D(c3[1], kernel_size=5, padding='same',
  18. activation='relu')
  19. # 线路43 x 3最大池化层后接1 x 1卷积层,激活函数是RELU,padding是same
  20. self.p4_1 = tf.keras.layers.MaxPool2D(
  21. pool_size=3, padding='same', strides=1)
  22. self.p4_2 = tf.keras.layers.Conv2D(
  23. c4, kernel_size=1, padding='same', activation='relu')
  24. # 完成前向传播过程
  25. def call(self, x):
  26. # 线路1
  27. p1 = self.p1_1(x)
  28. # 线路2
  29. p2 = self.p2_2(self.p2_1(x))
  30. # 线路3
  31. p3 = self.p3_2(self.p3_1(x))
  32. # 线路4
  33. p4 = self.p4_2(self.p4_1(x))
  34. # 在通道维上concat输出
  35. outputs = tf.concat([p1, p2, p3, p4], axis=-1)
  36. return outputs


Inception(64, (96, 128), (16, 32), 32)




2.1 B1模块


  1. # 定义模型的输入
  2. inputs = tf.keras.Input(shape=(224,224,3),name = "input")
  3. # b1 模块
  4. # 卷积层7*7的卷积核,步长为2,pad是same,激活函数RELU
  5. x = tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding='same', activation='relu')(inputs)
  6. # 最大池化:窗口大小为3*3,步长为2,pad是same
  7. x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
  8. # b2 模块

2.2 B2模块


  1. # b2 模块
  2. # 卷积层1*1的卷积核,步长为2,pad是same,激活函数RELU
  3. x = tf.keras.layers.Conv2D(64, kernel_size=1, padding='same', activation='relu')(x)
  4. # 卷积层3*3的卷积核,步长为2,pad是same,激活函数RELU
  5. x = tf.keras.layers.Conv2D(192, kernel_size=3, padding='same', activation='relu')(x)
  6. # 最大池化:窗口大小为3*3,步长为2,pad是same
  7. x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)

2.3 B3模块


  1. # b3 模块
  2. # Inception
  3. x = Inception(64, (96, 128), (16, 32), 32)(x)
  4. # Inception
  5. x = Inception(128, (128, 192), (32, 96), 64)(x)
  6. # 最大池化:窗口大小为3*3,步长为2,pad是same
  7. x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)

2.4 B4模块



  1. def aux_classifier(x, filter_size):
  2. #x:输入数据,filter_size:卷积层卷积核个数,全连接层神经元个数
  3. # 池化层
  4. x = tf.keras.layers.AveragePooling2D(
  5. pool_size=5, strides=3, padding='same')(x)
  6. # 1x1 卷积层
  7. x = tf.keras.layers.Conv2D(filters=filter_size[0], kernel_size=1, strides=1,
  8. padding='valid', activation='relu')(x)
  9. # 展平
  10. x = tf.keras.layers.Flatten()(x)
  11. # 全连接层1
  12. x = tf.keras.layers.Dense(units=filter_size[1], activation='relu')(x)
  13. # softmax输出层
  14. x = tf.keras.layers.Dense(units=10, activation='softmax')(x)
  15. return x


  1. # b4 模块
  2. # Inception
  3. x = Inception(192, (96, 208), (16, 48), 64)(x)
  4. # 辅助输出1
  5. aux_output_1 = aux_classifier(x, [128, 1024])
  6. # Inception
  7. x = Inception(160, (112, 224), (24, 64), 64)(x)
  8. # Inception
  9. x = Inception(128, (128, 256), (24, 64), 64)(x)
  10. # Inception
  11. x = Inception(112, (144, 288), (32, 64), 64)(x)
  12. # 辅助输出2
  13. aux_output_2 = aux_classifier(x, [128, 1024])
  14. # Inception
  15. x = Inception(256, (160, 320), (32, 128), 128)(x)
  16. # 最大池化
  17. x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)

2.5 B5模块





  1. # b5 模块
  2. # Inception
  3. x = Inception(256, (160, 320), (32, 128), 128)(x)
  4. # Inception
  5. x = Inception(384, (192, 384), (48, 128), 128)(x)
  6. # GAP
  7. x = tf.keras.layers.GlobalAvgPool2D()(x)
  8. # 输出层
  9. main_outputs = tf.keras.layers.Dense(10,activation='softmax')(x)
  10. # 使用Model来创建模型,指明输入和输出


  1. # 使用Model来创建模型,指明输入和输出
  2. model = tf.keras.Model(inputs=inputs, outputs=[main_outputs,aux_output_1,aux_output_2])
  3. model.summary()

Model: "functional_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input (InputLayer) [(None, 224, 224, 3)] 0 _________________________________________________________________ conv2d_122 (Conv2D) (None, 112, 112, 64) 9472 _________________________________________________________________ max_pooling2d_27 (MaxPooling (None, 56, 56, 64) 0 _________________________________________________________________ conv2d_123 (Conv2D) (None, 56, 56, 64) 4160 _________________________________________________________________ conv2d_124 (Conv2D) (None, 56, 56, 192) 110784 _________________________________________________________________ max_pooling2d_28 (MaxPooling (None, 28, 28, 192) 0 _________________________________________________________________ inception_19 (Inception) (None, 28, 28, 256) 163696 _________________________________________________________________ inception_20 (Inception) (None, 28, 28, 480) 388736 _________________________________________________________________ max_pooling2d_31 (MaxPooling (None, 14, 14, 480) 0 _________________________________________________________________ inception_21 (Inception) (None, 14, 14, 512) 376176 _________________________________________________________________ inception_22 (Inception) (None, 14, 14, 512) 449160 _________________________________________________________________ inception_23 (Inception) (None, 14, 14, 512) 510104 _________________________________________________________________ inception_24 (Inception) (None, 14, 14, 528) 605376 _________________________________________________________________ inception_25 (Inception) (None, 14, 14, 832) 868352 _________________________________________________________________ max_pooling2d_37 (MaxPooling (None, 7, 7, 832) 0 _________________________________________________________________ inception_26 (Inception) (None, 7, 7, 832) 1043456 _________________________________________________________________ inception_27 (Inception) (None, 7, 7, 1024) 1444080 _________________________________________________________________ global_average_pooling2d_2 ( (None, 1024) 0 _________________________________________________________________ dense_10 (Dense) (None, 10) 10250 ================================================================= Total params: 5,983,802 Trainable params: 5,983,802 Non-trainable params: 0 ___________________________________________________________



2.1 数据读取


  1. import numpy as np
  2. # 获取手写数字数据集
  3. (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
  4. # 训练集数据维度的调整:N H W C
  5. train_images = np.reshape(train_images,(train_images.shape[0],train_images.shape[1],train_images.shape[2],1))
  6. # 测试集数据维度的调整:N H W C
  7. test_images = np.reshape(test_images,(test_images.shape[0],test_images.shape[1],test_images.shape[2],1))


  1. # 定义两个方法随机抽取部分样本演示
  2. # 获取训练集数据
  3. def get_train(size):
  4. # 随机生成要抽样的样本的索引
  5. index = np.random.randint(0, np.shape(train_images)[0], size)
  6. # 将这些数据resize成22*227大小
  7. resized_images = tf.image.resize_with_pad(train_images[index],224,224,)
  8. # 返回抽取的
  9. return resized_images.numpy(), train_labels[index]
  10. # 获取测试集数据
  11. def get_test(size):
  12. # 随机生成要抽样的样本的索引
  13. index = np.random.randint(0, np.shape(test_images)[0], size)
  14. # 将这些数据resize成224*224大小
  15. resized_images = tf.image.resize_with_pad(test_images[index],224,224,)
  16. # 返回抽样的测试样本
  17. return resized_images.numpy(), test_labels[index]


  1. # 获取训练样本和测试样本
  2. train_images,train_labels = get_train(256)
  3. test_images,test_labels = get_test(128)

3.2 模型编译

  1. # 指定优化器,损失函数和评价指标
  2. optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.0)
  3. # 模型有3个输出,所以指定损失函数对应的权重系数
  4. net.compile(optimizer=optimizer,
  5. loss='sparse_categorical_crossentropy',
  6. metrics=['accuracy'],loss_weights=[1,0.3,0.3])

3.3 模型训练

  1. # 模型训练:指定训练数据,batchsize,epoch,验证集
  2. net.fit(train_images,train_labels,batch_size=128,epochs=3,verbose=1,validation_split=0.1)


Epoch 1/3 2/2 [==============================] - 8s 4s/step - loss: 2.9527 - accuracy: 0.1174 - val_loss: 3.3254 - val_accuracy: 0.1154 Epoch 2/3 2/2 [==============================] - 7s 4s/step - loss: 2.8111 - accuracy: 0.0957 - val_loss: 2.2718 - val_accuracy: 0.2308 Epoch 3/3 2/2 [==============================] - 7s 4s/step - loss: 2.3055 - accuracy: 0.0957 - val_loss: 2.2669 - val_accuracy: 0.2308

2.4 模型评估

  1. # 指定测试数据
  2. net.evaluate(test_images,test_labels,verbose=1)


4/4 [==============================] - 1s 338ms/step - loss: 2.3110 - accuracy: 0.0781 [2.310971260070801, 0.078125]


GoogLeNet是以InceptionV1为基础进行构建的,所以GoogLeNet也叫做InceptionNet,在随后的⼏年⾥,研究⼈员对GoogLeNet进⾏了数次改进, 就又产生了InceptionV2,V3,V4等版本。

4.1 InceptionV2


4.2 InceptionV3



  • 知道GoogLeNet的网络架构:有基础模块Inception构成
  • 能够利用GoogleNet完成图像分类





  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. #conv+ReLU
  5. class BasicConv2d(nn.Module):
  6. def __init__(self, in_channels, out_channels, **kwargs):
  7. super(BasicConv2d, self).__init__()
  8. self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
  9. self.relu = nn.ReLU()
  10. def forward(self, x):
  11. x = self.conv(x)
  12. x = self.relu(x)
  13. return x
  14. #前部
  15. class Front(nn.Module):
  16. def __init__(self):
  17. super(Front, self).__init__()
  18. self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
  19. self.maxpool1 = nn.MaxPool2d(3, stride=2,ceil_mode=True)
  20. self.conv2 = BasicConv2d(64, 64, kernel_size=1)
  21. self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
  22. self.maxpool2 = nn.MaxPool2d(3, stride=2,ceil_mode=True)
  23. def forward(self,input):
  24. #输入:(N,3,224,224)
  25. x = self.conv1(input)#(N,64,112,112)
  26. x = self.maxpool1(x)#(N,64,56,56)
  27. x = self.conv2(x)#(N,64,56,56)
  28. x = self.conv3(x)#(N,192,56,56)
  29. x = self.maxpool2(x)#(N,192,28,28)
  30. return x
  31. class Inception(nn.Module):
  32. def __init__(self, in_channels, ch1x1, ch3x3_1_1, ch3x3_1, ch3x3_2_1, ch3x3_2, pool_ch):
  33. super(Inception, self).__init__()
  34. self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
  35. self.branch2 = nn.Sequential(
  36. BasicConv2d(in_channels, ch3x3_1_1, kernel_size=1),
  37. BasicConv2d(ch3x3_1_1, ch3x3_1, kernel_size=3, padding=1)
  38. )
  39. self.branch3 = nn.Sequential(
  40. BasicConv2d(in_channels, ch3x3_2_1, kernel_size=1),
  41. BasicConv2d(ch3x3_2_1, ch3x3_2, kernel_size=3, padding=1)
  42. )
  43. self.branch4 = nn.Sequential(
  44. nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
  45. BasicConv2d(in_channels, pool_ch, kernel_size=1)
  46. )
  47. def forward(self, x):
  48. #输入(N,Cin,Hin,Win)
  49. branch1 = self.branch1(x)#(N,C1,Hin,Win)
  50. branch2 = self.branch2(x)#(N,C2,Hin,Win)
  51. branch3 = self.branch3(x)#(N,C3,Hin,Win)
  52. branch4 = self.branch4(x)#(N,C4,Hin,Win)
  53. outputs = [branch1, branch2, branch3, branch4]
  54. return torch.cat(outputs, 1)#(N,C1+C2+C3+C4,Hin,Win)
  55. #辅助分类器
  56. class InceptionAux(nn.Module):
  57. def __init__(self, in_channels, num_classes):
  58. super(InceptionAux, self).__init__()
  59. self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
  60. self.conv = BasicConv2d(in_channels, 128, kernel_size=1)
  61. self.fc1 = nn.Linear(2048, 1024)
  62. self.fc2 = nn.Linear(1024, num_classes)
  63. def forward(self, x):
  64. # 输入:aux1:(N,512,14,14), aux2: (N,528,14,14)
  65. x = self.averagePool(x)# aux1:(N,512,4,4), aux2: (N,528,4,4)
  66. x = self.conv(x)# (N,128,4,4)
  67. x = torch.flatten(x, 1)# (N,2048)
  68. x = F.dropout(x, 0.5, training=self.training)
  69. x = F.relu(self.fc1(x))# (N,1024)
  70. x = F.dropout(x, 0.5, training=self.training)
  71. x = self.fc2(x)# (N,num_classes)
  72. return x
  73. # GooLeNet网络主体
  74. class GoogLeNet(nn.Module):
  75. def __init__(self, num_classes=1000, aux_logits=True):
  76. super(GoogLeNet, self).__init__()
  77. self.aux_logits = aux_logits
  78. self.front = Front()
  79. self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
  80. self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
  81. self.maxpool3 = nn.MaxPool2d(3, stride=2,ceil_mode=True)
  82. self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
  83. self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
  84. self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
  85. self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
  86. self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
  87. self.maxpool4 = nn.MaxPool2d(3, stride=2,ceil_mode=True)
  88. self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
  89. self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
  90. if self.aux_logits:
  91. self.aux1 = InceptionAux(512, num_classes)
  92. self.aux2 = InceptionAux(528, num_classes)
  93. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  94. self.dropout = nn.Dropout(0.4)
  95. self.fc = nn.Linear(1024, num_classes)
  96. def forward(self, x):
  97. #输入:(N,3,224,224)
  98. x = self.front(x)#(N,192,28,28)
  99. x = self.inception3a(x)#(N,256,28,28)
  100. x = self.inception3b(x)#(N,480,28,28)
  101. x = self.maxpool3(x)#(N,480,14,14)
  102. x = self.inception4a(x)#(N,512,14,14)
  103. if self.training and self.aux_logits:
  104. aux1 = self.aux1(x)
  105. x = self.inception4b(x)#(N,512,14,14)
  106. x = self.inception4c(x)#(N,512,14,14)
  107. x = self.inception4d(x)#(N,528,14,14)
  108. if self.training and self.aux_logits:
  109. aux2 = self.aux2(x)
  110. x = self.inception4e(x)#(N,832,14,14)
  111. x = self.maxpool4(x)#(N,832,7,7)
  112. x = self.inception5a(x)#(N,832,7,7)
  113. x = self.inception5b(x)#(N,1024,7,7)
  114. x = self.avgpool(x)#(N,1024,1,1)
  115. x = torch.flatten(x, 1)#(N,1024)
  116. x = self.dropout(x)
  117. x = self.fc(x)#(N,num_classes)
  118. if self.training and self.aux_logits:
  119. return x, aux2, aux1
  120. return x

使用 Pytorch 搭建 GoogleNet 网络

本代码使用的数据集来自 “花分类” 数据集,→ 传送门 ←(具体内容看 data_set文件夹下的 README.md)

  • model.py ( 搭建 GoogleNet 网络模型 )
  1. import torch.nn as nn
  2. import torch
  3. import torch.nn.functional as F
  4. class GoogleNet(nn.Module):
  5. # aux_logits: 是否使用辅助分类器(训练的时候为True, 验证的时候为False)
  6. def __init__(self, num_classes=1000, aux_logits=True, init_weight=False):
  7. super(GoogleNet, self).__init__()
  8. self.aux_logits = aux_logits
  9. self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
  10. self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) # 当结构为小数时,ceil_mode=True向上取整,=False向下取整
  11. # nn.LocalResponseNorm (此处省略)
  12. self.conv2 = nn.Sequential(
  13. BasicConv2d(64, 64, kernel_size=1),
  14. BasicConv2d(64, 192, kernel_size=3, padding=1)
  15. )
  16. self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  17. self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
  18. self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
  19. self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  20. self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
  21. self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
  22. self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
  23. self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
  24. self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
  25. self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  26. self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
  27. self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
  28. if aux_logits: # 使用辅助分类器
  29. self.aux1 = InceptionAux(512, num_classes)
  30. self.aux2 = InceptionAux(528, num_classes)
  31. self.avgpool = nn.AdaptiveAvgPool1d((1, 1))
  32. self.dropout = nn.Dropout(0.4)
  33. self.fc = nn.Linear(1024, num_classes)
  34. if init_weight:
  35. self._initialize_weight()
  36. def forward(self, x):
  37. x = self.conv1(x)
  38. x = self.maxpool1(x)
  39. x = self.conv2(x)
  40. x = self.maxpool2(x)
  41. x = self.inception3a(x)
  42. x = self.inception3b(x)
  43. x =self.maxpool3(x)
  44. x =self.inception4a(x)
  45. if self.training and self.aux_logits:
  46. aux1 = self.aux1(x)
  47. x = self.inception4b(x)
  48. x = self.inception4c(x)
  49. x = self.inception4d(x)
  50. if self.training and self.aux_logits:
  51. aux2 = self.aux2(x)
  52. x = self.inception4e(x)
  53. x =self.maxpool4(x)
  54. x = self.inception5a(x)
  55. x = self.inception5b(x)
  56. x = self.avgpool(x)
  57. x = torch.flatten(x, 1)
  58. x = self.dropout(x)
  59. x = self.fc(x)
  60. if self.training and self.aux_logits:
  61. return x, aux1, aux2
  62. return x
  63. def _initialize_weight(self):
  64. for m in self.modules():
  65. if isinstance(m, nn.Conv2d):
  66. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='')
  67. if m.bias is not None:
  68. nn.init.constant_(m.bias, 0)
  69. elif isinstance(m, nn.Linear):
  70. nn.init.normal_(m.weight, 0, 0.01)
  71. nn.init.constant_(m.bias, 0)
  72. # 创建 Inception 结构函数(模板)
  73. class Inception(nn.Module):
  74. # 参数为 Inception 结构的那几个卷积核的数量(详细见表)
  75. def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
  76. super(Inception, self).__init__()
  77. # 四个并联结构
  78. self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
  79. self.branch2 = nn.Sequential(
  80. BasicConv2d(in_channels, ch3x3red, kernel_size=1),
  81. BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
  82. )
  83. self.branch3 = nn.Sequential(
  84. BasicConv2d(in_channels, ch5x5red, kernel_size=1),
  85. BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)
  86. )
  87. self.branch4 = nn.Sequential(
  88. nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
  89. BasicConv2d(in_channels, pool_proj, kernel_size=1)
  90. )
  91. def forward(self, x):
  92. branch1 = self.branch1(x)
  93. branch2 = self.branch2(x)
  94. branch3 = self.branch3(x)
  95. branch4 = self.branch4(x)
  96. outputs = [branch1, branch2, branch3, branch4]
  97. return torch.cat(outputs, 1)
  98. # 创建辅助分类器结构函数(模板)
  99. class InceptionAux(nn.Module):
  100. def __init__(self, in_channels, num_classes):
  101. super(InceptionAux, self).__init__()
  102. self.avgPool = nn.AvgPool2d(kernel_size=5, stride=3)
  103. self.conv = BasicConv2d(in_channels, 128, kernel_size=1)
  104. self.fc1 = nn.Linear(2048, 1024)
  105. self.fc2 = nn.Linear(1024, num_classes)
  106. def forward(self, x):
  107. # aux1: N x 512 x 14 x 14 aux2: N x 528 x 14 x 14(输入)
  108. x = self.avgPool(x)
  109. # aux1: N x 512 x 4 x 4 aux2: N x 528 x 4 x 4(输出) 4 = (14 - 5)/3 + 1
  110. x = self.conv(x)
  111. x = torch.flatten(x, 1) # 展平
  112. x = F.dropout(x, 0.5, training=self.training)
  113. x = F.relu(self.fc1(x), inplace=True)
  114. x = F.dropout(x, 0.5, training=self.training)
  115. x = self.fc2(x)
  116. return x
  117. # 创建卷积层函数(模板)
  118. class BasicConv2d(nn.Module):
  119. def __init__(self, in_channels, out_channels, **kwargs):
  120. super(BasicConv2d, self).__init__()
  121. self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
  122. self.relu = nn.ReLU(True)
  123. def forward(self, x):
  124. x = self.conv(x)
  125. x = self.relu(x)
  126. return x
  • train.py ( 训练网络 )
  1. import os
  2. import json
  3. import torch
  4. import torch.nn as nn
  5. from torchvision import transforms, datasets
  6. import torch.optim as optim
  7. from tqdm import tqdm
  8. from model import GoogleNet
  9. def main():
  10. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  11. print("using {} device.".format(device))
  12. data_transform = {
  13. "train": transforms.Compose([transforms.RandomResizedCrop(224),
  14. transforms.RandomHorizontalFlip(),
  15. transforms.ToTensor(),
  16. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
  17. "val": transforms.Compose([transforms.Resize((224, 224)),
  18. transforms.ToTensor(),
  19. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
  20. data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
  21. image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path
  22. assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
  23. train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
  24. transform=data_transform["train"])
  25. train_num = len(train_dataset)
  26. # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
  27. flower_list = train_dataset.class_to_idx
  28. cla_dict = dict((val, key) for key, val in flower_list.items())
  29. # write dict into json file
  30. json_str = json.dumps(cla_dict, indent=4)
  31. with open('class_indices.json', 'w') as json_file:
  32. json_file.write(json_str)
  33. batch_size = 32
  34. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
  35. print('Using {} dataloader workers every process'.format(nw))
  36. train_loader = torch.utils.data.DataLoader(train_dataset,
  37. batch_size=batch_size, shuffle=True,
  38. num_workers=nw)
  39. validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
  40. transform=data_transform["val"])
  41. val_num = len(validate_dataset)
  42. validate_loader = torch.utils.data.DataLoader(validate_dataset,
  43. batch_size=batch_size, shuffle=False,
  44. num_workers=nw)
  45. print("using {} images for training, {} images for validation.".format(train_num,
  46. val_num))
  47. net = GoogleNet(num_classes=5, aux_logits=True, init_weights=True)
  48. net.to(device)
  49. loss_function = nn.CrossEntropyLoss()
  50. optimizer = optim.Adam(net.parameters(), lr=0.0003)
  51. epochs = 30
  52. best_acc = 0.0
  53. save_path = './googleNet.pth'
  54. train_steps = len(train_loader)
  55. for epoch in range(epochs):
  56. # train
  57. net.train()
  58. running_loss = 0.0
  59. train_bar = tqdm(train_loader)
  60. for step, data in enumerate(train_bar):
  61. images, labels = data
  62. optimizer.zero_grad()
  63. logits, aux_logits2, aux_logits1 = net(images.to(device)) # 由于训练的时候会使用辅助分类器,所有相当于有三个返回结果
  64. loss0 = loss_function(logits, labels.to(device))
  65. loss1 = loss_function(aux_logits1, labels.to(device))
  66. loss2 = loss_function(aux_logits2, labels.to(device))
  67. loss = loss0 + loss1 * 0.3 + loss2 * 0.3
  68. loss.backward()
  69. optimizer.step()
  70. # print statistics
  71. running_loss += loss.item()
  72. train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
  73. epochs,
  74. loss)
  75. # validate
  76. net.eval()
  77. acc = 0.0 # accumulate accurate number / epoch
  78. with torch.no_grad():
  79. val_bar = tqdm(validate_loader)
  80. for val_data in val_bar:
  81. val_images, val_labels = val_data
  82. outputs = net(val_images.to(device)) # eval model only have last output layer
  83. predict_y = torch.max(outputs, dim=1)[1]
  84. acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
  85. val_accurate = acc / val_num
  86. print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
  87. (epoch + 1, running_loss / train_steps, val_accurate))
  88. if val_accurate > best_acc:
  89. best_acc = val_accurate
  90. torch.save(net.state_dict(), save_path)
  91. print('Finished Training')
  92. if __name__ == '__main__':
  93. main()
  • predict.py ( 使用训练好的模型网络对图像分类 )
  1. import os
  2. import json
  3. import torch
  4. from PIL import Image
  5. from torchvision import transforms
  6. import matplotlib.pyplot as plt
  7. from model import GoogleNet
  8. def main():
  9. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  10. data_transform = transforms.Compose(
  11. [transforms.Resize((224, 224)),
  12. transforms.ToTensor(),
  13. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  14. # load image
  15. img_path = "../tulip.jpg"
  16. assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
  17. img = Image.open(img_path)
  18. plt.imshow(img)
  19. # [N, C, H, W]
  20. img = data_transform(img)
  21. # expand batch dimension
  22. img = torch.unsqueeze(img, dim=0)
  23. # read class_indict
  24. json_path = './class_indices.json'
  25. assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
  26. json_file = open(json_path, "r")
  27. class_indict = json.load(json_file)
  28. # create model
  29. model = GoogleNet(num_classes=5, aux_logits=False).to(device)
  30. # load model weights
  31. weights_path = "./googleNet.pth"
  32. assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
  33. missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device),
  34. strict=False)
  35. model.eval()
  36. with torch.no_grad():
  37. # predict class
  38. output = torch.squeeze(model(img.to(device))).cpu()
  39. predict = torch.softmax(output, dim=0)
  40. predict_cla = torch.argmax(predict).numpy()
  41. print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
  42. predict[predict_cla].numpy())
  43. plt.title(print_res)
  44. print(print_res)
  45. plt.show()
  46. if __name__ == '__main__':
  47. main()

参考文章:【学习笔记】GoogleNet 网络结构_googlenet特点-CSDN博客



