当前位置:   article > 正文

手写数字识别:基于决策树算法,KNN算法,支持向量机(SVM)算法与卷积神经网络(CNN)算法_基于决策树算法的手写数字识别

基于决策树算法的手写数字识别

目录

手写数字识别问题介绍

理论价值

研究难点

MNIST数据集介绍

数据预处理

应用算法

KNN算法

决策树算法

支持向量机算法

卷积神经网络

模型评估

决策树算法

KNN算法

SVM算法

CNN

模型比较

参考代码

决策树

KNN

SVM

CNN

绘图

main


手写数字识别问题介绍

手写数字识别是模式识别学科的一个传统研究领域,是光学字符识别技术的一个分支。这类研究主要研究如何利用电子计算机自动辨认书写在纸张上的阿拉伯数字。随着信息化的发展,手写数字识别的应用日益广泛,研究高识别率、零误识率和低拒识率的高速识别算法具有重要意义。本报告将对MNIST数据集中的手写数据进行识别,此工作将对不同数据挖掘方法进行构建与评估。期望在手写数据识别问题上获得较好的结果。

理论价值

由于手写数字识别本身的特点,对它的研究有重要的理论价值:

(1)阿拉伯数字是唯一被世界各国通用的符号,对手写体数字识别的研究基本上与文化背景无关,各地的研究工作者基于同一平台开展工作,有利于研究的比较和探讨。

(2)手写数字识别应用广泛,如邮政编码自动识别,税表系统和银行支票自动处理等。这些工作以前需要大量的手工录入,投入的人力物力较多,劳动强度较大。手写数字识别的研究适应了无纸化办公的需要,能大大提高工作效率。

(3)由于数字类别只有10个,较其他字符识别率较高,可用于验证新的理论和做深入的分析研究。许多机器学习和模式识别领域的新理论和算法都是先用手写数字识别进行检验,验证理论的有效性,然后才应用到更复杂的领域当中。这方面的典型例子就是人工神经网络和支持向量机。

(4)手写数字的识别方法很容易推广到其它一些相关问题,如对英文之类拼音文字的识别。事实上,很多学者就是把数字和英文字母的识别放在一起研究的。

研究难点

数字的类别只有10种,笔划简单,其识别问题似乎不是很困难。但事实上,一些测试结果表明,数字的正确识别率并不如印刷体汉字识别率高,甚至也不如联机手写体汉字识别率高,而只仅仅优于脱机手写体汉字识别。这其中的主要原因是:

(1)数字笔划简单,其笔划差别相对较小,字形相差不大,使得准确区分某些数字相当困难;

(2)数字虽然只有10种,且笔划简单,但同一数字写法千差万别,全世界各个国家各个地区的人都在用,其书写上带有明显的区域特性,很难做出可以兼顾世界各种写法的、识别率极高的通用性数字识别系统。

虽然目前国内外对脱机手写数字识别的研究已经取得了很大的成就,但是仍然存在两大难点:

一是识别精度需要达到更高的水平。手写数字识别没有上下文,数据中的每一个数据都至关重要。而数字识别经常涉及金融、财会领域,其严格性更是不言而喻。因此,国内外众多的学者都在为提高手写数字的识别率,降低误识率而努力。

二是识别的速度要达到很高的水平。数字识别的输入通常是很大量的数据,而高精度与高速度是相互矛盾的,因此对识别算法提出了更高的要求。

MNIST数据集介绍

1998年,美国国家标准技术研究所(NIST)发布了MNIST数据集,这是一个用来训练各种图像处理系统的二进制图像数据集,广泛应用于机器学习中的训练和测试。MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的。MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的灰度图像,每张图像包含一个手写数字。

手写数据识别是个多分类问题,其中数据标签分别是0到9共10种,利用python包中的torchvision.datasets.MNIST函数对原始数据进行下载并解压,获得二进制文件,如图。

其中解压缩共获得四个文件。文件名称中t10k代表测试集,train代表训练集;images存储图像的数量,图像的高,图像的宽与图像各个像素的值,labels存储标签的数量与标签的类别信息。其中训练集有6万个样本,测试集有1万个样本。

数据预处理

MNIST数据集构建时已经将数据进行部分预处理,比如将每个手写数据放在图片正中心;将每个样本固定在28×28固定大小;图片灰度化,将其降噪转化为黑白两色;最终得到我们下载的数据。我们又对数据进行归一化,把所有的输入数值控制在[0,1]之间。

