当前位置:   article > 正文

深度学习之手写数字分类

手写数字分类

问题描述:

将手写数字的灰度图像(28 像素×28 像素)划分到 10 个类别 中(0~9)。我们将使用 MNIST 数据集,它是机器学习领域的一个经典数据集,其历史几乎和这 个领域一样长,而且已被人们深入研究。这个数据集包含 60 000 张训练图像和 10 000 张测试图 像,由美国国家标准与技术研究院(National Institute of Standards and Technology,即 MNIST 中 的 NIST)在 20 世纪 80 年代收集得到。你可以将“解决”MNIST 问题看作深度学习的“Hello World”,正是用它来验证你的算法是否按预期运行。当你成为机器学习从业者后,会发现 MNIST 一次又一次地出现在科学论文、博客文章等中。

1.准备数据:

  1. import numpy as np
  2. import paddle as paddle
  3. import paddle.fluid as fluid
  4. from PIL import Image
  5. import matplotlib.pyplot as plt
  6. import os
  7. train_reader = paddle.batch(paddle.reader.shuffle(paddle.dataset.mnist.train(),
  8. buf_size=512),batch_size=128)
  9. test_reader = paddle.batch(paddle.dataset.mnist.test(),batch_size=128)
  10. temp_reader = paddle.batch(paddle.dataset.mnist.train(),batch_size=1)
  11. temp_data=next(temp_reader())
  12. print(temp_data)

