当前位置:   article > 正文

Pytorch之经典神经网络CNN(七) —— GoogLeNet(InceptionV1)(Bottleneck)(全局平均池化GAP)(1*1卷积)(多尺度)(flower花卉数据集)_googlenet bottleneck

googlenet bottleneck

2014年 Google提出的

是和VGG同年出现的,在ILSVRC(ImageNet) 2014中获得冠军,vgg屈居第二

GoogLeNet也称Inception V1。之所以叫GoogLeNet而不叫GoogleNet是为了致敬LeCun的LeNet.

GoogLeNet——含并行连结的网络

在当时取得了非常大的影响, 因为网络的结构变得前所未有, 它颠覆了大家对卷积网络的串联的印象和固定做法, 采用了一种非常有效的 inception 模块, 得到了比 VGG 更深的网络结构, 但是却比 VGG 的参数更少, 因为其去掉了后面的全连接层, 所以参数大大减少, 同时有了很高的计算效率。GoogLeNet 22层,参数量比AlexNet少12倍

它虽然在名字上向LeNet致敬,但在⽹络结构上已经很难看到LeNet的影⼦。 GoogLeNet吸收了NiN中串联⽹络的思想,并在此基础上做了很⼤改进


辅助分类器 Auxiliary Classifier

GoogLeNet用到了辅助分类器。Inception Net一共有22层,除了最后一层的输出结果,中间节点的分类效果也有可能是很好的,所以GoogLeNet将中间某一层的输出作为分类,并以一个较小的权重(0.3)加到最终的分类结果中。一共有2个这样的辅助分类节点。

AlexNet 和 VGG 都只有1个输出层,GoogLeNet 有3个输出层,其中的两个是辅助分类层

如下图所示,网络主干右边的 两个分支 就是 辅助分类器,其结构一模一样。在训练模型时,将两个辅助分类器的损失乘以权重(论文中是0.3)加到网络的整体损失上,再进行反向传播。

两个辅助分类器的输入分别来自Inception(4a)和Inception(4d)

有没有全连接层

       在分类器之前,采用Network in Network中用Averagepool(平均池化)来代替全连接层的思想,而在avg pool之后,还是添加了一个全连接层,是为了大家做finetune(微调)。

       而无论是VGG还是LeNet、AlexNet,在输出层方面均是采用连续三个全连接层,全连接层的输入是前面卷积层的输出经过reshape得到。有些地方说googlenet去掉了expensive fc layers是指去掉了前两层计算量大的fc层。

      据发现,GoogLeNet将fully-connected layer用avg pooling layer代替后,top-1 accuracy 提高了大约0.6%;然而即使在去除了fully-connected layer后,依然必须dropout。

参数列表

GoogLeNet/Inception 的后续版本

v1: 最早的版本
v2: 加入 batch normalization 加快训练
v3: 对 inception 模块做了调整
v4: 基于 ResNet 加入了 残差连接

Inception模块

GoogLeNet中的基础卷积块叫作Inception块,得名于同名电影《盗梦空间》(Inception)。与NiN块相⽐,这个基础块在结构上更加复杂。

其实左边才是Inception原始的结构,右边的应该叫Inception+降维

Inception块⾥有4条并⾏的线路。前3条线路使⽤窗口⼤小分别是1 × 1、 3 ×3和5 × 5的卷积层来抽取不同空间尺⼨下的信息,其中中间2个线路会对输⼊先做1 × 1卷积来减少输⼊通道数,以降低模型复杂度。第四条线路则使⽤3 × 3最⼤池化层,后接1 × 1卷积层来改变通道数4条线路都使⽤了合适的填充来使输⼊与输出的⾼和宽⼀致。最后我们将每条线路的输出在通道维上连结,并输⼊接下来的层中去

注意这里四条路出来的feature不是相加,而是concat到一起

