当前位置:   article > 正文

【图像识别】基于pytorch 的入门demo——CIFAR10数据集识别及其可视化_基于pytorch+vgg的图像检索 demo

基于pytorch+vgg的图像检索 demo

目录

环境配置

1.数据集

2.模型训练

3.训练结果

4.Batch_size的作用

5.参考资料


 

        pytorch使用是动态图计算思想,符合一般的计算逻辑,集成了caffe,容易上手灵活方便,方便使用GPU 加速、自动求导数,更适用于学术界。tensorflow采用的是静态图计算思想,静态图需要提前定义计算图,然后使用创建的计算图运算,运算过程中不利于查看中间变量,但是框架的生态成熟,部署便利,更适合工业界。pytorch自然语言处理包:AllenNLP,计算机视觉包:Torchvision。

环境配置

             win10 + GTX 1660Ti +Anaconda3 +Spyder+Pytorch1.0

              Pytorch的配置非常简单,非常友好。 直接登录官网,https://pytorch.org/   选择配置环境,执行Command即可。

     spyder配置opencv环境,在Anaconda prompt中输入:

conda install –c https://conda.binstar.org/menpo opencv

1.数据集

     CIFAR-10和CIFAR-100是带有标签的数据集(详情:http://groups.csail.mit.edu/vision/TinyImages/)

     CIFAR-10数据集共有60000张彩色图像,每张大小:32*32*3,分为10个类,具体见图,每类6000张图。

     训练集:50000张,构成了500个训练批batch,每一批batch_size为100张。

     测试集:10000张,构成一个batch。每一类随机取1000张,共10类*1000=10000张。

                                                                                          10个类别

另外,pytorch的内置数据集很多:torchvision.datasets

  1. class torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)
  2. class torchvision.datasets.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)
  3. class torchvision.datasets.EMNIST(root, split, **kwargs)
  4. class torchvision.datasets.CocoCaptions(root, annFile, transform=None, target_transform=None)
  5. class torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None)
  6. class torchvision.datasets.LSUN(root, classes='train', transform=None, target_transform=None)
  7. class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=)
  8. class torchvision.datasets.DatasetFolder(root, loader, extensions, transform=None, target_transform=None)
  9. class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
  10. class torchvision.datasets.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)
  11. class torchvision.datasets.STL10(root, split='train', transform=None, target_transform=None, download=False)
  12. class torchvision.datasets.SVHN(root, split='train', transform=None, target_transform=None, download=False)
  13. class torchvision.datasets.PhotoTour(root, name, train=True, transform=None, download=False)

2.模型训练

     2.1 模型选择:

             一方面可以自己定义自己Net,另外也可以使用PyTorch的torchvision.models提供的模型。

  1. import torchvision.models as models
  2. resnet18 = models.resnet18(pretrained=True)
  3. alexnet = models.alexnet(pretrained=True)
  4. squeezenet = models.squeezenet1_0(pretrained=True)
  5. vgg16 = models.vgg16(pretrained=True)
  6. densenet = models.densenet161(pretrained=True)
  7. inception = models.inception_v3(pretrained=True)

      此外,pytorch 刚刚发布了hub功能,见 https://pytorch.org/hub