打印以下,观察mnist数据集

  1. G:\RJAZ\python3.6.6\py_data\python.exe I:/PaddlePaddle/code/wrirren_number/main.py
  2. [================================================= ][==================================================]
  3. [==================================================]
  4. [==================================================]
  5. [==================================================]
  6. [(array([-1. , -1. , -1. , -1. , -1. ,
  7. -1. , -1. , -1. , -1. , -1. ,
  8. -1. , -1. , -1. , -1. , -1. ,
  9. -1. , -1. , -1. , -1. , -1. ,
  10. -1. , -1. , -1. , -1. , -1. ,
  11. -1. , -1. , -1. , -1. , -1. ,
  12. -1. , -1. , -1. , -1. , -1. ,
  13. -1. , -1. , -1. , -1. , -1. ,
  14. -1. , -1. , -1. , -1. , -1. ,
  15. -1. , -1. , -1. , -1. , -1. ,
  16. -1. , -1. , -1. , -1. , -1. ,
  17. -1. , -1. , -1. , -1. , -1. ,
  18. -1. , -1. , -1. , -1. , -1. ,
  19. -1. , -1. , -1. , -1. , -1. ,
  20. -1. , -1. , -1. , -1. , -1. ,
  21. -1. , -1. , -1. , -1. , -1. ,
  22. -1. , -1. , -1. , -1. , -1. ,
  23. -1. , -1. , -1. , -1. , -1. ,
  24. -1. , -1. , -1. , -1. , -1. ,
  25. -1. , -1. , -1. , -1. , -1. ,
  26. -1. , -1. , -1. , -1. , -1. ,
  27. -1. , -1. , -1. , -1. , -1. ,
  28. -1. , -1. , -1. , -1. , -1. ,
  29. -1. , -1. , -1. , -1. , -1. ,
  30. -1. , -1. , -1. , -1. , -1. ,
  31. -1. , -1. , -1. , -1. , -1. ,
  32. -1. , -1. , -1. , -1. , -1. ,
  33. -1. , -1. , -1. , -1. , -1. ,
  34. -1. , -1. , -1. , -1. , -1. ,
  35. -1. , -1. , -1. , -1. , -1. ,
  36. -1. , -1. , -0.9764706 , -0.85882354, -0.85882354,
  37. -0.85882354, -0.01176471, 0.06666672, 0.37254906, -0.79607844,
  38. 0.30196083, 1. , 0.9372549 , -0.00392157, -1. ,
  39. -1. , -1. , -1. , -1. , -1. ,
  40. -1. , -1. , -1. , -1. , -1. ,
  41. -1. , -0.7647059 , -0.7176471 , -0.26274508, 0.20784318,
  42. 0.33333337, 0.9843137 , 0.9843137 , 0.9843137 , 0.9843137 ,
  43. 0.9843137 , 0.7647059 , 0.34901965, 0.9843137 , 0.8980392 ,
  44. 0.5294118 , -0.4980392 , -1. , -1. , -1. ,
  45. -1. , -1. , -1. , -1. , -1. ,
  46. -1. , -1. , -1. , -0.6156863 , 0.8666667 ,
  47. 0.9843137 , 0.9843137 , 0.9843137 , 0.9843137 , 0.9843137 ,
  48. 0.9843137 , 0.9843137 , 0.9843137 , 0.96862745, -0.27058822,
  49. -0.35686272, -0.35686272, -0.56078434, -0.69411767, -1. ,
  50. -1. , -1. , -1. , -1. , -1. ,
  51. -1. , -1. , -1. , -1. , -1. ,
  52. -1. , -0.85882354, 0.7176471 , 0.9843137 , 0.9843137 ,
  53. 0.9843137 , 0.9843137 , 0.9843137 , 0.5529412 , 0.427451 ,
  54. 0.9372549 , 0.8901961 , -1. , -1. , -1. ,
  55. -1. , -1. , -1. , -1. , -1. ,
  56. -1. , -1. , -1. , -1. , -1. ,
  57. -1. , -1. , -1. , -1. , -1. ,
  58. -0.372549 , 0.22352946, -0.1607843 , 0.9843137 , 0.9843137 ,
  59. 0.60784316, -0.9137255 , -1. , -0.6627451 , 0.20784318,
  60. -1. , -1. , -1. , -1. , -1. ,
  61. -1. , -1. , -1. , -1. , -1. ,
  62. -1. , -1. , -1. , -1. , -1. ,
  63. -1. , -1. , -1. , -1. , -0.8901961 ,
  64. -0.99215686, 0.20784318, 0.9843137 , -0.29411763, -1. ,
  65. -1. , -1. , -1. , -1. , -1. ,
  66. -1. , -1. , -1. , -1. , -1. ,
  67. -1. , -1. , -1. , -1. , -1. ,
  68. -1. , -1. , -1. , -1. , -1. ,
  69. -1. , -1. , -1. , -1. , 0.09019613,
  70. 0.9843137 , 0.4901961 , -0.9843137 , -1. , -1. ,
  71. -1. , -1. , -1. , -1. , -1. ,
  72. -1. , -1. , -1. , -1. , -1. ,
  73. -1. , -1. , -1. , -1. , -1. ,
  74. -1. , -1. , -1. , -1. , -1. ,
  75. -1. , -1. , -0.9137255 , 0.4901961 , 0.9843137 ,
  76. -0.45098037, -1. , -1. , -1. , -1. ,
  77. -1. , -1. , -1. , -1. , -1. ,
  78. -1. , -1. , -1. , -1. , -1. ,
  79. -1. , -1. , -1. , -1. , -1. ,
  80. -1. , -1. , -1. , -1. , -1. ,
  81. -1. , -0.7254902 , 0.8901961 , 0.7647059 , 0.254902 ,
  82. -0.15294117, -0.99215686, -1. , -1. , -1. ,
  83. -1. , -1. , -1. , -1. , -1. ,
  84. -1. , -1. , -1. , -1. , -1. ,
  85. -1. , -1. , -1. , -1. , -1. ,
  86. -1. , -1. , -1. , -1. , -1. ,
  87. -0.36470586, 0.88235295, 0.9843137 , 0.9843137 , -0.06666666,
  88. -0.8039216 , -1. , -1. , -1. , -1. ,
  89. -1. , -1. , -1. , -1. , -1. ,
  90. -1. , -1. , -1. , -1. , -1. ,
  91. -1. , -1. , -1. , -1. , -1. ,
  92. -1. , -1. , -1. , -1. , -0.64705884,
  93. 0.45882356, 0.9843137 , 0.9843137 , 0.17647064, -0.7882353 ,
  94. -1. , -1. , -1. , -1. , -1. ,
  95. -1. , -1. , -1. , -1. , -1. ,
  96. -1. , -1. , -1. , -1. , -1. ,
  97. -1. , -1. , -1. , -1. , -1. ,
  98. -1. , -1. , -1. , -0.8745098 , -0.27058822,
  99. 0.9764706 , 0.9843137 , 0.4666667 , -1. , -1. ,
  100. -1. , -1. , -1. , -1. , -1. ,
  101. -1. , -1. , -1. , -1. , -1. ,
  102. -1. , -1. , -1. , -1. , -1. ,
  103. -1. , -1. , -1. , -1. , -1. ,
  104. -1. , -1. , -1. , 0.9529412 , 0.9843137 ,
  105. 0.9529412 , -0.4980392 , -1. , -1. , -1. ,
  106. -1. , -1. , -1. , -1. , -1. ,
  107. -1. , -1. , -1. , -1. , -1. ,
  108. -1. , -1. , -1. , -1. , -1. ,
  109. -1. , -1. , -1. , -0.6392157 , 0.0196079 ,
  110. 0.43529415, 0.9843137 , 0.9843137 , 0.62352943, -0.9843137 ,
  111. -1. , -1. , -1. , -1. , -1. ,
  112. -1. , -1. , -1. , -1. , -1. ,
  113. -1. , -1. , -1. , -1. , -1. ,
  114. -1. , -1. , -1. , -1. , -0.69411767,
  115. 0.16078436, 0.79607844, 0.9843137 , 0.9843137 , 0.9843137 ,
  116. 0.9607843 , 0.427451 , -1. , -1. , -1. ,
  117. -1. , -1. , -1. , -1. , -1. ,
  118. -1. , -1. , -1. , -1. , -1. ,
  119. -1. , -1. , -1. , -1. , -1. ,
  120. -0.8117647 , -0.10588235, 0.73333335, 0.9843137 , 0.9843137 ,
  121. 0.9843137 , 0.9843137 , 0.5764706 , -0.38823527, -1. ,
  122. -1. , -1. , -1. , -1. , -1. ,
  123. -1. , -1. , -1. , -1. , -1. ,
  124. -1. , -1. , -1. , -1. , -1. ,
  125. -1. , -0.81960785, -0.4823529 , 0.67058825, 0.9843137 ,
  126. 0.9843137 , 0.9843137 , 0.9843137 , 0.5529412 , -0.36470586,
  127. -0.9843137 , -1. , -1. , -1. , -1. ,
  128. -1. , -1. , -1. , -1. , -1. ,
  129. -1. , -1. , -1. , -1. , -1. ,
  130. -1. , -1. , -0.85882354, 0.3411765 , 0.7176471 ,
  131. 0.9843137 , 0.9843137 , 0.9843137 , 0.9843137 , 0.5294118 ,
  132. -0.372549 , -0.92941177, -1. , -1. , -1. ,
  133. -1. , -1. , -1. , -1. , -1. ,
  134. -1. , -1. , -1. , -1. , -1. ,
  135. -1. , -1. , -1. , -0.5686275 , 0.34901965,
  136. 0.77254903, 0.9843137 , 0.9843137 , 0.9843137 , 0.9843137 ,
  137. 0.9137255 , 0.04313731, -0.9137255 , -1. , -1. ,
  138. -1. , -1. , -1. , -1. , -1. ,
  139. -1. , -1. , -1. , -1. , -1. ,
  140. -1. , -1. , -1. , -1. , -1. ,
  141. -1. , 0.06666672, 0.9843137 , 0.9843137 , 0.9843137 ,
  142. 0.6627451 , 0.05882359, 0.03529418, -0.8745098 , -1. ,
  143. -1. , -1. , -1. , -1. , -1. ,
  144. -1. , -1. , -1. , -1. , -1. ,
  145. -1. , -1. , -1. , -1. , -1. ,
  146. -1. , -1. , -1. , -1. , -1. ,
  147. -1. , -1. , -1. , -1. , -1. ,
  148. -1. , -1. , -1. , -1. , -1. ,
  149. -1. , -1. , -1. , -1. , -1. ,
  150. -1. , -1. , -1. , -1. , -1. ,
  151. -1. , -1. , -1. , -1. , -1. ,
  152. -1. , -1. , -1. , -1. , -1. ,
  153. -1. , -1. , -1. , -1. , -1. ,
  154. -1. , -1. , -1. , -1. , -1. ,
  155. -1. , -1. , -1. , -1. , -1. ,
  156. -1. , -1. , -1. , -1. , -1. ,
  157. -1. , -1. , -1. , -1. , -1. ,
  158. -1. , -1. , -1. , -1. , -1. ,
  159. -1. , -1. , -1. , -1. , -1. ,
  160. -1. , -1. , -1. , -1. , -1. ,
  161. -1. , -1. , -1. , -1. , -1. ,
  162. -1. , -1. , -1. , -1. ], dtype=float32), 5)]