获得手写数据的原始数据后,对图片进行采样处理,获得手写数据的位置并且确保手写数字在图片的中心处,防止手写数字因位置的偏移对识别造成影响;控制每个样本的大小相同,使图片都为28×28固定大小,保证测试集与验证机的数据维度相同;将图片由三通道彩色的格式(3×28×28)转化成单通道的灰度图(1×28×28);此时所有数据的范围都是[0,255]之间的灰度值,将所有数值除255,使所有的输入数值在[0,1]之间。

应用算法

本文应用了四种不同的算法对手写数字数据集进行分类,分别是决策树算法,KNN算法,支持向量机(SVM)算法与卷积神经网络(CNN)算法。

KNN算法

KNN分类算法是数据挖掘分类技术中最简单的方法之一,最初由Cover和Hart于1968年提出。KNN分类算法简单易懂,但存在存储要求高、分类响应效率低、噪声容限低等缺点。KNN算法的核心思想是,将手写数字图像样本编码映射到一个特征空间中,在这个样本空间中寻找与该样本编码距离最近的K个样本,记录该K个样本所属的类别,寻找K个样本中的大多数类别,由于相同类别的特征是相同的,则该样本就属于这一类别。该模型在确定分类决策上只依据最邻近的K个样本的类别来决定待分类的样本所属的类别。

本文使用sklearn包中的KNeighborsClassifier类构建模型,其中的K值选择默认值5。使用fit方法进行模型的训练,并使用predict_proba与predict方法进行预测。

决策树算法

决策树是一个预测模型,它代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表某个可能的属性值,而每个叶节点则对应从根节点到该叶节点所经历的路径所表示的对象的值。从数据产生决策树的机器学习技术叫做决策树学习,通俗说就是决策树。一个决策树包含三种类型的节点:决策节点,机会节点与终结节点。决策树是一树状结构,它的每一个叶节点对应着一个分类,非叶节点对应着在某个属性上的划分,根据样本在该属性上的不同取值将其划分成若干个子集。对于非纯的叶节点,多数类的标号给出到达这个节点的样本所属的类。构造决策树的核心问题是在每一步如何选择适当的属性对样本做拆分。对一个分类问题,从已知类标记的训练样本中学习并构造出决策树是一个自上而下,分而治之的过程。

本文使用sklearn包中的DecisionTreeClassifier类构建模型,其中的参数选择使用默认参数。使用fit方法进行模型的训练,并使用predict_proba与predict方法进行预测。

支持向量机算法

支持向量机(SVM)发表于1995年,它的基本模型是定义在特征空间上的间隔最大的线性分类器,由于SVM在文本分类问题中性能极其突出,很快在机器学习中得到广泛的应用。SVM的学习策略是通过间隔最大化划分超平面实现样本的分类。SVM将样本的输入向量映射到一个更高维度的特征空间中,并识别将数据点

分成不同类别的超平面。决策超平面使最接近边界的实例之间的边际距离最大化。由此产生的分类器具有相当大的泛化性,因此可用于新样品的可靠分类。

本文使用sklearn包中的SVC类构建模型,其中核函数选择高斯核函数(rbf)。使用fit方法进行模型的训练,并使用predict_proba与predict方法进行预测。

卷积神经网络

卷积神经网络(CNN)是一种具有局部连接、权值共享等特点的深层前馈神经网络,是深度学习的代表算法之一,擅长处理图像特别是图像识别等相关机器学习问题,比如图像分类、目标检测、图像分割等各种视觉任务中都有显著的提升效果,是目前应用最广泛的模型之一。

卷积神经网络具有表征学习能力,能够按其阶层结构对输入信息进行平移不变分类,可以进行监督学习和非监督学习,其隐含层内的卷积核参数共享和层间连接的稀疏性使得卷积神经网络能够以较小的计算量对格点化特征,例如像素和音频进行学习、有稳定的效果且对数据没有额外的特征工程要求,并被大量应用于计算机视觉、自然语言处理等领域。

本文使用pytorch包中的nn.Conv2d类构建模型,其中选择最大池化nn.MaxPool2d进行特征值提取,分别经过两层卷积层与最大池化层后,利用一个全连接层对数据进行输出。使用Adam对参数进行优化,使用交叉熵(CrossEntropyLoss)进行loss计算。模型的训练过程中,每一批次的大小设置为50,学习率设置为10-3,共对训练集进行100轮训练。针对模型的输出结果,使用softmax函数计算预测样本对每个类别的预测概率。

