当前位置:   article > 正文

Pytorch关于CIFAR-10测试完整代码

Pytorch关于CIFAR-10测试完整代码

 

  1. #_*_ coding:utf-8 _*_
  2. # pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
  3. import torch
  4. from torch import nn
  5. from torch.nn import Conv2d,MaxPool2d,Linear,Sequential,Flatten
  6. from torch.utils.data import DataLoader
  7. import torch.optim.optimizer
  8. import torchvision
  9. from torch.utils.tensorboard import SummaryWriter
  10. print(SummaryWriter)
  11. class Model(nn.Module):
  12. def __init__(self):
  13. super().__init__()
  14. self.covn1=Conv2d(in_channels=3,out_channels=32,kernel_size=5,padding=2)
  15. self.maxpool1=MaxPool2d(kernel_size=2)
  16. self.covn2=Conv2d(in_channels=32,out_channels=32,kernel_size=5,padding=2)
  17. self.maxpool2=MaxPool2d(kernel_size=2)
  18. self.covn3=Conv2d(in_channels=32,out_channels=64,kernel_size=5,padding=2)
  19. self.maxpool3=MaxPool2d(kernel_size=2)
  20. self.flaten=Flatten()
  21. self.linear1=Linear(in_features=1024,out_features=64)
  22. self.linear2=Linear(in_features=64,out_features=10)
  23. def forward(self,x):
  24. x= self.covn1(x)
  25. x=self.maxpool1(x)
  26. x=self.covn2(x)
  27. x=self.maxpool2(x)
  28. x=self.covn3(x)
  29. x=self.maxpool3(x)
  30. x=self.flaten(x)
  31. x=self.linear1(x)
  32. x=self.linear2(x)
  33. return x
  34. if __name__ == "__main__":
  35. model = Model()
  36. print(model)
  37. #下载数据集
  38. train_data = torchvision.datasets.CIFAR10('../CIFAR-10/',train=True,transform=torchvision.transforms.ToTensor(),download=False)
  39. test_data = torchvision.datasets.CIFAR10('../CIFAR-10/',train=False,transform=torchvision.transforms.ToTensor(),download=False)
  40. #数据加载器
  41. train_data_loader = DataLoader(train_data,batch_size=64)
  42. test_data_loader = DataLoader(test_data,batch_size=64)
  43. #损失函数
  44. loss = nn.CrossEntropyLoss()
  45. #优化器
  46. optim = torch.optim.SGD(model.parameters(),lr=0.01)
  47. #创建可视化
  48. write = SummaryWriter('CIFAR_logs')
  49. #开始训练数据
  50. eplo =10
  51. train_step = 0
  52. for i in range(10):
  53. print("开始训练:".format(i))
  54. for img,laber in train_data_loader:
  55. input_d = model(img)
  56. loss_fn =loss(input_d,laber)
  57. optim.zero_grad() #梯度清零
  58. loss_fn.backward() #反向计算
  59. optim.step()#更新梯度
  60. train_step += 1
  61. if train_step%100==0:
  62. print("训练次数:{}".format(train_step),"损失:{}".format(loss_fn.item()))
  63. write.add_scalar('train_loss',loss_fn.item(),train_step)
  64. total_accuracy = 0
  65. total_test_loss = 0
  66. total_test_step = 0
  67. with torch.no_grad():
  68. for imgs,labler in test_data_loader:
  69. output = model(imgs)
  70. loss_re=loss(output,labler)
  71. accuracy = (output.argmax(1) == labler).sum()
  72. total_accuracy += accuracy
  73. total_test_loss += loss_re
  74. total_test_step += 1
  75. print("测试集上的正确率{}".format(total_accuracy/len(test_data)))
  76. print("测试集上的损失{}".format(total_test_loss))
  77. write.add_scalar("test_loss",total_test_loss.item(),total_test_step)
  78. torch.save(model,"model10")
  79. write.close()
  80. # tensorboard --logdir=CIFAR_logs --port=2017

预测代码:

  1. from torchvision import datasets, transforms
  2. import numpy as np
  3. from PIL import Image
  4. import torch
  5. import torch.nn.functional as F
  6. from cov01 import Model
  7. classes = ('plane', 'car', 'bird', 'cat',
  8. 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  9. if __name__ == '__main__':
  10. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  11. model = torch.load('model10') # 加载模型
  12. model = model.to(device)
  13. model.eval() # 把模型转为test模式
  14. img = Image.open("../bird.jpg")
  15. trans = transforms.Compose(
  16. [
  17. transforms.CenterCrop(32),
  18. transforms.ToTensor(),
  19. # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
  20. ])
  21. img = trans(img)
  22. img = img.to(device)
  23. img = img.unsqueeze(0) # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]
  24. # 扩展后,为[112828]
  25. output = model(img)
  26. prob = F.softmax(output, dim=1) # prob是10个分类的概率
  27. print(prob)
  28. value, predicted = torch.max(output.data, 1) #按照维度返回最大概率dim = 0 表示按列求最大值,并返回最大值的索引,dim = 1 表示按行求最大值,并返回最大值的索引
  29. print(predicted.item())
  30. print(value)
  31. pred_class = classes[predicted.item()]
  32. print(pred_class)
  33. '''
  34. 记住:
  35. torch.max()[0], 只返回最大值的每个数
  36. troch.max()[1], 只返回最大值的每个索引
  37. torch.max()[1].data 只返回variable中的数据部分(去掉Variable containing:)
  38. torch.max()[1].data.numpy() 把数据转化成numpy ndarry
  39. torch.max()[1].data.numpy().squeeze() 把数据条目中维度为1 的删除掉
  40. torch.max(tensor1,tensor2) element-wise 比较tensor1 和tensor2 中的元素,返回较大的那个值
  41. '''

E:\开发工具\pythonProject\studyLL\venv\Scripts\python.exe E:/开发工具/pythonProject/studyLL/pytorch01/predict.py
<class 'torch.utils.tensorboard.writer.SummaryWriter'>
tensor([[0.1517, 0.0265, 0.3715, 0.1244, 0.1860, 0.0556, 0.0084, 0.0167, 0.0457,
         0.0134]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
2
tensor([1.9028], device='cuda:0')
bird

Process finished with exit code 0

 

成功预测为bird 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/165744
推荐阅读
相关标签
  

闽ICP备14008679号