model=torch.hub.load(model)

   2.2模型可视化

   下方的代码为网上搜集到的,PS:可以使用netron工具进行模型可视化,用工具直接打开cifar10.pkl即可。

    工具链接:https://github.com/lutzroeder/Netron ,可视化后的模型如下:

 

   2.3训练过程:

                 1.构建模型框架
                 2.迭代输入数据集  
                 3.计算前向损失(loss) 
                 4.误差反向传播,更新网络的参数

   2.4参数设置:

                  见代码

  1. import torch #torch的包
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torchvision #基于torch的计算技术视觉相关的开发包
  5. import torchvision.transforms as transforms
  6. import torch.optim as optim
  7. import cv2 as cv
  8. import numpy as np
  9. import time
  10. import matplotlib.pyplot as plt
  11. from visdom import Visdom
  12. import numpy as np
  13. viz = Visdom(env='loss')
  14. x1,y1=0,0
  15. win = viz.line(
  16. X=np.array([x1]),
  17. Y=np.array([y1]),
  18. opts=dict(title='loss'))
  19. #参数设置
  20. batch_size = 50
  21. start = time.time()
  22. #1、对数据进行预处理
  23. transform = transforms.Compose(
  24. [transforms.ToTensor(), #转为tensor
  25. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#归一化
  26. # =============================================================================
  27. # transforms.Compose:
  28. # 将多种操作组合在一起,此处将数据转换为tensor和数据归一化组合为函数tansform
  29. # =============================================================================
  30. #2、加载数据
  31. #2.1下载训练集,并预处理
  32. trainset = torchvision.datasets.CIFAR10(root='./', train=True,
  33. download=True, transform=transform)
  34. #2.2加载训练集,并打乱图像的序号
  35. trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
  36. shuffle=False, num_workers=2)
  37. #2.3下载测试集,并预处理
  38. testset = torchvision.datasets.CIFAR10(root='./', train=False,
  39. download=True, transform=transform)
  40. #2.4加载测试集,由于是测试无需打乱图像序号
  41. testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
  42. shuffle=False, num_workers=2)
  43. #2.5加载label,使用元组,不可改变
  44. classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  45. end = time.time()
  46. print("运行时间:%.2f秒"%(end-start))
  47. #3构建深度学习网络架构
  48. class Net(nn.Module):
  49. def __init__(self):
  50. super(Net, self).__init__()
  51. self.conv1 = nn.Conv2d(3, 64, 3, padding = 1)
  52. self.conv2 = nn.Conv2d(64, 64, 3, padding =1)
  53. self.conv3 = nn.Conv2d(64, 128, 3, padding = 1)
  54. self.conv4 = nn.Conv2d(128, 128, 3, padding = 1)
  55. self.conv5 = nn.Conv2d(128, 256, 3, padding = 1)
  56. self.conv6 = nn.Conv2d(256, 256, 3, padding = 1)
  57. self.maxpool = nn.MaxPool2d(2, 2)
  58. self.avgpool = nn.AvgPool2d(2, 2)
  59. self.globalavgpool = nn.AvgPool2d(8, 8)
  60. self.bn1 = nn.BatchNorm2d(64)
  61. self.bn2 = nn.BatchNorm2d(128)
  62. self.bn3 = nn.BatchNorm2d(256)
  63. self.dropout50 = nn.Dropout(0.5)
  64. self.dropout10 = nn.Dropout(0.1)
  65. self.fc = nn.Linear(256, 10)
  66. def forward(self, x):
  67. x = self.bn1(F.relu(self.conv1(x)))
  68. x = self.bn1(F.relu(self.conv2(x)))
  69. x = self.maxpool(x)
  70. x = self.dropout10(x)
  71. x = self.bn2(F.relu(self.conv3(x)))
  72. x = self.bn2(F.relu(self.conv4(x)))
  73. x = self.avgpool(x)
  74. x = self.dropout10(x)
  75. x = self.bn3(F.relu(self.conv5(x)))
  76. x = self.bn3(F.relu(self.conv6(x)))
  77. x = self.globalavgpool(x)
  78. x = self.dropout50(x)
  79. x = x.view(x.size(0), -1)
  80. x = self.fc(x)
  81. return x
  82. if __name__ == '__main__':
  83. net = Net()
  84. criterion = nn.CrossEntropyLoss() #交叉熵损失函数
  85. optimizer = optim.Adam(net.parameters(), lr=0.1)#lr=0.001
  86. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  87. net.to(device)
  88. for epoch in range(1):
  89. running_loss = 0.
  90. for i, data in enumerate(trainloader):
  91. inputs, labels = data
  92. inputs, labels = inputs.to(device), labels.to(device)
  93. optimizer.zero_grad()
  94. outputs = net(inputs)
  95. loss = criterion(outputs, labels)
  96. loss.backward()
  97. optimizer.step()
  98. print('[%d, %5d] loss: %.4f' %(epoch + 1, (i+1)*batch_size, loss.item()))
  99. x1+=i
  100. viz.line(
  101. X=np.array([x1]),
  102. Y=np.array([loss.item()]),
  103. win=win,#win要保持一致
  104. update='append')
  105. print('Finished Training')
  106. torch.save(net, 'cifar10.pkl')
  107. # net = torch.load('cifar10.pkl')
  108. correct = 0
  109. total = 0
  110. with torch.no_grad():
  111. for data in testloader:
  112. images, labels = data
  113. images, labels = images.to(device), labels.to(device)
  114. outputs = net(images)
  115. _, predicted = torch.max(outputs.data, 1)
  116. total += labels.size(0)
  117. correct += (predicted == labels).sum().item()
  118. print('Accuracy of the network on the 10000 test images: %d %%' % (
  119. 100 * correct / total))
  120. class_correct = list(0. for i in range(10))
  121. class_total = list(0. for i in range(10))
  122. with torch.no_grad():
  123. for data in testloader:
  124. images, labels = data
  125. images, labels = images.to(device), labels.to(device)
  126. outputs = net(images)
  127. _, predicted = torch.max(outputs, 1)
  128. c = (predicted == labels).squeeze()
  129. for i in range(4):
  130. label = labels[i]
  131. class_correct[label] += c[i].item()
  132. class_total[label] += 1
  133. for i in range(10):
  134. print('Accuracy of %5s : %2d %%' % (
  135. classes[i], 100 * class_correct[i] / class_total[i]))