模型评估

本文针对不同的模型,使用了正确率与错误率,ROC曲线与下方坐标轴围成的面积(AUC)与分类混淆矩阵对模型进行评估。

正确率定义为分类正确的个数除以总的样本数,错误率定义为分类错误的个数除以总的样本数。混淆矩阵也称误差矩阵,它可以非常容易的表明多个类别是否有混淆,是表示精度评价的一种标准格式,用n行n列的矩阵形式来表示。AUC值是一个概率值,随机挑选一个正样本以及负样本,当前的分类算法根据计算得到的Score值将这个正样本排在负样本前面的概率就是AUC值,AUC值越大,当前分类算法越有可能将正样本排在负样本前面,从而能够更好地分类。

决策树算法

决策树算法的正确率为0.8773,错误率为0.1227。其ROC曲线如图,可以发现数字‘8’的区分在该算法的表现最差,数字‘1’的区分在该算法的表现最好。各个类别的分类效果有一定的区别。

决策树算法的混淆矩阵如图,可以发现各个类别或多或少是都被分类成的其他类别,但大部分的分类结果是正确的,其中数字‘3’与数字‘5’之间经常进行误判;数字‘2’与数字‘3’,数字‘4’与数字‘9’误判率也偏高。

KNN算法

KNN算法的正确率为0.9688,错误率为0.0312。其ROC曲线如图,可以发现数字‘8’的区分在该算法的表现最差,数字‘1’的区分在该算法的表现最好。各个类别的分类效果相差不大。

KNN算法的混淆矩阵如图,可以发现分类错误的概率较少,大部分的分类结果是正确的,其中数字‘7’经常误判成数字‘1’;数字‘4’经常误判成数字‘9’。

SVM算法

SVM算法的正确率为0.9792,错误率为0.0208。其ROC曲线如图,可以发现数字‘9’的区分在该算法的表现最差,数字‘1’的区分在该算法的表现最好。各个类别的分类效果相差不大。

SVM算法的混淆矩阵如图,可以发现分类错误的概率较少,大部分的分类结果是正确的,其中数字‘7’经常误判成数字‘2’和数字‘9’;数字‘4’与数字‘9’经常误判。

CNN

CNN算法的正确率为0.9907,错误率为0.0093。其ROC曲线如图,可以发现数字‘7’的区分在该算法的表现最差,数字‘0’的区分在该算法的表现最好。各个类别的分类效果相差不大。

CNN算法的混淆矩阵如图,可以发现分类错误的概率较少,大部分的分类结果是正确的,其中数字‘2’与数字‘7’经常误判;数字‘4’与数字‘9’经常误判。

模型比较

在本文所构建的4种模型中,每个模型对手写数字问题进行了分类,得到了不同的结果。其中效果最差的为决策树模型,其次是KNN模型与SVM模型,效果最好的为CNN模型。其中KNN模型与决策树模型在模型的构建上速度较快,而CNN模型与SVM模型构建速度较慢。其中数字‘7’的区分问题与数字‘4’与‘9’的混淆问题是所有模型的共同难题。

参考代码

决策树

  1. import os
  2. import joblib
  3. import torchvision
  4. from sklearn.svm import SVC
  5. from sklearn.tree import DecisionTreeClassifier
  6. import plot_graph
  7. def run(result_path: str):
  8. train_data = torchvision.datasets.MNIST(
  9. root='./mnist/',
  10. train=True,
  11. transform=torchvision.transforms.ToTensor(),
  12. download=False,
  13. )
  14. train_x = train_data.data.reshape(train_data.data.shape[0], -1) / 255.
  15. train_y, train_x = train_data.targets.numpy(), train_x.numpy()
  16. model = DecisionTreeClassifier()
  17. model.fit(train_x, train_y)
  18. joblib.dump(model, os.path.join(result_path, 'dt.m'))
  19. return 0
  20. def estimate(result_path: str):
  21. dt = joblib.load(os.path.join(result_path, 'dt.m'))
  22. test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
  23. test_x = test_data.data.reshape(test_data.data.shape[0], -1) / 255.
  24. test_y, test_x = test_data.targets.numpy(), test_x.numpy()
  25. pred = dt.predict(test_x)
  26. y_pred = dt.predict_proba(test_x)
  27. plot_graph.plot_auc(y=test_y, pred_y=y_pred, show=True)
  28. print("正确率:", sum(pred == test_y) / y_pred.shape[0])
  29. print("错误率:", 1 - sum(pred == test_y) / y_pred.shape[0])
  30. plot_graph.show_crosstab(y=test_y, pred=pred, show=True)

