当前位置:   article > 正文

cifar10数据集的vggnet网络结构_cifar10 vggnet

cifar10 vggnet
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class VGGbase(nn.Module):
  5. def __init__(self):
  6. super(VGGbase, self).__init__()
  7. # 3 * 28 * 28 (crop --> 32,28)
  8. self.conv1 = nn.Sequential(
  9. nn.Conv2d(3, 64,kernel_size=3, stride=1, padding=1),
  10. nn.BatchNorm2d(64),
  11. nn.ReLU()
  12. )
  13. self.max_pooling1 = nn.MaxPool2d(kernel_size=2, stride=2)
  14. # 14 * 14
  15. self.conv2_1 = nn.Sequential(
  16. nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
  17. nn.BatchNorm2d(128),
  18. nn.ReLU()
  19. )
  20. self.conv2_2 = nn.Sequential(
  21. nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
  22. nn.BatchNorm2d(128),
  23. nn.ReLU()
  24. )
  25. self.max_pooling2 = nn.MaxPool2d(kernel_size=2, stride=2)
  26. # 7 * 7
  27. self.conv3_1 = nn.Sequential(
  28. nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
  29. nn.BatchNorm2d(256),
  30. nn.ReLU()
  31. )
  32. self.conv3_2 = nn.Sequential(
  33. nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
  34. nn.BatchNorm2d(256),
  35. nn.ReLU()
  36. )
  37. self.max_pooling3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
  38. #4 * 4
  39. self.conv4_1 = nn.Sequential(
  40. nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
  41. nn.BatchNorm2d(512),
  42. nn.ReLU()
  43. )
  44. self.conv4_2 = nn.Sequential(
  45. nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
  46. nn.BatchNorm2d(512),
  47. nn.ReLU()
  48. )
  49. self.max_pooling4 = nn.MaxPool2d(kernel_size=2, stride=2)
  50. # batchsize * 512 * 2 * 2-->batchsize * (512 * 4)
  51. self.fc = nn.Linear(512 * 4, 10)
  52. def forward(self, x):
  53. batchsize = x.size(0)
  54. out = self.conv1(x)
  55. out = self.max_pooling1(out)
  56. out = self.conv2_1(out)
  57. out = self.conv2_2(out)
  58. out = self.max_pooling2(out)
  59. out = self.conv3_1(out)
  60. out = self.conv3_2(out)
  61. out = self.max_pooling3(out)
  62. out = self.conv4_1(out)
  63. out = self.conv4_2(out)
  64. out = self.max_pooling4(out)
  65. out = out.view(batchsize, -1)
  66. # batchsize * c * h * w --> batchsize * n
  67. out = self.fc(out)
  68. out = F.log_softmax(out, dim=1) # batchsize * 10
  69. return out
  70. def VGGNet():
  71. return VGGbase()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/煮酒与君饮/article/detail/957492
推荐阅读
  

闽ICP备14008679号