2.配置网络

以下的代码判断就是定义一个简单的多层感知器,一共有三层,两个大小为100的隐层和一个大小为10的输出层,因为MNIST数据集是手写0到9的灰度图像,类别有10个,所以最后的输出大小是10。最后输出层的激活函数是Softmax,所以最后的输出层相当于一个分类器。加上一个输入层的话,多层感知器的结构是:输入层-->>隐层-->>隐层-->>输出层。

  1. def multilayer_perceptron(input):
  2. # 第一个全连接层,激活函数为ReLU
  3. hidden1 = fluid.layers.fc(input=input, size=100, act='relu')
  4. # 第二个全连接层,激活函数为ReLU
  5. hidden2 = fluid.layers.fc(input=hidden1, size=100, act = 'relu')
  6. # 以softmax为激活函数的全连接输出层,大小为10
  7. prediction = fluid.layers.fc(input=hidden2, size=10, act = 'softmax')
  8. return prediction

定义输入输出层,因为输入的是28*28的灰度图像,所以它的形状是[1,28,28]的,1表示的是颜色通道,如果是彩色图,则应为[3,28,28],因为彩色图有RGB三个通道

  1. image = fluid.layers.data(name = 'image', shape=[1,28,28], dtype='float32')
  2. label = fluid.layers.data(name='label', shape=[1], dtype='int64')

