赞
踩
慕课:《深度学习应用开发-TensorFlow实践》
章节:第七讲 MNIST手写数字识别:分类应用入门
TensorFlow版本为2.3
MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。数据集由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口
普查局 (the Census Bureau) 的工作人员。
这一份数据集共有训练集60000个,测试集10000个,下面这张图展示了一小部分
数据集的获取可以直接去网站下:https://s3.amazonaws.com/img-datasets/mnist.npz
当然也可以通过TensorFlow
的代码去获取数据集,不过值得注意的是,TensorFlow1.*
和TensorFlow2.*
的数据集获取方式并不相同,下面只提供TensorFlow2.*
的获取代码
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
tf.__version__
#加载数据集
mnist=tf.keras.datasets.mnist
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
然后他就会去对应的网站上下载对应的数据集,不过有个问题,他的下载地址是对应谷歌的一个地址,因此有可能会出现下载错误或者下载失败或者下载比较的慢,还是建议去上面提供的那个链接下,当然如果有某些方式,那随意哈
如果你选择手动下载,那么请把下载的数据集mnist.npz
存放在用户目录的“.keras/dataset”
子目录下(Windows 下用户目录为 C:\Users\用户名
,Linux 下用户目录为 /home/用户名
)。如果是第一次运行(在用户目录下没有找到数据文件),则会自动先从网络下载后再加载。如果用户目录下已经存在数据文件,则直接加载。
可以来看一下这一个数据集的一些信息
print(f"Train image shape:{train_images.shape} Train label shape:{train_labels.shape}");
print(f"Test image shape:{test_images.shape} Test label shape:{test_labels.shape}");
可以看到,他是28*28
的一个黑白图片,也就是一张图片会有784
个像素点。
可以用matplotlib.pyplot
来看一下他的一个图像
def plot_image(image):
plt.imshow(image.reshape(28,28),cmap='binary')
plt.show()
plot_image(train_images[0])
去第一张图片看一下就长上面这个样子
许多问题的预测结果是一个在连续空间的数值,比如房价预测问题,可以用线性模型来描述:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。