当前位置:   article > 正文

Pytorch之经典神经网络CNN(七) —— GoogLeNet(InceptionV1)(Bottleneck)(全局平均池化GAP)(1*1卷积)(多尺度)(flower花卉数据集)_googlenet bottleneck

googlenet bottleneck

2014年 Google提出的

是和VGG同年出现的,在ILSVRC(ImageNet) 2014中获得冠军,vgg屈居第二

GoogLeNet也称Inception V1。之所以叫GoogLeNet而不叫GoogleNet是为了致敬LeCun的LeNet.


在当时取得了非常大的影响, 因为网络的结构变得前所未有, 它颠覆了大家对卷积网络的串联的印象和固定做法, 采用了一种非常有效的 inception 模块, 得到了比 VGG 更深的网络结构, 但是却比 VGG 的参数更少, 因为其去掉了后面的全连接层, 所以参数大大减少, 同时有了很高的计算效率。GoogLeNet 22层,参数量比AlexNet少12倍

它虽然在名字上向LeNet致敬,但在⽹络结构上已经很难看到LeNet的影⼦。 GoogLeNet吸收了NiN中串联⽹络的思想,并在此基础上做了很⼤改进

辅助分类器 Auxiliary Classifier

GoogLeNet用到了辅助分类器。Inception Net一共有22层,除了最后一层的输出结果,中间节点的分类效果也有可能是很好的,所以GoogLeNet将中间某一层的输出作为分类,并以一个较小的权重(0.3)加到最终的分类结果中。一共有2个这样的辅助分类节点。

AlexNet 和 VGG 都只有1个输出层,GoogLeNet 有3个输出层,其中的两个是辅助分类层

如下图所示,网络主干右边的 两个分支 就是 辅助分类器,其结构一模一样。在训练模型时,将两个辅助分类器的损失乘以权重(论文中是0.3)加到网络的整体损失上,再进行反向传播。



       在分类器之前,采用Network in Network中用Averagepool(平均池化)来代替全连接层的思想,而在avg pool之后,还是添加了一个全连接层,是为了大家做finetune(微调)。

       而无论是VGG还是LeNet、AlexNet,在输出层方面均是采用连续三个全连接层,全连接层的输入是前面卷积层的输出经过reshape得到。有些地方说googlenet去掉了expensive fc layers是指去掉了前两层计算量大的fc层。

      据发现,GoogLeNet将fully-connected layer用avg pooling layer代替后,top-1 accuracy 提高了大约0.6%;然而即使在去除了fully-connected layer后,依然必须dropout。


GoogLeNet/Inception 的后续版本

v1: 最早的版本
v2: 加入 batch normalization 加快训练
v3: 对 inception 模块做了调整
v4: 基于 ResNet 加入了 残差连接




Inception块⾥有4条并⾏的线路。前3条线路使⽤窗口⼤小分别是1 × 1、 3 ×3和5 × 5的卷积层来抽取不同空间尺⼨下的信息,其中中间2个线路会对输⼊先做1 × 1卷积来减少输⼊通道数,以降低模型复杂度。第四条线路则使⽤3 × 3最⼤池化层,后接1 × 1卷积层来改变通道数4条线路都使⽤了合适的填充来使输⼊与输出的⾼和宽⼀致。最后我们将每条线路的输出在通道维上连结,并输⼊接下来的层中去


Inception模块的核心思想:  利用不同大小的卷积核实现不同尺度的感知,最后进行融合,可以得到图像更好的表征

Inception 结构的主要思路是怎样用密集成分来近似最优的局部稀疏结构。



GoogLeNet 可以看作是很多个 Inception 模块的串联。




原论文中使用了多个输出来解决梯度消失的问题, 这里我们只定义一个简单版本的 GoogLeNet, 简化为一个输出

输入的size要是96*96,几通道没关系。不管用CIFAR10还是Fashion-MNIST, 都要resize成96*96的

