当前位置:   article > 正文

Pytorch 深度学习 手写数字分类(亲测)_手写数字分类深度学习算法

手写数字分类深度学习算法

先简单记录下,后期有空再补充。

1. 定义cnn 模型 

  1. import torch.nn as nn
  2. import torch
  3. from torchvision import transforms,models
  4. class CNN(torch.nn.Module):
  5. def __init__(self):
  6. super(CNN, self).__init__()
  7. self.conv = torch.nn.Sequential(
  8. # 用来实现2d卷积操作,h和w2个维度,当前图片的channel是1,输出是32,卷积核是5
  9. torch.nn.Conv2d(1, 32, kernel_size=5, padding=2),
  10. torch.nn.BatchNorm2d(32),
  11. torch.nn.ReLU(),
  12. torch.nn.MaxPool2d(2)
  13. )
  14. #第一轮卷积之后的大小,输入尺寸是28*28,变为14*14*channel32,输出结果为10维
  15. self.fc = torch.nn.Linear(14 * 14 * 32, 10)
  16. def forward(self, x):
  17. out = self.conv(x)
  18. out = out.view(out.size()[0], -1)
  19. out = self.fc(out)
  20. return out

2.制作自己的数据集

  1. import os
  2. import torch
  3. from torch.utils import data
  4. from PIL import Image
  5. from torchvision import transforms
  6. from torch.utils.data import Dataset
  7. species = {'0':0,'1':1,'2':2,'3':3,'4':4,'5':5,'6':6,'7':7,'8':8,'9':9
  8. }
  9. class MyDataset(Dataset):
  10. def __init__(self, root, transform = None):
  11. # root :'mnist\\test' or 'mnist\\train'
  12. self.root = root
  13. self.transform = transform
  14. self.data =[]
  15. # 获取子目录 '0','1','2','3',...
  16. sub_root_test = os.listdir(self.root)
  17. for sub_root in sub_root_test:
  18. # 获取子目录下所有图片的名字
  19. sub_image_name_list = os.listdir(os.path.join(self.root,sub_root))
  20. for sub_image_name in sub_image_name_list:
  21. # 获取每张图片的完整路径
  22. image_path = os.path.join(self.root, os.path.join(sub_root, sub_image_name))
  23. # 获取标签,也就是子目录的文件名
  24. label = species[image_path.split('\\')[-2]]
  25. # 做成(图片路径,标签)的元组
  26. sample = (image_path,label)
  27. self.data.append(sample)
  28. def __len__(self):
  29. return len(self.data)
  30. def __getitem__(self, index):
  31. image_path,label = self.data[index]
  32. image_original = Image.open(image_path).convert('RGB')
  33. image_tensor = self.transform(image_original)
  34. return image_tensor,label
  35. class MyDataset_pre(Dataset):
  36. def __init__(self, root, transform = None):
  37. # root :'test_images
  38. self.root = root
  39. self.transform = transform
  40. self.data =[]
  41. image_name_list = os.listdir(self.root)
  42. for image_name in image_name_list:
  43. image_path = os.path.join(self.root, image_name)
  44. self.data.append(image_path)
  45. def __len__(self):
  46. return len(self.data)
  47. def __getitem__(self, index):
  48. image_path = self.data[index]
  49. image_original = Image.open(image_path).convert('RGB')
  50. image_tensor = self.transform(image_original)
  51. return image_tensor

 3.模型训练

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.optim as optim
  5. from torchvision import transforms
  6. from torch.utils.data import DataLoader
  7. from preprocess_dataset import MyDataset
  8. from model import CNN
  9. BATCH_SIZE = 32
  10. EPOCHS = 5
  11. DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  12. trans = transforms.Compose([transforms.ToTensor(),transforms.Grayscale()])
  13. # trans = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])
  14. train_dataset = MyDataset('mnist1_new\\train',transform = trans)
  15. test_dataset = MyDataset('mnist1_new\\test',transform = trans)
  16. train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
  17. test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = True)
  18. #model = torch.load("model/mnist_model_nn.pkl")
  19. model = CNN()
  20. # model = NeuralNetwork()
  21. # model = VGG16().vgg16_model()
  22. net = model.to(DEVICE)
  23. # loss function
  24. criterion = torch.nn.CrossEntropyLoss()
  25. # 优化器
  26. optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
  27. # training
  28. # 将所有的样本遍历完,对模型进行训练后,这一轮称为epoch
  29. for epoch in range(EPOCHS):
  30. model.train()
  31. running_loss = 0.0
  32. for i, data in enumerate(train_loader):
  33. images, labels = data
  34. outputs = net(images)
  35. loss = criterion(outputs, labels)
  36. # running_loss += loss.item()
  37. # 反向传播,完成对参数的优化
  38. optimizer.zero_grad()
  39. loss.backward()
  40. optimizer.step()
  41. print("epoch is {}, batch is{}/{}, loss is {}".format(epoch + 1, i, len(train_dataset)/BATCH_SIZE, loss.item()))
  42. # eval/test 计算在测试集的精度
  43. loss_test = 0.0
  44. acc = 0.0
  45. accuracy = 0.0
  46. model.eval()
  47. for i, data in enumerate(test_loader):
  48. images , labels = data
  49. outputs = net(images)
  50. loss_test += criterion(outputs, labels)
  51. _, pred = outputs.max(1)
  52. # 判断是否相等计算准确率
  53. accuracy += (pred == labels).sum().item()
  54. accuracy = accuracy / len(test_dataset)
  55. loss_test = loss_test / (len(test_dataset) / BATCH_SIZE)
  56. # 打印精度和损失
  57. print("epoch is {}, accuracy is {}, loss test is {}".format(epoch + 1, accuracy, loss_test.item()))
  58. if accuracy > acc:
  59. acc = accuracy
  60. torch.save(net, "model/mnist_model_nn.pkl")
  61. print('accuancy',acc)