Inception模块的核心思想:  利用不同大小的卷积核实现不同尺度的感知,最后进行融合,可以得到图像更好的表征

Inception 结构的主要思路是怎样用密集成分来近似最优的局部稀疏结构。

网络实现(简化版)

Fashion-MNIST数据集

GoogLeNet 可以看作是很多个 Inception 模块的串联。

GoogLeNet将多个设计精细的Inception块和其他层串串联起来。其中Inception块的通道数分配之⽐比是在ImageNet数据集上通过⼤大量量的实验得来的。

GoogLeNet和它的后继者们一度是ImageNet上最⾼高效的模型之⼀:在类似的测试精度下,它们的计算复杂度往往更更低。
 

代码实现如下图的GoogLenet网络

原论文中使用了多个输出来解决梯度消失的问题, 这里我们只定义一个简单版本的 GoogLeNet, 简化为一个输出

输入的size要是96*96,几通道没关系。不管用CIFAR10还是Fashion-MNIST, 都要resize成96*96的

对于1通道的Fashion-MNIST数据集,size 96*96,我4G显存的GPU的batch_size可以开到32

  1. import torch
  2. from torch import nn, optim
  3. import torch.nn.functional as F
  4. from datetime import datetime
  5. import torchvision
  6. class inception(nn.Module):
  7. def __init__(self, in_channel, out1_1, out2_1, out2_3, out3_1, out3_5, out4_1):
  8. super(inception, self).__init__()
  9. #默认stride=1,padding=0
  10. # 第一条线路
  11. self.branch1x1 = nn.Sequential(
  12. nn.Conv2d(in_channel, out1_1, kernel_size=1),
  13. nn.BatchNorm2d(out1_1, eps=1e-3),
  14. nn.ReLU(True)
  15. )
  16. # 第二条线路
  17. self.branch3x3 = nn.Sequential(
  18. # conv_relu(in_channel, out2_1, 1),
  19. nn.Conv2d(in_channel, out2_1, kernel_size=1),
  20. nn.BatchNorm2d(out2_1, eps=1e-3),
  21. nn.ReLU(True),
  22. # conv_relu(out2_1, out2_3, 3, padding=1)
  23. nn.Conv2d(out2_1, out2_3, kernel_size=3, padding=1),
  24. nn.BatchNorm2d(out2_3, eps=1e-3),
  25. nn.ReLU(True),
  26. )
  27. #第三条线路
  28. self.branch5x5 = nn.Sequential(
  29. # conv_relu(in_channel, out3_1, 1),
  30. nn.Conv2d(in_channel, out3_1, kernel_size=1),
  31. nn.BatchNorm2d(out3_1, eps=1e-3),
  32. nn.ReLU(True),
  33. # conv_relu(out3_1, out3_5, 5, padding=2)
  34. nn.Conv2d(out3_1, out3_5, kernel_size=5, padding=2),
  35. nn.BatchNorm2d(out3_5, eps=1e-3),
  36. nn.ReLU(True),
  37. )
  38. #第四条线路
  39. self.branch_pool = nn.Sequential(
  40. nn.MaxPool2d(3, stride=1, padding=1),
  41. # conv_relu(in_channel, out4_1, 1)
  42. nn.Conv2d(in_channel, out4_1, kernel_size=1),
  43. nn.BatchNorm2d(out4_1, eps=1e-3),
  44. nn.ReLU(True),
  45. )
  46. def forward(self, x):
  47. f1 = self.branch1x1(x)
  48. f2 = self.branch3x3(x)
  49. f3 = self.branch5x5(x)
  50. f4 = self.branch_pool(x)
  51. output = torch.cat((f1, f2, f3, f4), dim=1)
  52. return output
  53. # test_net = inception(3, 64, 48, 64, 64, 96, 32)
  54. # test_x = torch.tensor(torch.zeros(1, 3, 96, 96))
  55. # print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))
  56. # test_y = test_net(test_x)
  57. # print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))
  58. class googlenet(nn.Module):
  59. def __init__(self, in_channel, num_classes, verbose=False):
  60. super(googlenet, self).__init__()
  61. self.verbose = verbose
  62. self.block1 = nn.Sequential(
  63. # conv_relu(in_channel, out_channel=64, kernel=7, stride=2, padding=3),
  64. nn.Conv2d(in_channels=in_channel, out_channels=64, kernel_size=7, stride=2, padding=3),
  65. nn.BatchNorm2d(64, eps=1e-3),
  66. nn.ReLU(True),
  67. nn.MaxPool2d(kernel_size=3, stride=2)
  68. )
  69. self.block2 = nn.Sequential(
  70. # conv_relu(64, 64, kernel=1),
  71. nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),
  72. nn.BatchNorm2d(64, eps=1e-3),
  73. nn.ReLU(True),
  74. # conv_relu(64, 192, kernel=3, padding=1),
  75. nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),
  76. nn.BatchNorm2d(192, eps=1e-3),
  77. nn.ReLU(True),
  78. nn.MaxPool2d(kernel_size=3, stride=2)
  79. )
  80. self.block3 = nn.Sequential(
  81. inception(192, 64, 96, 128, 16, 32, 32),
  82. inception(256, 128, 128, 192, 32, 96, 64),
  83. nn.MaxPool2d(kernel_size=3, stride=2)
  84. )
  85. self.block4 = nn.Sequential(
  86. inception(480, 192, 96, 208, 16, 48, 64),
  87. inception(512, 160, 112, 224, 24, 64, 64),
  88. inception(512, 128, 128, 256, 24, 64, 64),
  89. inception(512, 112, 144, 288, 32, 64, 64),
  90. inception(528, 256, 160, 320, 32, 128, 128),
  91. nn.MaxPool2d(3, 2)
  92. )
  93. self.block5 = nn.Sequential(
  94. inception(832, 256, 160, 320, 32, 128, 128),
  95. inception(832, 384, 182, 384, 48, 128, 128),
  96. nn.AvgPool2d(2)
  97. )
  98. self.classifier = nn.Linear(1024, num_classes)
  99. def forward(self, x):
  100. x = self.block1(x)
  101. if self.verbose:
  102. print('block 1 output: {}'.format(x.shape))
  103. x = self.block2(x)
  104. if self.verbose:
  105. print('block 2 output: {}'.format(x.shape))
  106. x = self.block3(x)
  107. if self.verbose:
  108. print('block 3 output: {}'.format(x.shape))
  109. x = self.block4(x)
  110. if self.verbose:
  111. print('block 4 output: {}'.format(x.shape))
  112. x = self.block5(x)
  113. if self.verbose:
  114. print('block 5 output: {}'.format(x.shape))
  115. # print(x.shape)
  116. #x是[b,1024,1,1]
  117. x = x.view(x.shape[0], -1)
  118. #x是[b,1024]
  119. x = self.classifier(x)
  120. return x
  121. def get_acc(output, label):
  122. total = output.shape[0]
  123. # output是概率,每行概率最高的就是预测值
  124. _, pred_label = output.max(1)
  125. num_correct = (pred_label == label).sum().item()
  126. return num_correct / total
  127. batch_size = 32
  128. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  129. transform = torchvision.transforms.Compose([
  130. torchvision.transforms.Resize(size=96),
  131. torchvision.transforms.ToTensor()
  132. ])
  133. train_set = torchvision.datasets.FashionMNIST(
  134. root='dataset/',
  135. train=True,
  136. download=True,
  137. transform=transform
  138. )
  139. # hand-out留出法划分
  140. train_set, val_set = torch.utils.data.random_split(train_set, [50000, 10000])
  141. test_set = torchvision.datasets.FashionMNIST(
  142. root='dataset/',
  143. train=False,
  144. download=True,
  145. transform=transform
  146. )
  147. train_loader = torch.utils.data.DataLoader(
  148. dataset=train_set,
  149. batch_size=batch_size,
  150. shuffle=True
  151. )
  152. val_loader = torch.utils.data.DataLoader(
  153. dataset=val_set,
  154. batch_size=batch_size,
  155. shuffle=True
  156. )
  157. test_loader = torch.utils.data.DataLoader(
  158. dataset=test_set,
  159. batch_size=batch_size,
  160. shuffle=False
  161. )
  162. net = googlenet(1,10)
  163. lr = 2e-3
  164. optimizer = optim.Adam(net.parameters(), lr=lr)
  165. critetion = nn.CrossEntropyLoss()
  166. net = net.to(device)
  167. prev_time = datetime.now()
  168. valid_data = val_loader
  169. for epoch in range(3):
  170. train_loss = 0
  171. train_acc = 0
  172. net.train()
  173. for inputs, labels in train_loader:
  174. inputs = inputs.to(device)
  175. labels = labels.to(device)
  176. # forward
  177. outputs = net(inputs)
  178. loss = critetion(outputs, labels)
  179. # backward
  180. optimizer.zero_grad()
  181. loss.backward()
  182. optimizer.step()
  183. train_loss += loss.item()
  184. train_acc += get_acc(outputs, labels)
  185. # 最后还要求平均的
  186. # 显示时间
  187. cur_time = datetime.now()
  188. h, remainder = divmod((cur_time - prev_time).seconds, 3600)
  189. m, s = divmod(remainder, 60)
  190. # time_str = 'Time %02d:%02d:%02d'%(h,m,s)
  191. time_str = 'Time %02d:%02d:%02d(from %02d/%02d/%02d %02d:%02d:%02d to %02d/%02d/%02d %02d:%02d:%02d)' % (
  192. h, m, s, prev_time.year, prev_time.month, prev_time.day, prev_time.hour, prev_time.minute, prev_time.second,
  193. cur_time.year, cur_time.month, cur_time.day, cur_time.hour, cur_time.minute, cur_time.second)
  194. prev_time = cur_time
  195. # validation
  196. with torch.no_grad():
  197. net.eval()
  198. valid_loss = 0
  199. valid_acc = 0
  200. for inputs, labels in valid_data:
  201. inputs = inputs.to(device)
  202. labels = labels.to(device)
  203. outputs = net(inputs)
  204. loss = critetion(outputs, labels)
  205. valid_loss += loss.item()
  206. valid_acc += get_acc(outputs, labels)
  207. print("Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f,"
  208. % (epoch, train_loss / len(train_loader), train_acc / len(train_loader), valid_loss / len(valid_data),
  209. valid_acc / len(valid_data))
  210. + time_str)
  211. torch.save(net.state_dict(), 'checkpoints/params.pkl')
  212. # 测试
  213. with torch.no_grad():
  214. net.eval()
  215. correct = 0
  216. total = 0
  217. for (images, labels) in test_loader:
  218. images, labels = images.to(device), labels.to(device)
  219. output = net(images)
  220. _, predicted = torch.max(output.data, 1)
  221. total += labels.size(0)
  222. correct += (predicted == labels).sum().item()
  223. print("The accuracy of total {} images: {}%".format(total, 100 * correct / total))

