当前位置:   article > 正文

在CIFAR-10上训练VGG6

vgg6
  1. import os
  2. import torch
  3. from torch.utils.data import Dataset
  4. from torchvision.io import read_image
  5. from torch.utils.data import DataLoader
  6. import sys
  7. import torchvision.transforms as transforms
  8. import torch.optim as optim
  9. import torch.nn as nn
  10. from tqdm import tqdm
  11. from typing import List, cast
  12. classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  13. class Cinic10(Dataset):
  14. def __init__(self, img_dir):
  15. self.img_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  16. self.label2id = {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4,
  17. 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
  18. self.id2label = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer',
  19. 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
  20. self.img_dir = img_dir
  21. self.cinic_mean = [0.47889522, 0.47227842, 0.43047404]
  22. self.cinic_std = [0.24205776, 0.23828046, 0.25874835]
  23. self.transform = transforms.Compose([transforms.Normalize(mean=self.cinic_mean, std=self.cinic_std)])
  24. self.countPerLabel = len(
  25. [name for name in os.listdir(os.path.join(self.img_dir, self.img_labels[0]))
  26. if os.path.isfile(os.path.join(os.path.join(self.img_dir, self.img_labels[0]), name))])
  27. self.len = len(self.img_labels) * self.countPerLabel
  28. self.X_Y = []
  29. for label in self.img_labels:
  30. img_path = os.path.join(self.img_dir, label)
  31. images_files = [name for name in os.listdir(img_path) if os.path.isfile(os.path.join(img_path, name))]
  32. label_id = self.label2id[label]
  33. for images_file in images_files:
  34. image = read_image(os.path.join(img_path, images_file))
  35. if image.shape != torch.Size([3, 32, 32]):
  36. image = torch.cat([image, image, image])
  37. image = image.type(torch.float32)
  38. image = self.transform(image)
  39. self.X_Y.append([image, label_id])
  40. def __len__(self):
  41. return self.len
  42. def __getitem__(self, idx):
  43. [image, label] = self.X_Y[idx]
  44. return image, label
  45. class Cifar10(Dataset):
  46. def __init__(self, img_dir):
  47. self.img_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  48. self.label2id = {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4,
  49. 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
  50. self.id2label = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer',
  51. 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
  52. self.img_dir = img_dir
  53. self.cinic_mean = [0.485, 0.456, 0.406]
  54. self.cinic_std = [0.229, 0.224, 0.225]
  55. self.transform = transforms.Compose([transforms.Normalize(mean=self.cinic_mean, std=self.cinic_std)])
  56. self.countPerLabel = len(
  57. [name for name in os.listdir(os.path.join(self.img_dir, self.img_labels[0]))
  58. if os.path.isfile(os.path.join(os.path.join(self.img_dir, self.img_labels[0]), name))])
  59. self.len = len(self.img_labels) * self.countPerLabel
  60. self.X_Y = []
  61. for label in self.img_labels:
  62. img_path = os.path.join(self.img_dir, label)
  63. images_files = [name for name in os.listdir(img_path) if os.path.isfile(os.path.join(img_path, name))]
  64. label_id = self.label2id[label]
  65. for images_file in images_files:
  66. image = read_image(os.path.join(img_path, images_file))
  67. if image.shape != torch.Size([3, 32, 32]):
  68. image = torch.cat([image, image, image])
  69. image = image.type(torch.float32)
  70. image = self.transform(image)
  71. self.X_Y.append([image, label_id])
  72. def __len__(self):
  73. return self.len
  74. def __getitem__(self, idx):
  75. [image, label] = self.X_Y[idx]
  76. return image, label
  77. traindataset = Cifar10('CIFAR10/train')
  78. traindataloader = DataLoader(traindataset, batch_size=256, shuffle=True)
  79. testdataset = Cifar10('CIFAR10/test')
  80. testdataloader = DataLoader(testdataset, batch_size=256)
  81. def make_layers():
  82. layers: List[nn.Module] = []
  83. in_channels = 3
  84. cfg = [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"]
  85. for v in cfg:
  86. if v == "M":
  87. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
  88. else:
  89. v = cast(int, v)
  90. conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
  91. layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
  92. in_channels = v
  93. return nn.Sequential(*layers)
  94. class VGG(nn.Module):
  95. def __init__(self, features, num_classes, dropout):
  96. super().__init__()
  97. self.features = features
  98. self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
  99. self.classifier = nn.Sequential(
  100. nn.Linear(512 * 7 * 7, 4096),
  101. nn.ReLU(True),
  102. nn.Dropout(p=dropout),
  103. nn.Linear(4096, 4096),
  104. nn.ReLU(True),
  105. nn.Dropout(p=dropout),
  106. nn.Linear(4096, num_classes),
  107. )
  108. def forward(self, x: torch.Tensor) -> torch.Tensor:
  109. x = self.features(x)
  110. x = self.avgpool(x)
  111. x = torch.flatten(x, 1)
  112. x = self.classifier(x)
  113. return x
  114. VGG16_layers = [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"]
  115. device = torch.device('cuda:1')
  116. net = VGG(make_layers(), len(classes), 0.5)
  117. net = net.to(device)
  118. criterion = nn.CrossEntropyLoss()
  119. optimizer = optim.SGD(net.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)
  120. def adjust_learning_rate(optimizer, epoch):
  121. """Sets the learning rate to the initial LR decayed by 2 every 30 epochs"""
  122. lr = 0.05 * (0.5 ** (epoch // 30))
  123. for param_group in optimizer.param_groups:
  124. param_group['lr'] = lr
  125. def train(dataloader, net, criterion, optimizer, device):
  126. net.train()
  127. running_loss = 0.0
  128. correct = 0
  129. total = 0
  130. for data in tqdm(dataloader, desc='training...', file=sys.stdout):
  131. inputs, labels = data
  132. inputs = inputs.to(device)
  133. labels = labels.to(device)
  134. optimizer.zero_grad()
  135. outputs = net(inputs)
  136. loss = criterion(outputs, labels)
  137. loss.backward()
  138. optimizer.step()
  139. _, predicted = outputs.max(1)
  140. total += labels.size(0)
  141. correct += (predicted == labels).sum().item()
  142. running_loss += loss.item()
  143. return running_loss/total, correct/total
  144. def evaluate(dataloader, net, criterion, device):
  145. net.eval()
  146. running_loss = 0.0
  147. correct = 0
  148. total = 0
  149. with torch.no_grad():
  150. for data in tqdm(dataloader, desc='evaluating...', file=sys.stdout):
  151. inputs, labels = data
  152. inputs = inputs.to(device)
  153. labels = labels.to(device)
  154. outputs = net(inputs)
  155. loss = criterion(outputs, labels)
  156. _, predicted = outputs.max(1)
  157. total += labels.size(0)
  158. correct += (predicted == labels).sum().item()
  159. running_loss += loss.item()
  160. return running_loss/total, correct/total
  161. n_epochs = 300
  162. best_valid_acc = 0
  163. for epoch in range(n_epochs):
  164. adjust_learning_rate(optimizer, epoch)
  165. train_loss, train_acc = train(traindataloader, net, criterion, optimizer, device)
  166. valid_loss, valid_acc = evaluate(testdataloader, net, criterion, device)
  167. print(f'epoch: {epoch + 1}')
  168. print(f'train_loss: {train_loss:.3f}, train_acc: {train_acc:.3f}')
  169. print(f'valid_loss: {valid_loss:.3f}, valid_acc: {valid_acc:.3f}')
  170. if valid_acc > best_valid_acc:
  171. print(f'{valid_acc:.3f} is better than {best_valid_acc:.3f}, best valid acc is {valid_acc:.3f}')
  172. best_valid_acc = valid_acc
  173. torch.save(net.state_dict(), 'CV/CIFAR10/VGG16.pth')
  174. else:
  175. print(f'best valid acc is {best_valid_acc:.3f}')
  176. net2 = VGG(make_layers(), len(classes), 0.5)
  177. net2 = net2.to(device)
  178. net2.load_state_dict(torch.load('CV/CIFAR10/VGG16.pth'))
  179. valid_loss, valid_acc = evaluate(testdataloader, net2, criterion, device)
  180. print(f'best model valid acc: {valid_acc:.3f}')

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/957540
推荐阅读
相关标签
  

闽ICP备14008679号