4之前.对图片进行变换,将一串数字按照从左到右切成单个图片。(仅供参考)

  1. import cv2
  2. import numpy as np
  3. import os,torch
  4. from torchvision.transforms import transforms
  5. import numpy
  6. from PIL import Image
  7. def sort_contours(cnts, method='left-to-right'):
  8. # 从左到右排序
  9. reverse = False
  10. i = 0
  11. # handle if sort in reverse
  12. if method == 'right-to-left' or method == 'bottom-to-top':
  13. reverse = True
  14. # handle if sort against y rather than x of the bounding box
  15. if method == 'bottom-to-top' or method == 'top-to-bottom':
  16. i = 1
  17. boundingBoxes = [cv2.boundingRect(c) for c in cnts]
  18. (cnts, boundingBoxes) = zip(*sorted(zip(cnts, boundingBoxes), key = lambda b: b[1][i], reverse = reverse))
  19. return (cnts, boundingBoxes)
  20. def cut_image_sign():
  21. '''
  22. 中间数字串,切成方块图,白底黑字
  23. '''
  24. root_dir = 'output_me\\'
  25. for im_name in os.listdir(root_dir):
  26. # 149,341,3
  27. image_writer_recongnize = cv2.imread(os.path.join(root_dir,im_name))
  28. h, w = image_writer_recongnize.shape[:2]
  29. SIZE = 138
  30. w_size = 256
  31. image_writer_recongnize = cv2.resize(image_writer_recongnize, (w_size, SIZE))
  32. # cv2.imshow('image_writer_recongnize', image_writer_recongnize)
  33. # cv2.waitKey()
  34. gray_new = cv2.cvtColor(image_writer_recongnize,cv2.COLOR_BGR2GRAY)
  35. # cv2.imshow('gray_new', gray_new)
  36. # cv2.waitKey()
  37. threshold, adaptive_image_1 = cv2.threshold(gray_new, 100, 255, cv2.THRESH_BINARY)
  38. adaptive_image_1 = cv2.dilate(adaptive_image_1,(15,15))
  39. # cv2.imshow('adaptive_image_1', adaptive_image_1)
  40. # cv2.waitKey()
  41. adaptive_image_copy = adaptive_image_1.copy()
  42. cnts_1, h = cv2.findContours(adaptive_image_1, cv2.RETR_CCOMP , cv2.CHAIN_APPROX_SIMPLE)
  43. # contourPic_1 = cv2.drawContours(adaptive_image_copy, cnts_1, -1, (0, 0, 255)
  44. (cnts_new, boundingboxes) = sort_contours(cnts_1)
  45. count1 = 0
  46. for c in cnts_new:
  47. # print(cv2.contourArea(c))
  48. if cv2.contourArea(c) > 300 and cv2.contourArea(c) <4000:
  49. # print(cv2.contourArea(c))
  50. count1 = count1 + 1
  51. x, y, w, h = cv2.boundingRect(c)
  52. image_result = image_writer_recongnize[y:y+h, x:x+w]
  53. image_path = os.path.join('output_me_cut','me_cut{}.png'.format(count1))
  54. cv2.imwrite(image_path, image_result)
  55. def image_address():
  56. '''
  57. 图像裁剪为28*28,图像增强,变成黑底白字
  58. '''
  59. root_dir = 'output_me_cut\\'
  60. count3 =0
  61. for im_name in os.listdir(root_dir):
  62. # 149,341,3
  63. count3 = count3 + 1
  64. image_writer_recongnize = Image.open(os.path.join(root_dir,im_name))
  65. image_1 = transforms.Resize((28,28))(image_writer_recongnize)
  66. image2 = np.array(image_1)
  67. image1 = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY)
  68. ret,thresh1 = cv2.threshold(image1,0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
  69. thresh2 = cv2.bitwise_not(thresh1)
  70. thresh3 = cv2.erode(thresh2, (15,15))
  71. cv2.imwrite(os.path.join('output_me_cut_black\\','0_{}.png'.format(count3)),thresh2)
  72. # cv2.imshow('image_writer_recongnize',thresh3)
  73. # cv2.waitKey()

 4.预测单张图片

  1. import os
  2. import torch
  3. from PIL import Image
  4. from torch import nn
  5. from torchvision import transforms, models
  6. from torch.utils.data import DataLoader
  7. from preprocess_dataset import MyDataset_pre
  8. from torchvision.transforms import ToPILImage
  9. import cv2
  10. species= ['0','1','2','3','4','5','6','7','8','9']
  11. model = torch.load("model\\mnist_model_nn.pkl", map_location=torch.device("cpu"))
  12. trans = transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor(), transforms.Grayscale()])
  13. #预测的图片最好是按照训练图片的尺寸维度先进行前处理,黑底白字,尺寸为28*28
  14. predict_dataset = MyDataset_pre('output_me_cut_black',transform = trans)
  15. predict_loader = DataLoader(predict_dataset, batch_size=32)
  16. model.eval()
  17. predict =[]
  18. with torch.no_grad():
  19. for i,data in enumerate(predict_loader):
  20. predict =[]
  21. images = data
  22. output = model(images)
  23. _, pred = torch.max(output, 1)
  24. # print(pred)
  25. for i in range(0,len(images)):
  26. class_name = species[int(pred[i].item())]
  27. predict.append(class_name)
  28. s = ''
  29. for i in range(0,len(predict)):
  30. s += predict[i]
  31. print(s)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/235343?site
推荐阅读
相关标签
  

闽ICP备14008679号