对于1通道的Fashion-MNIST数据集,size 96*96,我4G显存的GPU的batch_size可以开到32

  1. import torch
  2. from torch import nn, optim
  3. import torch.nn.functional as F
  4. from datetime import datetime
  5. import torchvision
  6. class inception(nn.Module):
  7. def __init__(self, in_channel, out1_1, out2_1, out2_3, out3_1, out3_5, out4_1):
  8. super(inception, self).__init__()
  9. #默认stride=1,padding=0
  10. # 第一条线路
  11. self.branch1x1 = nn.Sequential(
  12. nn.Conv2d(in_channel, out1_1, kernel_size=1),
  13. nn.BatchNorm2d(out1_1, eps=1e-3),
  14. nn.ReLU(True)
  15. )
  16. # 第二条线路
  17. self.branch3x3 = nn.Sequential(
  18. # conv_relu(in_channel, out2_1, 1),
  19. nn.Conv2d(in_channel, out2_1, kernel_size=1),
  20. nn.BatchNorm2d(out2_1, eps=1e-3),
  21. nn.ReLU(True),
  22. # conv_relu(out2_1, out2_3, 3, padding=1)
  23. nn.Conv2d(out2_1, out2_3, kernel_size=3, padding=1),
  24. nn.BatchNorm2d(out2_3, eps=1e-3),
  25. nn.ReLU(True),
  26. )
  27. #第三条线路
  28. self.branch5x5 = nn.Sequential(
  29. # conv_relu(in_channel, out3_1, 1),
  30. nn.Conv2d(in_channel, out3_1, kernel_size=1),
  31. nn.BatchNorm2d(out3_1, eps=1e-3),
  32. nn.ReLU(True),
  33. # conv_relu(out3_1, out3_5, 5, padding=2)
  34. nn.Conv2d(out3_1, out3_5, kernel_size=5, padding=2),
  35. nn.BatchNorm2d(out3_5, eps=1e-3),
  36. nn.ReLU(True),
  37. )
  38. #第四条线路
  39. self.branch_pool = nn.Sequential(
  40. nn.MaxPool2d(3, stride=1, padding=1),
  41. # conv_relu(in_channel, out4_1, 1)
  42. nn.Conv2d(in_channel, out4_1, kernel_size=1),
  43. nn.BatchNorm2d(out4_1, eps=1e-3),
  44. nn.ReLU(True),
  45. )
  46. def forward(self, x):
  47. f1 = self.branch1x1(x)
  48. f2 = self.branch3x3(x)
  49. f3 = self.branch5x5(x)
  50. f4 = self.branch_pool(x)
  51. output = torch.cat((f1, f2, f3, f4), dim=1)
  52. return output
  53. # test_net = inception(3, 64, 48, 64, 64, 96, 32)
  54. # test_x = torch.tensor(torch.zeros(1, 3, 96, 96))
  55. # print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))
  56. # test_y = test_net(test_x)
  57. # print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))
  58. class googlenet(nn.Module):
  59. def __init__(self, in_channel, num_classes, verbose=False):
  60. super(googlenet, self).__init__()
  61. self.verbose = verbose
  62. self.block1 = nn.Sequential(
  63. # conv_relu(in_channel, out_channel=64, kernel=7, stride=2, padding=3),
  64. nn.Conv2d(in_channels=in_channel, out_channels=64, kernel_size=7, stride=2, padding=3),
  65. nn.BatchNorm2d(64, eps=1e-3),
  66. nn.ReLU(True),
  67. nn.MaxPool2d(kernel_size=3, stride=2)
  68. )
  69. self.block2 = nn.Sequential(
  70. # conv_relu(64, 64, kernel=1),
  71. nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),
  72. nn.BatchNorm2d(64, eps=1e-3),
  73. nn.ReLU(True),
  74. # conv_relu(64, 192, kernel=3, padding=1),
  75. nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),
  76. nn.BatchNorm2d(192, eps=1e-3),
  77. nn.ReLU(True),
  78. nn.MaxPool2d(kernel_size=3, stride=2)
  79. )
  80. self.block3 = nn.Sequential(
  81. inception(192, 64, 96, 128, 16, 32, 32),
  82. inception(256, 128, 128, 192, 32, 96, 64),
  83. nn.MaxPool2d(kernel_size=3, stride=2)
  84. )
  85. self.block4 = nn.Sequential(
  86. inception(480, 192, 96, 208, 16, 48, 64),
  87. inception(512, 160, 112, 224, 24, 64, 64),
  88. inception(512, 128, 128, 256, 24, 64, 64),
  89. inception(512, 112, 144, 288, 32, 64, 64),
  90. inception(528, 256, 160, 320, 32, 128, 128),
  91. nn.MaxPool2d(3, 2)
  92. )
  93. self.block5 = nn.Sequential(
  94. inception(832, 256, 160, 320, 32, 128, 128),
  95. inception(832, 384, 182, 384, 48, 128, 128),
  96. nn.AvgPool2d(2)
  97. )
  98. self.classifier = nn.Linear(1024, num_classes)
  99. def forward(self, x):
  100. x = self.block1(x)
  101. if self.verbose:
  102. print('block 1 output: {}'.format(x.shape))
  103. x = self.block2(x)
  104. if self.verbose:
  105. print('block 2 output: {}'.format(x.shape))
  106. x = self.block3(x)
  107. if self.verbose:
  108. print('block 3 output: {}'.format(x.shape))
  109. x = self.block4(x)
  110. if self.verbose:
  111. print('block 4 output: {}'.format(x.shape))
  112. x = self.block5(x)
  113. if self.verbose:
  114. print('block 5 output: {}'.format(x.shape))
  115. # print(x.shape)
  116. #x是[b,1024,1,1]
  117. x = x.view(x.shape[0], -1)
  118. #x是[b,1024]
  119. x = self.classifier(x)
  120. return x
  121. def get_acc(output, label):
  122. total = output.shape[0]
  123. # output是概率,每行概率最高的就是预测值
  124. _, pred_label = output.max(1)
  125. num_correct = (pred_label == label).sum().item()
  126. return num_correct / total
  127. batch_size = 32
  128. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  129. transform = torchvision.transforms.Compose([
  130. torchvision.transforms.Resize(size=96),
  131. torchvision.transforms.ToTensor()
  132. ])
  133. train_set = torchvision.datasets.FashionMNIST(
  134. root='dataset/',
  135. train=True,
  136. download=True,
  137. transform=transform
  138. )
  139. # hand-out留出法划分
  140. train_set, val_set = torch.utils.data.random_split(train_set, [50000, 10000])
  141. test_set = torchvision.datasets.FashionMNIST(
  142. root='dataset/',
  143. train=False,
  144. download=True,
  145. transform=transform
  146. )
  147. train_loader = torch.utils.data.DataLoader(
  148. dataset=train_set,
  149. batch_size=batch_size,
  150. shuffle=True
  151. )
  152. val_loader = torch.utils.data.DataLoader(
  153. dataset=val_set,
  154. batch_size=batch_size,
  155. shuffle=True
  156. )
  157. test_loader = torch.utils.data.DataLoader(
  158. dataset=test_set,
  159. batch_size=batch_size,
  160. shuffle=False
  161. )
  162. net = googlenet(1,10)
  163. lr = 2e-3
  164. optimizer = optim.Adam(net.parameters(), lr=lr)
  165. critetion = nn.CrossEntropyLoss()
  166. net = net.to(device)
  167. prev_time = datetime.now()
  168. valid_data = val_loader
  169. for epoch in range(3):
  170. train_loss = 0
  171. train_acc = 0
  172. net.train()
  173. for inputs, labels in train_loader:
  174. inputs = inputs.to(device)
  175. labels = labels.to(device)
  176. # forward
  177. outputs = net(inputs)
  178. loss = critetion(outputs, labels)
  179. # backward
  180. optimizer.zero_grad()
  181. loss.backward()
  182. optimizer.step()
  183. train_loss += loss.item()
  184. train_acc += get_acc(outputs, labels)
  185. # 最后还要求平均的
  186. # 显示时间
  187. cur_time = datetime.now()
  188. h, remainder = divmod((cur_time - prev_time).seconds, 3600)
  189. m, s = divmod(remainder, 60)
  190. # time_str = 'Time %02d:%02d:%02d'%(h,m,s)
  191. time_str = 'Time %02d:%02d:%02d(from %02d/%02d/%02d %02d:%02d:%02d to %02d/%02d/%02d %02d:%02d:%02d)' % (
  192. h, m, s, prev_time.year, prev_time.month, prev_time.day, prev_time.hour, prev_time.minute, prev_time.second,
  193. cur_time.year, cur_time.month, cur_time.day, cur_time.hour, cur_time.minute, cur_time.second)
  194. prev_time = cur_time
  195. # validation
  196. with torch.no_grad():
  197. net.eval()
  198. valid_loss = 0
  199. valid_acc = 0
  200. for inputs, labels in valid_data:
  201. inputs = inputs.to(device)
  202. labels = labels.to(device)
  203. outputs = net(inputs)
  204. loss = critetion(outputs, labels)
  205. valid_loss += loss.item()
  206. valid_acc += get_acc(outputs, labels)
  207. print("Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f,"
  208. % (epoch, train_loss / len(train_loader), train_acc / len(train_loader), valid_loss / len(valid_data),
  209. valid_acc / len(valid_data))
  210. + time_str)
  211. torch.save(net.state_dict(), 'checkpoints/params.pkl')
  212. # 测试
  213. with torch.no_grad():
  214. net.eval()
  215. correct = 0
  216. total = 0
  217. for (images, labels) in test_loader:
  218. images, labels = images.to(device), labels.to(device)
  219. output = net(images)
  220. _, predicted = torch.max(output.data, 1)
  221. total += labels.size(0)
  222. correct += (predicted == labels).sum().item()
  223. print("The accuracy of total {} images: {}%".format(total, 100 * correct / total))