2.5 训练过程可视化

      打开Anaconda Prompt输入命令。(conda install visdom命令安装失败)

pip install visdom

     启动服务:

python -m visdom.server

     打开浏览器:

http://localhost:8097/

3.训练结果

         GPU上训练就是快呀!!!CPU i3 三个半小时左右跑完,GTX 1660 TI 三分钟左右就出一次结果。

4.Batch_size的作用

Batch_size=100;

测试结果

Accuracy of the network on the 10000 test images: 67 %

Accuracy of plane : 65 %

Accuracy of   car : 84 %

Accuracy of  bird : 52 %

Accuracy of   cat : 46 %

Accuracy of  deer : 44 %

Accuracy of   dog : 43 %

Accuracy of  frog : 79 %

Accuracy of horse : 78 %

Accuracy of  ship : 77 %

Accuracy of truck : 75 %

Batch_size=50;

测试结果

Accuracy of the network on the 10000 test images: 66 %

Accuracy of plane : 76 %

Accuracy of   car : 82 %

Accuracy of  bird : 37 %

Accuracy of   cat : 25 %

Accuracy of  deer : 56 %

Accuracy of   dog : 57 %

Accuracy of  frog : 72 %

Accuracy of horse : 67 %

Accuracy of  ship : 76 %

Accuracy of truck : 87 %

Batch_size=10;

测试结果

Accuracy of the network on the 10000 test images: 62 %

Accuracy of plane : 59 %

Accuracy of   car : 77 %

Accuracy of  bird : 49 %

Accuracy of   cat : 37 %

Accuracy of  deer : 50 %

Accuracy of   dog : 52 %

Accuracy of  frog : 69 %

Accuracy of horse : 73 %

Accuracy of  ship : 75 %

Accuracy of truck : 77 %

结论与思考:

  1. 在一定范围内,batch_size越大,越有利于模型的快速收敛,较大的batch _size更接近训练集的整体数据结构,因此,可以保证迭代过程中的梯度方向越准确,最后网络收敛情况就会好。
  2. 然而,并不是batch_size越大越好,使用large-batch训练得到的网络具有较差的泛化能力。训练集的数据结构和测试集的数据结构是相似的,但是二者并不是完全的相同,large-batch有利于提高训练集的收敛精度,但是模型过于刻画了训练集的数据结构,势必导致对测试集的数据模型的刻画能力降低。
  3. batch_size的减小,整体识别率下降,但是对部分类别的识别率升高了,猜测根batch的数据分布接近训练集的分布有关,改变了SGD的梯度下降方向,随着batch_size减小,增加了迭代次数,使得模型收敛更精确。
  4. 训练的核心在于构建具有足够代表性的训练集,并用模型去刻画训练集的数据结构,且该模型对非显著特征应当具有泛化学习能力。

5.参考资料

1、https://blog.csdn.net/Kansas_Jason/article/details/84503367

2、https://blog.csdn.net/shareviews/article/details/83094783(推荐一看)

3、https://blog.csdn.net/leviopku/article/details/81980249(Netron可视化工具)

4、 莫烦大神网页:https://morvanzhou.github.io/

5、Pytorch中文网:https://ptorch.com/

6、Pytorch中文文档:https://ptorch.com/docs/1/

7、Pytorch中文论坛:https://discuss.ptorch.com/

8、深度学习模型可视化工具:https://blog.csdn.net/baidu_40840693/article/details/83006347

 

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/262385
推荐阅读
相关标签
  

闽ICP备14008679号