当前位置:   article > 正文

基于CNN网络的mnist手写数字数据库训练和识别_手写数字识别logisitic基于mnist 库

手写数字识别logisitic基于mnist 库

目录

一、理论基础

1.1、MNIST手写数字数据库

1.2、CNN深度学习网络

1.3、MNIST手写数字识别实现

1.4、总结

二、核心程序

三、仿真结论


一、理论基础

       手写数字识别是机器学习领域中的一个重要应用,它可以应用于自动化数据输入、智能检测等领域。MNIST手写数字数据库是机器学习领域中的一个经典数据集,包含了一系列手写数字图片和对应的标签,是进行手写数字识别算法研究和性能评估的标准数据集之一。本文将介绍基于CNN深度学习网络的MNIST手写数字数据库训练和识别的实现方法和步骤。

1.1、MNIST手写数字数据库

      MNIST手写数字数据库是一个常用的手写数字识别数据集,包含了60000个训练集和10000个测试集。每个样本都是28x28像素的灰度图像,表示了0~9之间的数字。其中,训练集用于训练模型,测试集用于测试模型的准确率。

  1. 标签信息: 标签是用于标识图像中手写数字的数字,范围从0到9。例如,标签0表示图像中的手写数字是数字0,标签1表示数字1,以此类推。

  2. 数据的应用: MNIST数据集常被用作数字识别问题的基准数据集,尤其是在机器学习领域。研究者和开发者可以使用MNIST数据集来验证算法、模型或方法的性能。例如,可以将MNIST数据集用于训练和测试各种图像分类算法,如卷积神经网络(CNN)、支持向量机(SVM)等。

  3. 数据获取: MNIST数据集可以从官方网站或许多开源机器学习库中获取,如TensorFlow、PyTorch等。这些库提供了方便的函数和接口,可以轻松加载和处理MNIST数据集。

  4. 数据预处理: 在使用MNIST数据集之前,通常需要进行一些数据预处理,以适应具体的任务和模型。预处理可能包括图像归一化(将像素值缩放到0到1之间)、数据增强(生成变换后的图像以增加样本数量)等。

  5. 应用领域: MNIST数据集的主要应用领域包括数字识别、图像分类、模式识别等。它被用于教育、学术研究以及算法验证和性能评估。

       MNIST手写数字数据库作为一个经典的数据集,为机器学习领域提供了一个常用的基准,有助于研究人员和开发者更好地理解和实验数字识别算法。同时,它也促进了新算法和模型的发展,为图像处理和模式识别领域做出了重要贡献。

1.2、CNN深度学习网络

       CNN(Convolutional Neural Network,卷积神经网络)是一种深度学习网络,特别适用于图像分类和识别任务。与传统神经网络不同,CNN可以自动提取数据的特征,从而实现对图像的高效分类和识别。CNN网络的核心组件包括卷积层、池化层和全连接层。

卷积层
       卷积层是CNN网络中的核心组件,它可以自动提取输入图像的特征。在卷积层中,使用一组可学习的卷积核对输入图像进行卷积操作,得到一组特征图。卷积核是一个小的矩阵,可以通过反向传播算法进行训练,以提取输入图像的不同特征。

 池化层
      池化层是CNN网络中的另一个核心组件,它用于对特征图进行降维操作。在池化层中,通常使用最大池化或平均池化等操作,将每个特征图中的一定区域进行压缩,从而减少网络中的参数数量和计算量。

 全连接层
      全连接层是CNN网络中的最后一层,它用于对特征图进行分类。在全连接层中,可以使用softmax函数将特征图映射到0~1之间的概率值,从而得到对输入图像的分类结果。

1.3、MNIST手写数字识别实现

       基于CNN深度学习网络的MNIST手写数字识别实现主要包括以下几个步骤:数据集加载、数据预处理、模型构建、模型训练和模型测试。下面将逐一介绍这些步骤的具体实现方法。这个模型包含了两个卷积层、两个池化层和两个全连接层。在卷积层和全连接层中,都使用了ReLU激活函数来增强模型的非线性表达能力。在最后的全连接层中,使用了softmax函数来将特征图映射为概率值。

1.4、总结

       本文介绍了基于CNN深度学习网络的MNIST手写数字识别实现方法和步骤。通过对数据集的加载、预处理、模型构建、模型训练和模型测试等步骤的介绍,可以帮助读者了解深度学习网络的基本原理和实现方法,以及如何应用深度学习网络进行手写数字识别任务。
 