获取分类器,,这里用定义好的函数来获取

model = multilayer_perceptron(image)

接下来定义损失函数

这里使用的是交叉熵损失函数,此函数在分类上比较常用

cost = fluid.layers.cross_entropy(input=model, label=label)#使用交叉熵损失函数,描述真实样本标签和预测概率之间的差值

定义了一个损失函数之后,还有对它求平均值,因为定义的是一个Batch的损失值。

avg_cost = fluid.layers.mean(cost)

定义一个准确率函数,这个可以在训练的时候输出分类的准确率

acc = fluid.layers.accuracy(input=model, label=label)

接着是定义优化方法,这次我们使用的是Adam优化方法,同时指定学习率为0.001。

  1. optimizer = fluid.optimizer.AdamOptimizer(learning_rate=0.001)#使用Adam算法进行优化
  2. opts = optimizer.minimize(avg_cost)

定义一个使用CPU的解析器

  1. place = fluid.CPUPlace()
  2. exe = fluid.Executor(place)

参数初始化

exe.run(fluid.default_startup_program())

定义输入数据维度

输入数据的维度是图像数据和图像对应的标签,每个类别的图像都要对应一个标签,这个标签是从0开始递增的整形值

feeder = fluid.DataFeeder(place=place, feed_list=[image, label])