可以看到输入的尺寸不断减小, 通道的维度不断增加





  1. import os
  2. from shutil import copy
  3. import random
  4. def mkfile(file):
  5. if not os.path.exists(file):
  6. os.makedirs(file)
  7. root_dir = 'dataset/flower/'
  8. file_dir = root_dir + 'flower_photos/'
  9. train_dir = root_dir + 'train/'
  10. val_dir = root_dir + 'val/'
  11. flower_class = [cla for cla in os.listdir(file_dir) if ".txt" not in cla]
  12. mkfile(train_dir)
  13. mkfile(val_dir)
  14. for cla in flower_class:
  15. mkfile(train_dir + cla)
  16. for cla in flower_class:
  17. mkfile(val_dir + cla)
  18. split_rate = 0.1
  19. for cla in flower_class:
  20. cla_path = file_dir + cla + '/'
  21. images = os.listdir(cla_path)
  22. num = len(images)
  23. eval_index = random.sample(images, k=int(num*split_rate))
  24. for index, image in enumerate(images):
  25. if image in eval_index:
  26. image_path = cla_path + image
  27. new_path = val_dir + cla
  28. copy(image_path, new_path)
  29. else:
  30. image_path = cla_path + image
  31. new_path = train_dir + cla
  32. copy(image_path, new_path)
  33. print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
  34. print()
  35. print("processing done!")




  1. import torch
  2. import torch.nn as nn
  3. from torchvision import transforms, datasets
  4. import json
  5. import torch.optim as optim
  6. import torch.nn.functional as F
  7. class BasicConv2d(nn.Module):
  8. def __init__(self, in_channels, out_channels, **kwargs):
  9. super(BasicConv2d, self).__init__()
  10. self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
  11. self.relu = nn.ReLU(inplace=True)
  12. def forward(self, x):
  13. x = self.conv(x)
  14. x = self.relu(x)
  15. return x
  16. class Inception(nn.Module):
  17. def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
  18. super(Inception, self).__init__()
  19. self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
  20. self.branch2 = nn.Sequential(
  21. BasicConv2d(in_channels, ch3x3red, kernel_size=1),
  22. BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) # 保证输出大小等于输入大小
  23. )
  24. self.branch3 = nn.Sequential(
  25. BasicConv2d(in_channels, ch5x5red, kernel_size=1),
  26. BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2) # 保证输出大小等于输入大小
  27. )
  28. self.branch4 = nn.Sequential(
  29. nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
  30. BasicConv2d(in_channels, pool_proj, kernel_size=1)
  31. )
  32. def forward(self, x):
  33. branch1 = self.branch1(x)
  34. branch2 = self.branch2(x)
  35. branch3 = self.branch3(x)
  36. branch4 = self.branch4(x)
  37. outputs = [branch1, branch2, branch3, branch4]
  38. return torch.cat(outputs, 1)
  39. class GoogLeNet(nn.Module):
  40. def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
  41. super(GoogLeNet, self).__init__()
  42. self.aux_logits = aux_logits
  43. self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
  44. self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  45. self.conv2 = BasicConv2d(64, 64, kernel_size=1)
  46. self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
  47. self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  48. self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
  49. self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
  50. self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  51. self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
  52. self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
  53. self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
  54. self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
  55. self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
  56. self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  57. self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
  58. self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
  59. if self.aux_logits:
  60. self.aux1 = InceptionAux(512, num_classes)
  61. self.aux2 = InceptionAux(528, num_classes)
  62. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  63. self.dropout = nn.Dropout(0.4)
  64. self.fc = nn.Linear(1024, num_classes)
  65. if init_weights:
  66. self._initialize_weights()
  67. def forward(self, x):
  68. # N x 3 x 224 x 224
  69. x = self.conv1(x)
  70. # N x 64 x 112 x 112
  71. x = self.maxpool1(x)
  72. # N x 64 x 56 x 56
  73. x = self.conv2(x)
  74. # N x 64 x 56 x 56
  75. x = self.conv3(x)
  76. # N x 192 x 56 x 56
  77. x = self.maxpool2(x)
  78. # N x 192 x 28 x 28
  79. x = self.inception3a(x)
  80. # N x 256 x 28 x 28
  81. x = self.inception3b(x)
  82. # N x 480 x 28 x 28
  83. x = self.maxpool3(x)
  84. # N x 480 x 14 x 14
  85. x = self.inception4a(x)
  86. # N x 512 x 14 x 14
  87. if self.training and self.aux_logits: # eval model lose this layer
  88. aux1 = self.aux1(x)
  89. x = self.inception4b(x)
  90. # N x 512 x 14 x 14
  91. x = self.inception4c(x)
  92. # N x 512 x 14 x 14
  93. x = self.inception4d(x)
  94. # N x 528 x 14 x 14
  95. if self.training and self.aux_logits: # eval model lose this layer
  96. aux2 = self.aux2(x)
  97. x = self.inception4e(x)
  98. # N x 832 x 14 x 14
  99. x = self.maxpool4(x)
  100. # N x 832 x 7 x 7
  101. x = self.inception5a(x)
  102. # N x 832 x 7 x 7
  103. x = self.inception5b(x)
  104. # N x 1024 x 7 x 7
  105. x = self.avgpool(x)
  106. # N x 1024 x 1 x 1
  107. x = torch.flatten(x, 1)
  108. # N x 1024
  109. x = self.dropout(x)
  110. x = self.fc(x)
  111. # N x 1000 (num_classes)
  112. if self.training and self.aux_logits: # eval model lose this layer
  113. return x, aux2, aux1
  114. return x
  115. def _initialize_weights(self):
  116. for m in self.modules():
  117. if isinstance(m, nn.Conv2d):
  118. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  119. if m.bias is not None:
  120. nn.init.constant_(m.bias, 0)
  121. elif isinstance(m, nn.Linear):
  122. nn.init.normal_(m.weight, 0, 0.01)
  123. nn.init.constant_(m.bias, 0)
  124. class InceptionAux(nn.Module):
  125. def __init__(self, in_channels, num_classes):
  126. super(InceptionAux, self).__init__()
  127. self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
  128. self.conv = BasicConv2d(in_channels, 128, kernel_size=1) # output[batch, 128, 4, 4]
  129. self.fc1 = nn.Linear(2048, 1024)
  130. self.fc2 = nn.Linear(1024, num_classes)
  131. def forward(self, x):
  132. # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
  133. x = self.averagePool(x)
  134. # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
  135. x = self.conv(x)
  136. # N x 128 x 4 x 4
  137. x = torch.flatten(x, 1)
  138. x = F.dropout(x, 0.5, training=self.training)
  139. # N x 2048
  140. x = F.relu(self.fc1(x), inplace=True)
  141. x = F.dropout(x, 0.5, training=self.training)
  142. # N x 1024
  143. x = self.fc2(x)
  144. # N x num_classes
  145. return x
  146. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  147. data_transform = {
  148. "train": transforms.Compose([transforms.RandomResizedCrop(224),
  149. transforms.RandomHorizontalFlip(),
  150. transforms.ToTensor(),
  151. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
  152. ),
  153. "val": transforms.Compose([transforms.Resize((224, 224)),
  154. transforms.ToTensor(),
  155. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
  156. )
  157. }
  158. root = 'dataset/flower/' # flower data set path
  159. train_dataset = datasets.ImageFolder(
  160. root=root + "train",
  161. transform=data_transform["train"]
  162. )
  163. train_num = len(train_dataset)
  164. # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
  165. flower_list = train_dataset.class_to_idx
  166. cla_dict = dict((val, key) for key, val in flower_list.items())
  167. # write dict into json file
  168. json_str = json.dumps(cla_dict, indent=4)
  169. with open(root + 'class_indices.json', 'w') as json_file:
  170. json_file.write(json_str)
  171. batch_size = 32
  172. train_loader = torch.utils.data.DataLoader(
  173. dataset=train_dataset,
  174. batch_size=batch_size, shuffle=True,
  175. num_workers=0
  176. )
  177. validate_dataset = datasets.ImageFolder(
  178. root=root + "val",
  179. transform=data_transform["val"]
  180. )
  181. val_num = len(validate_dataset)
  182. validate_loader = torch.utils.data.DataLoader(
  183. dataset=validate_dataset,
  184. batch_size=batch_size, shuffle=False,
  185. num_workers=0
  186. )
  187. net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
  188. net.to(device)
  189. criterion = nn.CrossEntropyLoss()
  190. optimizer = optim.Adam(net.parameters(), lr=3e-4)
  191. best_acc = 0.0
  192. save_path = 'checkpoints/googleNet.pth'
  193. for epoch in range(6):
  194. # train
  195. net.train()
  196. running_loss = 0.0
  197. for step, data in enumerate(train_loader, start=0):
  198. images, labels = data
  199. optimizer.zero_grad()
  200. logits, aux_logits2, aux_logits1 = net(images.to(device))
  201. loss0 = criterion(logits, labels.to(device))
  202. loss1 = criterion(aux_logits1, labels.to(device))
  203. loss2 = criterion(aux_logits2, labels.to(device))
  204. loss = loss0 + loss1 * 0.3 + loss2 * 0.3
  205. loss.backward()
  206. optimizer.step()
  207. # print statistics
  208. running_loss += loss.item()
  209. # print train process
  210. rate = (step + 1) / len(train_loader)
  211. a = "*" * int(rate * 50)
  212. b = "." * int((1 - rate) * 50)
  213. print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
  214. print()
  215. # validate
  216. net.eval()
  217. acc = 0.0 # accumulate accurate number / epoch
  218. with torch.no_grad():
  219. for data_test in validate_loader:
  220. test_images, test_labels = data_test
  221. outputs = net(test_images.to(device)) # eval model only have last output layer
  222. predict_y = torch.max(outputs, dim=1)[1]
  223. acc += (predict_y == test_labels.to(device)).sum().item()
  224. accurate_test = acc / val_num
  225. if accurate_test > best_acc:
  226. torch.save(net.state_dict(), save_path)
  227. print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
  228. (epoch + 1, running_loss / step, accurate_test))
  229. print('Finished Training')




  1. import torch
  2. from model import GoogLeNet
  3. from PIL import Image
  4. from torchvision import transforms
  5. import matplotlib.pyplot as plt
  6. import json
  7. data_transform = transforms.Compose(
  8. [transforms.Resize((224, 224)),
  9. transforms.ToTensor(),
  10. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  11. # load image
  12. # img = Image.open("../tulip.jpg")
  13. img = Image.open("dataset/flower/val/tulips/10791227_7168491604.jpg")
  14. plt.imshow(img)
  15. # [N, C, H, W]
  16. img = data_transform(img)
  17. # expand batch dimension
  18. img = torch.unsqueeze(img, dim=0)
  19. # read class_indict
  20. try:
  21. json_file = open('dataset/flower/class_indices.json', 'r')
  22. class_indict = json.load(json_file)
  23. except Exception as e:
  24. print(e)
  25. exit(-1)
  26. # create model
  27. model = GoogLeNet(num_classes=5, aux_logits=False)
  28. # load model weights
  29. model_weight_path = "checkpoints/googleNet.pth"
  30. missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)
  31. model.eval()
  32. with torch.no_grad():
  33. # predict class
  34. output = torch.squeeze(model(img))
  35. predict = torch.softmax(output, dim=0)
  36. predict_cla = torch.argmax(predict).numpy()
  37. print(class_indict[str(predict_cla)])
  38. plt.show()