可以看到输入的尺寸不断减小, 通道的维度不断增加

flower数据集

http://download.tensorflow.org/example_images/flower_photos.tgz

下载及解压完后应该是这样

数据集进行分类+处理

  1. import os
  2. from shutil import copy
  3. import random
  4. def mkfile(file):
  5. if not os.path.exists(file):
  6. os.makedirs(file)
  7. root_dir = 'dataset/flower/'
  8. file_dir = root_dir + 'flower_photos/'
  9. train_dir = root_dir + 'train/'
  10. val_dir = root_dir + 'val/'
  11. flower_class = [cla for cla in os.listdir(file_dir) if ".txt" not in cla]
  12. mkfile(train_dir)
  13. mkfile(val_dir)
  14. for cla in flower_class:
  15. mkfile(train_dir + cla)
  16. for cla in flower_class:
  17. mkfile(val_dir + cla)
  18. split_rate = 0.1
  19. for cla in flower_class:
  20. cla_path = file_dir + cla + '/'
  21. images = os.listdir(cla_path)
  22. num = len(images)
  23. eval_index = random.sample(images, k=int(num*split_rate))
  24. for index, image in enumerate(images):
  25. if image in eval_index:
  26. image_path = cla_path + image
  27. new_path = val_dir + cla
  28. copy(image_path, new_path)
  29. else:
  30. image_path = cla_path + image
  31. new_path = train_dir + cla
  32. copy(image_path, new_path)
  33. print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
  34. print()
  35. print("processing done!")

