当前位置:   article > 正文

Vgg11训练CIFAR10数据集——pytorch实现_vgg11实现图像分类

vgg11实现图像分类

代码在kaggle上跑了60多分钟,精度85%

Sequential output shape:	 torch.Size([1, 64, 112, 112])
Sequential output shape:	 torch.Size([1, 128, 56, 56])
Sequential output shape:	 torch.Size([1, 256, 28, 28])
Sequential output shape:	 torch.Size([1, 512, 14, 14])
Sequential output shape:	 torch.Size([1, 512, 7, 7])
Sequential output shape:	 torch.Size([1, 25088])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 10])

  1. import os
  2. import datetime
  3. import torch
  4. import torchvision
  5. from torch import nn
  6. from torch import optim
  7. import torch.nn.functional as F
  8. from torch.autograd import Variable
  9. from torch.utils.data import DataLoader
  10. from torchvision import transforms
  11. import torchvision.models as models
  12. from torchvision.utils import save_image
  13. import numpy as np
  14. import matplotlib.pyplot as plt
  15. from scipy.stats import norm
  16. from PIL import Image
  17. import time
  18. import argparse
  19. def try_gpu(i=0):
  20. """如果存在,则返回gpu(i),否则返回cpu()"""
  21. if torch.cuda.device_count() >= i + 1:
  22. return torch.device(f'cuda:{i}')
  23. return torch.device('cpu')
  24. batch_size = 64
  25. path = './'
  26. train_transform = transforms.Compose([
  27. transforms.RandomSizedCrop(224),# 随机剪切成227*227
  28. transforms.RandomHorizontalFlip(),# 随机水平翻转
  29. transforms.ToTensor(),
  30. transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
  31. std = [ 0.229, 0.224, 0.225 ]),
  32. ])
  33. val_transform = transforms.Compose([
  34. transforms.Resize((224,224)),
  35. transforms.ToTensor(),
  36. transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
  37. std = [ 0.229, 0.224, 0.225 ]),
  38. ])
  39. traindir = os.path.join(path, 'train')
  40. valdir = os.path.join(path, 'val')
  41. train_set = torchvision.datasets.CIFAR10(
  42. traindir, train=True, transform=train_transform, download=True)
  43. valid_set = torchvision.datasets.CIFAR10(
  44. valdir, train=False, transform=val_transform, download=True)
  45. train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
  46. valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False)
  47. dataloaders = {
  48. 'train': train_loader,
  49. 'valid': valid_loader,
  50. # 'test': dataloader_test
  51. }
  52. dataset_sizes = {
  53. 'train': len(train_set),
  54. 'valid': len(valid_set),
  55. # 'test': len(test_set)
  56. }
  57. print(dataset_sizes)
  58. class Vgg11(nn.Module):
  59. def __init__(self):
  60. super().__init__()
  61. # vgg11的卷积通道变化
  62. conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))
  63. conv_blks = []
  64. in_channels = 3
  65. # 卷积部分
  66. for(num_convs, out_channels) in conv_arch:
  67. conv_blks.append(vgg_block(num_convs, in_channels, out_channels))
  68. in_channels = out_channels
  69. self.conv = nn.Sequential(*conv_blks)
  70. self.fc = nn.Sequential(
  71. nn.Linear(out_channels * 7 * 7, 4096), nn.ReLU(),
  72. nn.Dropout(p=0.5),
  73. nn.Linear(4096, 4096), nn.ReLU(),
  74. nn.Dropout(p=0.5),
  75. nn.Linear(4096, 10))
  76. self.fn = nn.Flatten()
  77. def forward(self, x):
  78. out = self.conv(x)
  79. out = self.fn(out)
  80. out = self.fc(out)
  81. return out
  82. # vgg块:num_convs个卷积层 + 1个最大汇聚层
  83. def vgg_block(num_convs, in_channels, out_channels):
  84. layers = []
  85. for _ in range(num_convs):
  86. layers.append(nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1))
  87. layers.append(nn.ReLU())
  88. in_channels = out_channels
  89. layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
  90. return nn.Sequential(*layers)
  91. X = torch.randn(1, 3, 224, 224)
  92. net = Vgg11()
  93. for layer in net.conv:
  94. X=layer(X)
  95. print(layer.__class__.__name__,'output shape:\t',X.shape)
  96. X=net.fn(X)
  97. print(layer.__class__.__name__,'output shape:\t',X.shape)
  98. for layer in net.fc:
  99. X=layer(X)
  100. print(layer.__class__.__name__,'output shape:\t',X.shape)
  101. def train(model, criterion, optimizer, scheduler, device, num_epochs, dataloaders,dataset_sizes):
  102. model = model.to(device)
  103. print('training on ', device)
  104. since = time.time()
  105. best_model_wts = []
  106. best_acc = 0.0
  107. for epoch in range(num_epochs):
  108. # 训练模型
  109. s = time.time()
  110. model,train_epoch_acc,train_epoch_loss = train_model(
  111. model, criterion, optimizer, dataloaders['train'], dataset_sizes['train'], device)
  112. print('Epoch {}/{} - train Loss: {:.4f} Acc: {:.4f} Time:{:.1f}s'
  113. .format(epoch+1, num_epochs, train_epoch_loss, train_epoch_acc,time.time()-s))
  114. # 验证模型
  115. s = time.time()
  116. val_epoch_acc,val_epoch_loss = val_model(
  117. model, criterion, dataloaders['valid'], dataset_sizes['valid'], device)
  118. print('Epoch {}/{} - valid Loss: {:.4f} Acc: {:.4f} Time:{:.1f}s'
  119. .format(epoch+1, num_epochs, val_epoch_loss, val_epoch_acc,time.time()-s))
  120. # 每轮都记录最好的参数.
  121. if val_epoch_acc > best_acc:
  122. best_acc = val_epoch_acc
  123. best_model_wts = model.state_dict()
  124. # 优化器
  125. # if scheduler not in None:
  126. # scheduler.step()
  127. # 保存画图参数
  128. train_losses.append(train_epoch_loss.to('cpu'))
  129. train_acc.append(train_epoch_acc.to('cpu'))
  130. val_losses.append(val_epoch_loss.to('cpu'))
  131. val_acc.append(val_epoch_acc.to('cpu'))
  132. print()
  133. time_elapsed = time.time() - since
  134. print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  135. print('Best val Acc: {:4f}'.format(best_acc))
  136. # model.load_state_dict(best_model_wts)
  137. return model
  138. def train_model(model, criterion, optimizer, dataloader, dataset_size,device):
  139. model.train()
  140. running_loss = 0.0
  141. running_corrects = 0
  142. for inputs,labels in dataloader:
  143. optimizer.zero_grad()
  144. # 输入的属性
  145. inputs = Variable(inputs.to(device))
  146. # 标签
  147. labels = Variable(labels.to(device))
  148. # 预测
  149. outputs = model(inputs)
  150. _,preds = torch.max(outputs.data,1)
  151. # 计算损失
  152. loss = criterion(outputs,labels)
  153. #梯度下降
  154. loss.backward()
  155. optimizer.step()
  156. running_loss += loss.data
  157. running_corrects += torch.sum(preds == labels.data)
  158. epoch_loss = running_loss / dataset_size
  159. epoch_acc = running_corrects / dataset_size
  160. return model,epoch_acc,epoch_loss
  161. def val_model(model, criterion, dataloader, dataset_size, device):
  162. model.eval()
  163. running_loss = 0.0
  164. running_corrects = 0
  165. for (inputs,labels) in dataloader:
  166. # 输入的属性
  167. inputs = Variable(inputs.to(device))
  168. # 标签
  169. labels = Variable(labels.to(device))
  170. # 预测
  171. outputs = model(inputs)
  172. _,preds = torch.max(outputs.data,1)
  173. # 计算损失
  174. loss = criterion(outputs,labels)
  175. running_loss += loss.data
  176. running_corrects += torch.sum(preds == labels.data)
  177. epoch_loss = running_loss / dataset_size
  178. epoch_acc = running_corrects / dataset_size
  179. return epoch_acc,epoch_loss
  180. val_losses,val_acc = [],[]
  181. train_losses,train_acc = [],[]
  182. lr,num_epochs = 0.01,10
  183. model = Vgg11()
  184. criterion = nn.CrossEntropyLoss()
  185. # optimizer = optim.Adam(model.parameters(), lr=lr)
  186. optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  187. # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
  188. model = train(model, criterion, optimizer, None ,
  189. try_gpu(), num_epochs, dataloaders, dataset_sizes)
  190. plt.plot(range(1, len(train_losses)+1),train_losses, 'b', label='training loss')
  191. plt.plot(range(1, len(val_losses)+1), val_losses, 'r', label='val loss')
  192. plt.legend()
  193. plt.plot(range(1,len(train_acc)+1),train_acc,'b--',label = 'train accuracy')
  194. plt.plot(range(1,len(val_acc)+1),val_acc,'r--',label = 'val accuracy')
  195. plt.legend()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/957486
推荐阅读
相关标签
  

闽ICP备14008679号