下面开始训练并测试

我们这次训练5个Pass,并且在每一轮Pass之后在进行一次测试,使用测试集进行测试,并求出当前的cost和准确率的average

  1. for pass_id in range(5):
  2. # 进行训练
  3. for batch_id, data in enumerate(train_reader()): #遍历train_reader
  4. train_cost, train_acc = exe.run(program=fluid.default_main_program(),#运行主程序
  5. feed=feeder.feed(data), #给模型喂入数据
  6. fetch_list=[avg_cost, acc])#fetch 误差、准确率
  7. # 每100个batch打印一次信息 误差、准确率
  8. if batch_id % 100 == 0:
  9. print('Pass:%d, Batch:%d, Cost:%0.5f, Accuracy:%0.5f' %
  10. (pass_id, batch_id, train_cost[0], train_acc[0]))
  11. # 进行测试
  12. test_accs = []
  13. test_costs = []
  14. #每训练一轮 进行一次测试
  15. for batch_id, data in enumerate(test_reader()): #遍历test_reader
  16. test_cost, test_acc = exe.run(program=fluid.default_main_program(), #执行训练程序
  17. feed=feeder.feed(data), #喂入数据
  18. fetch_list=[avg_cost, acc]) #fetch 误差、准确率
  19. test_accs.append(test_acc[0]) #每个batch的准确率
  20. test_costs.append(test_cost[0]) #每个batch的误差
  21. # 求测试结果的平均值
  22. test_cost = (sum(test_costs) / len(test_costs)) #每轮的平均误差
  23. test_acc = (sum(test_accs) / len(test_accs)) #每轮的平均准确率
  24. print('Test:%d, Cost:%0.5f, Accuracy:%0.5f' % (pass_id, test_cost, test_acc))
  25. # 保存模型
  26. model_save_dir = "./hand.inference.model"
  27. # 如果保存路径不存在就创建
  28. if not os.path.exists(model_save_dir):
  29. os.makedirs(model_save_dir)
  30. print('save models to %s' % (model_save_dir))
  31. fluid.io.save_inference_model(model_save_dir, # 保存推理model的路径
  32. ['image'], # 推理(inference)需要 feed 的数据
  33. [model], # 保存推理(inference)结果的 Variables
  34. exe) # executor 保存 inference model

先运行以下看下输出

  1. Pass:0, Batch:0, Cost:2.50275, Accuracy:0.07812
  2. Pass:0, Batch:100, Cost:0.32129, Accuracy:0.89062
  3. Pass:0, Batch:200, Cost:0.28232, Accuracy:0.89062
  4. Pass:0, Batch:300, Cost:0.30894, Accuracy:0.90625
  5. Pass:0, Batch:400, Cost:0.23480, Accuracy:0.92188
  6. Test:0, Cost:0.22551, Accuracy:0.93068
  7. save models to ./hand.inference.model
  8. Pass:1, Batch:0, Cost:0.21532, Accuracy:0.91406
  9. Pass:1, Batch:100, Cost:0.25423, Accuracy:0.92188
  10. Pass:1, Batch:200, Cost:0.16427, Accuracy:0.96094
  11. Pass:1, Batch:300, Cost:0.12058, Accuracy:0.96875
  12. Pass:1, Batch:400, Cost:0.13565, Accuracy:0.94531
  13. Test:1, Cost:0.14589, Accuracy:0.95481
  14. save models to ./hand.inference.model
  15. Pass:2, Batch:0, Cost:0.14125, Accuracy:0.95312
  16. Pass:2, Batch:100, Cost:0.14312, Accuracy:0.96094
  17. Pass:2, Batch:200, Cost:0.07529, Accuracy:0.96875
  18. Pass:2, Batch:300, Cost:0.10757, Accuracy:0.96094
  19. Pass:2, Batch:400, Cost:0.19855, Accuracy:0.94531
  20. Test:2, Cost:0.11291, Accuracy:0.96242
  21. save models to ./hand.inference.model
  22. Pass:3, Batch:0, Cost:0.10000, Accuracy:0.96875
  23. Pass:3, Batch:100, Cost:0.12072, Accuracy:0.96094
  24. Pass:3, Batch:200, Cost:0.05219, Accuracy:0.97656
  25. Pass:3, Batch:300, Cost:0.09506, Accuracy:0.97656
  26. Pass:3, Batch:400, Cost:0.13533, Accuracy:0.95312
  27. Test:3, Cost:0.09648, Accuracy:0.96915
  28. save models to ./hand.inference.model
  29. Pass:4, Batch:0, Cost:0.18229, Accuracy:0.95312
  30. Pass:4, Batch:100, Cost:0.12360, Accuracy:0.96875
  31. Pass:4, Batch:200, Cost:0.10631, Accuracy:0.96875
  32. Pass:4, Batch:300, Cost:0.10291, Accuracy:0.97656
  33. Pass:4, Batch:400, Cost:0.06706, Accuracy:0.97656
  34. Test:4, Cost:0.07825, Accuracy:0.97498
  35. save models to ./hand.inference.model

