当前位置:   article > 正文

卷积神经网络(多分类问题 pytorch)_卷积神经网络损失函数 多分类

卷积神经网络损失函数 多分类
  1. # 手写数字识别 神经网络处理 高级处理
  2. import torch
  3. import torch.nn as nn
  4. # 数据集处理
  5. from torchvision import transforms
  6. from torchvision import datasets
  7. from torch.utils.data import DataLoader
  8. # 函数 激活函数等
  9. import torch.nn.functional as F
  10. # 优化器包
  11. import torch.optim as optim
  12. # 分批
  13. batch_size = 64
  14. # 1. 数据处理
  15. transform = transforms.Compose([
  16. transforms.ToTensor(),
  17. transforms.Normalize((0.1307, ), (0.3081, ))
  18. ])
  19. train_dataset = datasets.MNIST(root='../dataset/mnist/',
  20. train=True,
  21. download=True,
  22. transform=transform)
  23. test_dataset = datasets.MNIST(root='../dataset/mnist/',
  24. train=False,
  25. download=True,
  26. transform=transform)
  27. train_loader = DataLoader(test_dataset,
  28. shuffle=True,
  29. batch_size=batch_size)
  30. test_loader = DataLoader(test_dataset,
  31. shuffle=False,
  32. batch_size=batch_size)
  33. # 数据为1 * 28 * 28
  34. # 2. 建立模型
  35. class InceptionA(nn.Module):
  36. def __init__(self, in_channels):
  37. super(InceptionA, self).__init__()
  38. '''初始化'''
  39. """初始化"""
  40. # 池化分支
  41. self.branch_pool = nn.Conv2d(in_channels, 24, kernel_size=1)
  42. # 1 * 1 分支
  43. self.branch1x1 = nn.Conv2d(in_channels, 16, kernel_size=1)
  44. # 5 * 5 分支
  45. self.branch5x5_1 = nn.Conv2d(in_channels, 16, kernel_size=1)
  46. self.branch5x5_2 = nn.Conv2d(16, 24, kernel_size=5, padding=2)
  47. # 3 * 3分支
  48. self.branch3x3_1 = nn.Conv2d(in_channels, 16, kernel_size=1)
  49. self.branch3x3_2 = nn.Conv2d(16, 24, kernel_size=3, padding=1)
  50. self.branch3x3_3 = nn.Conv2d(24, 24, kernel_size=3, padding=1)
  51. def forward(self, x):
  52. branch_pool = F.avg_pool2d(x,
  53. kernel_size=3,
  54. stride=1,
  55. padding=1)
  56. branch_pool = self.branch_pool(branch_pool)
  57. branch1x1 = self.branch1x1(x)
  58. branch5x5 = self.branch5x5_1(x)
  59. branch5x5 = self.branch5x5_2(branch5x5)
  60. branch3x3 = self.branch3x3_1(x)
  61. branch3x3 = self.branch3x3_2(branch3x3)
  62. branch3x3 = self.branch3x3_3(branch3x3)
  63. outputs = [branch1x1, branch5x5, branch3x3, branch_pool]
  64. # dim 1纬度
  65. return torch.cat(outputs, dim=1)
  66. class Net(nn.Module):
  67. def __init__(self):
  68. super(Net, self).__init__()
  69. self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
  70. self.conv2 = torch.nn.Conv2d(88, 20, kernel_size=5)
  71. self.incep1 = InceptionA(in_channels=10)
  72. self.incep2 = InceptionA(in_channels=20)
  73. self.mp = nn.MaxPool2d(2)
  74. self.fc = nn.Linear(1408, 10)
  75. def forward(self, x):
  76. in_size = x.size(0)
  77. x = F.relu(self.mp(self.conv1(x)))
  78. x = self.incep1(x)
  79. x = F.relu(self.mp(self.conv2(x)))
  80. x = self.incep2(x)
  81. x = x.view(in_size, -1)
  82. x = self.fc(x)
  83. return x
  84. model = Net()
  85. # 3.损失函数和优化器 交叉熵损失
  86. criterion = nn.CrossEntropyLoss()
  87. optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
  88. # 4.循环训练
  89. def train(epoch):
  90. running_loss = 0.0
  91. for batch_idx, data in enumerate(train_loader):
  92. inputs, target = data
  93. optimizer.zero_grad()
  94. outputs = model(inputs)
  95. loss = criterion(outputs, target)
  96. loss.backward()
  97. optimizer.step()
  98. running_loss += loss.item()
  99. if batch_idx % 300 == 0:
  100. print('[%d,%d] loss: %.10f' % (epoch+1, batch_idx+1, running_loss / 300))
  101. running_loss = 0.0
  102. # 测试验证
  103. def test():
  104. correct = 0
  105. total = 0
  106. with torch.no_grad(): # 不会再进行梯度
  107. for data in test_loader:
  108. images, labels = data
  109. outputs = model(images)
  110. _, predicted = torch.max(outputs.data, dim=1)
  111. total+=labels.size(0)
  112. correct+=(predicted==labels).sum().item()
  113. print("Accuracy on test set: %d %%" % (100 * correct / total))
  114. # 程序入口处
  115. if __name__ == '__main__':
  116. for epoch in range(10):
  117. train(epoch)
  118. test()
  119. print("训练结束...")

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/520657
推荐阅读
相关标签
  

闽ICP备14008679号