当前位置:   article > 正文

机器学习入门实例-MNIST手写数据集-简单探索&二分分类_mnist_784

mnist_784

MNIST数据集介绍

MNIST数据集包含7w张带标签的手写数字图片。每次有新的分类算法出现时,常常会在改数据集测试效果。

from sklearn.datasets import fetch_openml

# 获取的mnist是一个字典
mnist = fetch_openml('mnist_784', version=1)
print(mnist.keys())
# dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])
# DESCR是描述信息,data是数据集,target是标签

X, y = mnist["data"], mnist["target"]
print(X.shape)
print(y.shape)
# (70000, 784)
# (70000,)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

X.shape表示一共7w张图,每个有784个特征。每个特征是28 x 28 像素中的一个点的数值,在0(白)~ 255(黑)之间。

查看其中一个图:

import matplotlib.pyplot as plt
print(y[5])
print(type(y[5]))
some_digit = X[5]
some_digit_image = some_digit.reshape(28, 28)
# cmp表示颜色映射,即实数值通过什么方法转成RGB图像。常用的还有'viridis'(很多颜色细节)、
# 'gray'(适合灰度图像)
plt.imshow(some_digit_image, cmap="binary")
plt.axis("off")
plt.show()

# 输出: 2
# <class 'str'>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

在这里插入图片描述
注意,这段代码很可能报错如下:

Traceback (most recent call last):
  File "D:\Program Files\Anaconda3\lib\site-packages\pandas\core\indexes\base.py", line 3621, in get_loc
    return self._engine.get_loc(casted_key)
  File "pandas\_libs\index.pyx", line 136, in pandas._libs.index.IndexEngine.get_loc
  File "pandas\_libs\index.pyx", line 163, in pandas._libs.index.IndexEngine.get_loc
  File "pandas\_libs\hashtable_class_helper.pxi", line 5198, in pandas._libs.hashtable.PyObjectHashTable.get_item
  File "pandas\_libs\hashtable_class_helper.pxi", line 5206, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 0

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\Users\xxxx\Desktop\study\classification.py", line 24, in <module>
    plt.imshow(X[0], cmap="gray")
  File "D:\Program Files\Anaconda3\lib\site-packages\pandas\core\frame.py", line 3505, in __getitem__
    indexer = self.columns.get_loc(key)
  File "D:\Program Files\Anaconda3\lib\site-packages\pandas\core\indexes\base.py", line 3623, in get_loc
    raise KeyError(key) from err
KeyError: 0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

解决:获取数据集时添加参数 as_frame=False。 这个表示以原格式返回。

mnist = fetch_openml('mnist_784', version=1, as_frame=False)
  • 1

保存数据集到本地&导入

pickle库可以序列化任何Python对象,所以可以用它保存数据集到本地。

from sklearn.datasets import fetch_openml
import pickle

# 获取数据集
# False表示以原始格式返回,每个特征是一个单独的数组。True表示返回Pandas
# DataFrame对象
# 自0.24.0(2020 年 12 月)以来,as_frame参数为auto(而不是之前的False默认选项)
mnist = fetch_openml('mnist_784', version=1, as_frame=False)

# 保存数据集到本地
with open('mnist_data.pkl', 'wb') as f:
    pickle.dump(mnist, f)

# 从本地导入数据集
with open('mnist_data.pkl', 'rb') as f:
    mnist = pickle.load(f)
    X, y = mnist["data"], mnist["target"]

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

划分训练集和测试集

MNIST已经划分好了,前6w个是训练集,后1w个是测试集。而且已经打乱过顺序了。

通过第一节我们已经知道,y的所有元素都是字符串。因为很多算法的预测结果都是数字,将标签转为数字也有助于计算error,所以使用astype(np.uint8)将y里所有元素转为8位无符号整数。

# 将数组中所有元素转为8位无符号整数
y = y.astype(np.uint8)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
  • 1
  • 2
  • 3
  • 4

Binary Classifier

Binary Classifier就是分两类。比如以数字 2 为例,我们训练一个分类器,将图片分成是2的和不是2的。

这里使用一个Stochastic Gradient Descent (SGD,随机梯度下降)分类器。这个适合高效处理较大的数据集,而且每个训练实例是单独处理的,一次一个,所以可以online learning。

y_train_2 = (y_train == 2)
y_test_2 = (y_test == 2)

from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_2)
some_digit = X[100]
print(sgd_clf.predict([some_digit]))
print(y[100])

some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap="binary")
plt.axis("off")
plt.show()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

其中,为了使结果可以复现,设置了random_state=42。
输出为:

[False]
5
  • 1
  • 2

在这里插入图片描述
因此可以知道,分类正确。

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号