可以看到最终损失率在不断减小,准确率在不断接近1

4.模型预测

在预测之前,要对图像进行预处理,处理方式要跟训练时的一样

首先进行灰度化,然后压缩图像大小为28*28,接着将图像转换成一维向量,最后对一维向量进行归一化处理

  1. def load_image(file):
  2. #将RGB转化为灰度图像,L代表灰度图像,灰度图像的像素值在0~255之间
  3. im = Image.open(file).convert('L')
  4. im = im.resize((28,28), Image.ANTIALIAS) #resize image with high-quality 图像大小为28*28
  5. #返回新形状的数组,把它变成一个 numpy 数组以匹配数据馈送格式。
  6. im = np.array(im).reshape(1,1,28,28).astype(np.float32)
  7. im = im / 255.0 * 2.0 -1.0 #归一化到【-1~1】之间
  8. print(im)
  9. return im
  10. img = Image.open('./6.png')
  11. plt.imshow(img) #根据数组绘制图像
  12. plt.show() #显示图像

运行一下看下结果:原图像随便用画图板画的

已经转换成功,将刚三行代码注释掉

重新定义一个CPU解析器并预测作用域

  1. infer_exe = fluid.Executor(place)
  2. inference_scope = fluid.core.Scope() #预测作用域

加载数据并开始预测

  1. with fluid.scope_guard(inference_scope):
  2. #获取训练好的模型
  3. #从指定目录中加载 推理model(inference model)
  4. [inference_program,#推理Program
  5. feed_target_names,#是一个str列表,它包含需要在推理 Program 中提供数据的变量的名称。
  6. fetch_targets #fetch_targets:是一个 Variable 列表,从中我们可以得到推断结果。
  7. ] = fluid.io.load_inference_model(
  8. model_save_dir,#model_save_dir:模型保存的路径
  9. infer_exe)#infer_exe: 运行 inference model的 executor
  10. img = load_image('./6.png')
  11. results = exe.run(program=inference_program,#运行推测程序
  12. feed={feed_target_names[0]:img}, #喂入要预测的img
  13. fetch_list=fetch_targets) #得到推测结果

获取概率最大的label

  1. lab = np.argsort(results)#argsort函数返回的是result数组值从小到大的索引值
  2. print("该图片的预测结果的label是:%d" % lab[0][0][-1]) #-1代表读取数组中倒数第一列

开始运行:打印结果为

该图片的预测结果的label是:0

又运行了几次,有1,有3,就是没有6......

换张图片试一下,这次用粗笔画的6,最后...

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/339436
推荐阅读
相关标签
  

闽ICP备14008679号