二、核心程序

  1. clc;
  2. clear;
  3. close all;
  4. warning off;
  5. addpath(genpath(pwd));
  6. rng('default')
  7. inputSize = 28 * 28;
  8. numLabels = 5;
  9. hiddenSize = 200;
  10. sparsityParam = 0.1; % desired average activation of the hidden units.
  11. % (This was denoted by the Greek alphabet rho, which looks like a lower-case "p",
  12. % in the lecture notes).
  13. lambda = 3e-3; % weight decay parameter
  14. beta = 3; % weight of sparsity penalty term
  15. %% ======================================================================
  16. % STEP 1: Load data from the MNIST database
  17. % Load MNIST database files
  18. mnistData = loadMNISTImages('mnist/train-images-idx3-ubyte');
  19. mnistLabels = loadMNISTLabels('mnist/train-labels-idx1-ubyte');
  20. % Set Unlabeled Set (All Images)
  21. % Simulate a Labeled and Unlabeled set
  22. labeledSet = find(mnistLabels >= 0 & mnistLabels <= 4);
  23. unlabeledSet = find(mnistLabels >= 5);
  24. numTrain = round(numel(labeledSet)/2);
  25. trainSet = labeledSet(1:numTrain);
  26. testSet = labeledSet(numTrain+1:end);
  27. unlabeledData = mnistData(:, unlabeledSet);
  28. trainData = mnistData(:, trainSet);
  29. trainLabels = mnistLabels(trainSet)' + 1; % Shift Labels to the Range 1-5
  30. testData = mnistData(:, testSet);
  31. testLabels = mnistLabels(testSet)' + 1; % Shift Labels to the Range 1-5
  32. % Output Some Statistics
  33. fprintf('# examples in unlabeled set: %d\n', size(unlabeledData, 2));
  34. fprintf('# examples in supervised training set: %d\n\n', size(trainData, 2));
  35. fprintf('# examples in supervised testing set: %d\n\n', size(testData, 2));
  36. %% ======================================================================
  37. % STEP 2: Train the sparse autoencoder
  38. theta = initializeParameters(hiddenSize, inputSize);
  39. addpath minFunc/
  40. autoencoderOptions.Method = 'lbfgs'; % Here, we use L-BFGS to optimize our cost
  41. % function. Generally, for minFunc to work, you
  42. % need a function pointer with two outputs: the
  43. % function value and the gradient. In our problem,
  44. % sparseAutoencoderCost.m satisfies this.
  45. autoencoderOptions.maxIter = 400; % Maximum number of iterations of L-BFGS to run
  46. autoencoderOptions.display = 'on';
  47. if exist('opttheta.mat','file')==2
  48. load('opttheta.mat');
  49. else
  50. [opttheta, cost] = minFunc( @(p) sparseAutoencoderCost(p, ...
  51. inputSize, hiddenSize, ...
  52. lambda, sparsityParam, ...
  53. beta, unlabeledData), ...
  54. theta, autoencoderOptions);
  55. save('opttheta.mat','opttheta');
  56. end
  57. %% -----------------------------------------------------
  58. % Visualize weights
  59. W1 = reshape(opttheta(1:hiddenSize * inputSize), hiddenSize, inputSize);
  60. display_network(W1');
  61. %%======================================================================
  62. %% STEP 3: Extract Features from the Supervised Dataset
  63. trainFeatures = feedForwardAutoencoder(opttheta, hiddenSize, inputSize, ...
  64. trainData);
  65. testFeatures = feedForwardAutoencoder(opttheta, hiddenSize, inputSize, ...
  66. testData);
  67. %%======================================================================
  68. %% STEP 4: Train the softmax classifier
  69. softmaxOptions.maxIter = 100;
  70. lambdaSoftmax = 1e-4; % Weight decay parameter for Softmax
  71. trainNumber = size(trainData,2);
  72. % softmaxTrain 默认数据中已包含截距项
  73. softmaxModel = softmaxTrain(hiddenSize+1, numLabels, lambdaSoftmax, [trainFeatures;ones(1,trainNumber)], trainLabels, softmaxOptions); % learn by features
  74. %softmaxModel = softmaxTrain(inputSize+1, numLabels, lambdaSoftmax, [trainData;ones(1,trainNumber)], trainLabels, softmaxOptions); % learn by raw data
  75. %% -----------------------------------------------------
  76. %%======================================================================
  77. %% STEP 5: Testing
  78. %% ----------------- YOUR CODE HERE ----------------------
  79. % Compute Predictions on the test set (testFeatures) using softmaxPredict
  80. % and softmaxModel
  81. testNumber = size(testData,2);
  82. % softmaxPredict 默认数据中已包含截距项
  83. [pred] = softmaxPredict(softmaxModel, [testFeatures;ones(1,testNumber)]); % predict by test features
  84. %[pred] = softmaxPredict(softmaxModel, [testData;ones(1,testNumber)]); % predict by test raw data
  85. %% -----------------------------------------------------
  86. % Classification Score
  87. fprintf('Test Accuracy: %f%%\n', 100*mean(pred(:) == testLabels(:)));
  88. up2112

三、仿真结论

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/343032
推荐阅读
相关标签
  

闽ICP备14008679号