KNN

  1. import os
  2. import joblib
  3. import torchvision
  4. from sklearn.neighbors import KNeighborsClassifier
  5. import plot_graph
  6. def run(result_path: str):
  7. train_data = torchvision.datasets.MNIST(
  8. root='./mnist/',
  9. train=True,
  10. transform=torchvision.transforms.ToTensor(),
  11. download=False,
  12. )
  13. train_x = train_data.data.reshape(train_data.data.shape[0], -1) / 255.
  14. train_y, train_x = train_data.targets.numpy(), train_x.numpy()
  15. model = KNeighborsClassifier()
  16. model.fit(train_x, train_y)
  17. joblib.dump(model, os.path.join(result_path, 'knn.m'))
  18. return 0
  19. def estimate(result_path: str):
  20. knn = joblib.load(os.path.join(result_path, 'knn.m'))
  21. test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
  22. test_x = test_data.data.reshape(test_data.data.shape[0], -1) / 255.
  23. test_y, test_x = test_data.targets.numpy(), test_x.numpy()
  24. pred = knn.predict(test_x)
  25. y_pred = knn.predict_proba(test_x)
  26. plot_graph.plot_auc(y=test_y, pred_y=y_pred, show=True)
  27. print("正确率:", sum(pred == test_y) / y_pred.shape[0])
  28. print("错误率:", 1 - sum(pred == test_y) / y_pred.shape[0])
  29. plot_graph.show_crosstab(y=test_y, pred=pred, show=True)

SVM

  1. import os
  2. import joblib
  3. import torchvision
  4. from sklearn.svm import SVC
  5. import plot_graph
  6. def run(result_path: str):
  7. train_data = torchvision.datasets.MNIST(
  8. root='./mnist/',
  9. train=True,
  10. transform=torchvision.transforms.ToTensor(),
  11. download=False,
  12. )
  13. train_x = train_data.data.reshape(train_data.data.shape[0], -1) / 255.
  14. train_y, train_x = train_data.targets.numpy(), train_x.numpy()
  15. model = SVC(kernel='rbf', probability=True)
  16. model.fit(train_x, train_y)
  17. joblib.dump(model, os.path.join(result_path, 'svm.m'))
  18. return 0
  19. def estimate(result_path: str):
  20. svm = joblib.load(os.path.join(result_path, 'svm.m'))
  21. test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
  22. test_x = test_data.data.reshape(test_data.data.shape[0], -1) / 255.
  23. test_y, test_x = test_data.targets.numpy(), test_x.numpy()
  24. pred = svm.predict(test_x)
  25. y_pred = svm.predict_proba(test_x)
  26. plot_graph.plot_auc(y=test_y, pred_y=y_pred, show=True)
  27. print("正确率:", sum(pred == test_y) / y_pred.shape[0])
  28. print("错误率:", 1 - sum(pred == test_y) / y_pred.shape[0])
  29. plot_graph.show_crosstab(y=test_y, pred=pred, show=True)