还要带一个class_indices.json

这个是程序带着的,不是flower数据集中带着的

网络实现

  1. import torch
  2. import torch.nn as nn
  3. from torchvision import transforms, datasets
  4. import json
  5. import torch.optim as optim
  6. import torch.nn.functional as F
  7. class BasicConv2d(nn.Module):
  8. def __init__(self, in_channels, out_channels, **kwargs):
  9. super(BasicConv2d, self).__init__()
  10. self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
  11. self.relu = nn.ReLU(inplace=True)
  12. def forward(self, x):
  13. x = self.conv(x)
  14. x = self.relu(x)
  15. return x
  16. class Inception(nn.Module):
  17. def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
  18. super(Inception, self).__init__()
  19. self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
  20. self.branch2 = nn.Sequential(
  21. BasicConv2d(in_channels, ch3x3red, kernel_size=1),
  22. BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) # 保证输出大小等于输入大小
  23. )
  24. self.branch3 = nn.Sequential(
  25. BasicConv2d(in_channels, ch5x5red, kernel_size=1),
  26. BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2) # 保证输出大小等于输入大小
  27. )
  28. self.branch4 = nn.Sequential(
  29. nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
  30. BasicConv2d(in_channels, pool_proj, kernel_size=1)
  31. )
  32. def forward(self, x):
  33. branch1 = self.branch1(x)
  34. branch2 = self.branch2(x)
  35. branch3 = self.branch3(x)
  36. branch4 = self.branch4(x)
  37. outputs = [branch1, branch2, branch3, branch4]
  38. return torch.cat(outputs, 1)
  39. class GoogLeNet(nn.Module):
  40. def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
  41. super(GoogLeNet, self).__init__()
  42. self.aux_logits = aux_logits
  43. self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
  44. self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  45. self.conv2 = BasicConv2d(64, 64, kernel_size=1)
  46. self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
  47. self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  48. self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
  49. self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
  50. self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  51. self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
  52. self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
  53. self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
  54. self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
  55. self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
  56. self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  57. self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
  58. self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
  59. if self.aux_logits:
  60. self.aux1 = InceptionAux(512, num_classes)
  61. self.aux2 = InceptionAux(528, num_classes)
  62. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  63. self.dropout = nn.Dropout(0.4)
  64. self.fc = nn.Linear(1024, num_classes)
  65. if init_weights:
  66. self._initialize_weights()
  67. def forward(self, x):
  68. # N x 3 x 224 x 224
  69. x = self.conv1(x)
  70. # N x 64 x 112 x 112
  71. x = self.maxpool1(x)
  72. # N x 64 x 56 x 56
  73. x = self.conv2(x)
  74. # N x 64 x 56 x 56
  75. x = self.conv3(x)
  76. # N x 192 x 56 x 56
  77. x = self.maxpool2(x)
  78. # N x 192 x 28 x 28
  79. x = self.inception3a(x)
  80. # N x 256 x 28 x 28
  81. x = self.inception3b(x)
  82. # N x 480 x 28 x 28
  83. x = self.maxpool3(x)
  84. # N x 480 x 14 x 14
  85. x = self.inception4a(x)
  86. # N x 512 x 14 x 14
  87. if self.training and self.aux_logits: # eval model lose this layer
  88. aux1 = self.aux1(x)
  89. x = self.inception4b(x)
  90. # N x 512 x 14 x 14
  91. x = self.inception4c(x)
  92. # N x 512 x 14 x 14
  93. x = self.inception4d(x)
  94. # N x 528 x 14 x 14
  95. if self.training and self.aux_logits: # eval model lose this layer
  96. aux2 = self.aux2(x)
  97. x = self.inception4e(x)
  98. # N x 832 x 14 x 14
  99. x = self.maxpool4(x)
  100. # N x 832 x 7 x 7
  101. x = self.inception5a(x)
  102. # N x 832 x 7 x 7
  103. x = self.inception5b(x)
  104. # N x 1024 x 7 x 7
  105. x = self.avgpool(x)
  106. # N x 1024 x 1 x 1
  107. x = torch.flatten(x, 1)
  108. # N x 1024
  109. x = self.dropout(x)
  110. x = self.fc(x)
  111. # N x 1000 (num_classes)
  112. if self.training and self.aux_logits: # eval model lose this layer
  113. return x, aux2, aux1
  114. return x
  115. def _initialize_weights(self):
  116. for m in self.modules():
  117. if isinstance(m, nn.Conv2d):
  118. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  119. if m.bias is not None:
  120. nn.init.constant_(m.bias, 0)
  121. elif isinstance(m, nn.Linear):
  122. nn.init.normal_(m.weight, 0, 0.01)
  123. nn.init.constant_(m.bias, 0)
  124. class InceptionAux(nn.Module):
  125. def __init__(self, in_channels, num_classes):
  126. super(InceptionAux, self).__init__()
  127. self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
  128. self.conv = BasicConv2d(in_channels, 128, kernel_size=1) # output[batch, 128, 4, 4]
  129. self.fc1 = nn.Linear(2048, 1024)
  130. self.fc2 = nn.Linear(1024, num_classes)
  131. def forward(self, x):
  132. # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
  133. x = self.averagePool(x)
  134. # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
  135. x = self.conv(x)
  136. # N x 128 x 4 x 4
  137. x = torch.flatten(x, 1)
  138. x = F.dropout(x, 0.5, training=self.training)
  139. # N x 2048
  140. x = F.relu(self.fc1(x), inplace=True)
  141. x = F.dropout(x, 0.5, training=self.training)
  142. # N x 1024
  143. x = self.fc2(x)
  144. # N x num_classes
  145. return x
  146. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  147. data_transform = {
  148. "train": transforms.Compose([transforms.RandomResizedCrop(224),
  149. transforms.RandomHorizontalFlip(),
  150. transforms.ToTensor(),
  151. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
  152. ),
  153. "val": transforms.Compose([transforms.Resize((224, 224)),
  154. transforms.ToTensor(),
  155. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
  156. )
  157. }
  158. root = 'dataset/flower/' # flower data set path
  159. train_dataset = datasets.ImageFolder(
  160. root=root + "train",
  161. transform=data_transform["train"]
  162. )
  163. train_num = len(train_dataset)
  164. # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
  165. flower_list = train_dataset.class_to_idx
  166. cla_dict = dict((val, key) for key, val in flower_list.items())
  167. # write dict into json file
  168. json_str = json.dumps(cla_dict, indent=4)
  169. with open(root + 'class_indices.json', 'w') as json_file:
  170. json_file.write(json_str)
  171. batch_size = 32
  172. train_loader = torch.utils.data.DataLoader(
  173. dataset=train_dataset,
  174. batch_size=batch_size, shuffle=True,
  175. num_workers=0
  176. )
  177. validate_dataset = datasets.ImageFolder(
  178. root=root + "val",
  179. transform=data_transform["val"]
  180. )
  181. val_num = len(validate_dataset)
  182. validate_loader = torch.utils.data.DataLoader(
  183. dataset=validate_dataset,
  184. batch_size=batch_size, shuffle=False,
  185. num_workers=0
  186. )
  187. net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
  188. net.to(device)
  189. criterion = nn.CrossEntropyLoss()
  190. optimizer = optim.Adam(net.parameters(), lr=3e-4)
  191. best_acc = 0.0
  192. save_path = 'checkpoints/googleNet.pth'
  193. for epoch in range(6):
  194. # train
  195. net.train()
  196. running_loss = 0.0
  197. for step, data in enumerate(train_loader, start=0):
  198. images, labels = data
  199. optimizer.zero_grad()
  200. logits, aux_logits2, aux_logits1 = net(images.to(device))
  201. loss0 = criterion(logits, labels.to(device))
  202. loss1 = criterion(aux_logits1, labels.to(device))
  203. loss2 = criterion(aux_logits2, labels.to(device))
  204. loss = loss0 + loss1 * 0.3 + loss2 * 0.3
  205. loss.backward()
  206. optimizer.step()
  207. # print statistics
  208. running_loss += loss.item()
  209. # print train process
  210. rate = (step + 1) / len(train_loader)
  211. a = "*" * int(rate * 50)
  212. b = "." * int((1 - rate) * 50)
  213. print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
  214. print()
  215. # validate
  216. net.eval()
  217. acc = 0.0 # accumulate accurate number / epoch
  218. with torch.no_grad():
  219. for data_test in validate_loader:
  220. test_images, test_labels = data_test
  221. outputs = net(test_images.to(device)) # eval model only have last output layer
  222. predict_y = torch.max(outputs, dim=1)[1]
  223. acc += (predict_y == test_labels.to(device)).sum().item()
  224. accurate_test = acc / val_num
  225. if accurate_test > best_acc:
  226. torch.save(net.state_dict(), save_path)
  227. print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
  228. (epoch + 1, running_loss / step, accurate_test))
  229. print('Finished Training')

predict

predict部分没和主程序写在一起是因为predict的时候就不用辅助分类器的输出了,所以网络结构就变了

所以只能是等训练完模型参数然后再载入进来,这里的载入还是strict=False

  1. import torch
  2. from model import GoogLeNet
  3. from PIL import Image
  4. from torchvision import transforms
  5. import matplotlib.pyplot as plt
  6. import json
  7. data_transform = transforms.Compose(
  8. [transforms.Resize((224, 224)),
  9. transforms.ToTensor(),
  10. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  11. # load image
  12. # img = Image.open("../tulip.jpg")
  13. img = Image.open("dataset/flower/val/tulips/10791227_7168491604.jpg")
  14. plt.imshow(img)
  15. # [N, C, H, W]
  16. img = data_transform(img)
  17. # expand batch dimension
  18. img = torch.unsqueeze(img, dim=0)
  19. # read class_indict
  20. try:
  21. json_file = open('dataset/flower/class_indices.json', 'r')
  22. class_indict = json.load(json_file)
  23. except Exception as e:
  24. print(e)
  25. exit(-1)
  26. # create model
  27. model = GoogLeNet(num_classes=5, aux_logits=False)
  28. # load model weights
  29. model_weight_path = "checkpoints/googleNet.pth"
  30. missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)
  31. model.eval()
  32. with torch.no_grad():
  33. # predict class
  34. output = torch.squeeze(model(img))
  35. predict = torch.softmax(output, dim=0)
  36. predict_cla = torch.argmax(predict).numpy()
  37. print(class_indict[str(predict_cla)])
  38. plt.show()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/92523
推荐阅读
相关标签
  

闽ICP备14008679号