CNN

  1. import os
  2. import torch
  3. import torch.nn as nn
  4. import torch.utils.data as Data
  5. import torchvision
  6. import tqdm
  7. from torch import softmax
  8. import plot_graph
  9. class CNN(nn.Module):
  10. def __init__(self):
  11. super(CNN, self).__init__()
  12. self.conv1 = nn.Sequential(
  13. nn.Conv2d(
  14. in_channels=1,
  15. out_channels=16,
  16. kernel_size=5,
  17. stride=1,
  18. padding=2,
  19. ),
  20. nn.ReLU(),
  21. nn.MaxPool2d(kernel_size=2),
  22. )
  23. self.conv2 = nn.Sequential(
  24. nn.Conv2d(16, 32, 5, 1, 2),
  25. nn.ReLU(),
  26. nn.MaxPool2d(2),
  27. )
  28. self.out = nn.Linear(32 * 7 * 7, 10)
  29. def forward(self, x):
  30. x = self.conv1(x)
  31. x = self.conv2(x)
  32. x = x.view(x.size(0), -1)
  33. output = self.out(x)
  34. return output
  35. def run(result_path: str):
  36. device = torch.device('cuda:0')
  37. torch.manual_seed(1)
  38. EPOCH = 100
  39. BATCH_SIZE = 50
  40. LR = 0.001
  41. DOWNLOAD_MNIST = False
  42. train_data = torchvision.datasets.MNIST(
  43. root='./mnist/',
  44. train=True,
  45. transform=torchvision.transforms.ToTensor(),
  46. download=DOWNLOAD_MNIST,
  47. )
  48. train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
  49. cnn = CNN().to(device)
  50. optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
  51. loss_func = nn.CrossEntropyLoss().to(device)
  52. for epoch in tqdm.tqdm(range(EPOCH), total=EPOCH, desc="EPOCH"):
  53. for step, (b_x, b_y) in enumerate(train_loader):
  54. b_x, b_y = b_x.to(device), b_y.to(device)
  55. output = cnn(b_x)
  56. loss = loss_func(output, b_y)
  57. optimizer.zero_grad()
  58. loss.backward()
  59. optimizer.step()
  60. torch.save(cnn, os.path.join(result_path, 'CNN.pkl'))
  61. def estimate(result_path: str):
  62. device = torch.device('cuda:0')
  63. cnn = torch.load(os.path.join(result_path, 'CNN.pkl'))
  64. test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
  65. test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor) / 255.
  66. test_y = test_data.targets
  67. test_output = cnn(test_x.to(device))
  68. predictions = softmax(test_output.to('cpu'), dim=-1)
  69. plot_graph.plot_auc(y=test_y.numpy(), pred_y=predictions.detach().numpy(), show=True)
  70. pred_y = torch.max(test_output.to('cpu'), 1)[1].data.numpy().squeeze()
  71. print("正确率:", sum(pred_y == test_y.numpy()) / pred_y.shape[0])
  72. print("错误率:", 1 - sum(pred_y == test_y.numpy()) / pred_y.shape[0])
  73. plot_graph.show_crosstab(y=test_y.numpy(), pred=pred_y, show=True)

绘图

  1. import pandas as pd
  2. from matplotlib import pyplot as plt
  3. from sklearn.metrics import auc, roc_curve
  4. import seaborn as sns
  5. from sklearn.preprocessing import LabelBinarizer
  6. def plot_auc(y, pred_y, show: bool):
  7. encoder = LabelBinarizer()
  8. y = encoder.fit_transform(y)
  9. fpr = dict()
  10. tpr = dict()
  11. roc_auc = dict()
  12. for i in range(10):
  13. fpr[i], tpr[i], _ = roc_curve(y[:, i], pred_y[:, i])
  14. roc_auc[i] = auc(fpr[i], tpr[i])
  15. plt.plot(fpr[i], tpr[i], label='AUC'+str(i)+'=%.5f' % roc_auc[i])
  16. plt.plot([0, 1], [0, 1], linestyle='--')
  17. plt.legend()
  18. plt.title('ROC')
  19. plt.ylabel('True Positive Rate')
  20. plt.xlabel('False Positive Rate')
  21. plt.show()
  22. def show_crosstab(pred, y, show: bool):
  23. a = pd.crosstab(y, pred, rownames=['Actual'], colnames=['Predicted'])
  24. sns.set()
  25. sns.heatmap(a, cmap='YlOrRd',annot=True,fmt="d",annot_kws={'color':'blue'})
  26. plt.show()
  27. print(a)
  28. return a

main

  1. import argparse
  2. import os
  3. import KNN
  4. import CNN
  5. import SVM
  6. import DT
  7. parser = argparse.ArgumentParser()
  8. parser.add_argument('--base_path', help='path',
  9. type=str, default=os.getcwd())
  10. parser.add_argument('--model', help='model',
  11. type=str, default='CNN')
  12. args = parser.parse_args()
  13. if args.model == 'CNN':
  14. CNN.run(result_path=args.base_path)
  15. elif args.model == 'CNN-estimate':
  16. CNN.estimate(result_path=args.base_path)
  17. elif args.model == 'KNN':
  18. KNN.run(result_path=args.base_path)
  19. elif args.model == 'KNN-estimate':
  20. KNN.estimate(result_path=args.base_path)
  21. elif args.model == 'SVM':
  22. SVM.run(result_path=args.base_path)
  23. elif args.model == 'SVM-estimate':
  24. SVM.estimate(result_path=args.base_path)
  25. elif args.model == 'DT':
  26. DT.run(result_path=args.base_path)
  27. elif args.model == 'DT-estimate':
  28. DT.estimate(result_path=args.base_path)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小桥流水78/article/detail/775056
推荐阅读
  

闽ICP备14008679号