赞
踩
作者:LeonG
本文参考自:《Hands-On Machine Learning with Scikit-Learn & TensorFlow 机器学习实用指南》,感谢中文AI社区ApacheCN提供翻译。本文全部代码和数据集保存在我的github-----LeonG的github
机器学习的监督学习任务中最常见的任务是回归(用于预测某个值)和分类(预测某个类别)。
在这一章,我们重点来学习分类任务。
Scikit-Learn包是一个集成了大部分机器学习经典算法的python库,它可以帮助我们快速上手机器学习的众多应用。
注意:本项目主要使用Jupyter Notebook进行开发,Jupyter Notebook是一个简便易上手的python工作环境,源代码请使用Jupyter Notebook打开。
首先给大家介绍一下机器学习的经典数据集–MNIST,它是拥有70000张手写数字图片的数据集合。这个数据集经典程度相当于机器学习领域的“Hello World”。
from sklearn.datasets import fetch_mldata
#先将mnist数据集下载到项目目录(./mldata/mnist-original.mat)
#下载地址:https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat
mnist = fetch_mldata('MNIST Original',data_home='./')
mnist
介绍一下这三个标签的意义:
DESCR
键描述数据集
target
键存放一个标签数组
data
键存放一个图片数组,数组的一行表示一个样例,也就是一张图片
然后看看这个数据集的结构,主要是查看维度
x,y = mnist["data"],mnist["target"]
x.shape
y.shape
现在我们总结一下MNIST数据集的特点:
MNIST有70000张图片,众所周知,图片是由像素点构成的。
784意味着每张图片拥有784个像素点,这是因为每张图片都是28*28像素的,且每个像素点都介于0~255之间。
target标签表示对应位置的图片是什么数字。
我相信每个接触到MNIST数据集的人都迫不及待地想看一下图片是什么样子了,说实话我也是(搓手手),让我们随便展示一张图片。
import matplotlib
from matplotlib import pyplot as plt
some_digit = x[36000] #第36001张图片
img = some_digit.reshape(28,28) #还原成28*28的结构
plt.imshow(img,cmap = matplotlib.cm.binary, interpolation="nearest")
plt.show()
y[36000] #顺便看一下标签
哟,不错,整挺好,可以看到图片上的数字和标签正好对应起来了,有了这样一个利器,我们就能慢慢揭开机器学习分类任务的神秘面纱了。
少侠稍安勿躁,在进行下一步之前,我们总是要先将数据集分为训练集和测试集两个部分,最好还能打乱一下顺序,因为有的算法对顺序的敏感度很高。
#分割数据集
x_train,x_test,y_train,y_test = x[:60000],x[60000:],y[:60000],y[60000:]
#打乱训练集
import numpy as np
shuffie_index = np.random.permutation(60000)
x_train,y_train = x_train[shuffie_index],y_train[shuffie_index]
接下来正式开始学习分类任务。
鲁迅曾经说过:学走先学爬(鲁迅:我没说过),多分类属于比较高级的分类任务,我们先做个简单的,比如鉴定一张图片是否是数字5
这个“数字 5 检测器”就是一个二分类器,能够识别两个类,“是5”和“非5”。让我们为这个分类任务创建新的标签,也就是将所有数字标签转为是和否两种标签(bool类型)
y_train_5 = (y_train == 5) #训练集的标签
y_test_5 = (y_test == 5) #测试集的标签
接下来就是选择一个分类器了,给大家介绍一下本次的主角SGD(随机梯度下降分类器)。其实随机梯度下降分类器并不是一个独立的算法,而是一系列利用随机梯度下降求解参数的算法集合。
这个分类器默认用的算法是SVM(线性支持向量机),既然是线性的,自然就适合二分类方法,毕竟非黑即白嘛。而且该分类器的好处是处理大量数据时非常高效,让我们先创建一个SGDClassifier
分类器然后训练一遍。
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(x_train,y_train_5) #训练分类器
这里的代码很常用所以解释一下,首先调用SGDClassifier函数新建了一个分类器,随机种子设为42,算法中需要的参数都是按照某种方式随机生成的,然后调用fit训练函数,输入训练数据,通过梯度下降法训练参数,最后得到的sgd_clf是训练好的分类器
现在随便测试一张图片:
sgd_clf.predict(X[36000])
array([ True], dtype=bool)
分类器猜测这个数字代表 5( True )。看起来在这个例子当中,它猜对了。现在让我们评估 这个模型的性能。
首先我们应该使用交叉验证法,具体原理参考上一章中提到的K这交叉验证法。在这里我们只管调用函数 cross_val_score
就好了。
K = 3意味着将把训练集分成3折,使用其他折进行训练,剩下的折用于测试精确度。
from sklearn.model_selection import cross_val_score
#cv代表折数,scoring为计算方式,这里计算精准度
cross_val_score(sgd_clf,x_train,y_train_5,cv = 3,scoring="accuracy")
array([ 0.9502 , 0.96565, 0.96495]
看起来准确率非常高,但是并不是这样的,想象一下,假如你全猜“非5”,准确率依然有90%!这是因为图片5的数量太少了,只占总数据的十分之一。所以准确率并不能很好的表达模型的精度。
那怎么样才能判断模型的精准度呢,现在我们打开新世界的大门。
对于分类器来说,混淆矩阵是个不错的判断精度工具,什么是混淆矩阵呢?
大致的意思就是类别A被判断为B的次数,我们做一个简单的表格就能看懂了,其中加粗的对角线就是正确分类的数量。
首先,我们使用K折交叉验证得出一系列的预测,也就是模型得出的预测值。使用函数cross_val_predict
,注意,有别于上面提到的 cross_val_score
,那个是分数,这个是预测值,然后使用函数confusion_matrix
得到混淆矩阵,是不是很简单?
from sklearn.model_selection import cross_val_predict
#计算预测值
y_train_pred = cross_val_predict(sgd_clf,x_train,y_train_5,cv = 3)
from sklearn.metrics import confusion_matrix
#计算混淆矩阵
confusion_matrix(y_train_5,y_train_pred)
array([[53887, 692],
[ 1279, 4142]], dtype=int64)
来,总结一下这个矩阵中四个值代表的意义:(正例:是5,反例:非5)
53887:代号TN,全名true negative,真反例,就是反例分对了的意思。
1279:代号FN,全名false negative,假反例,就是错分为反例的意思。
692:代号FP,全名fasle positive,假正例,就是错分为正例的意思。
4142:代号TP,全名true positive,真正例,就是正例分对了的意思。
好了,我不是有意要绕晕你的,如果你没看懂,直接看图啦
形象生动(并没有)
基于以上的划分方式,我们要引出准确率、召回率这对欢喜冤家。
准确率: p r e c i s i o n = T P T P + F P precision = {TP\over TP+FP} precision=TP+FPTP
召回率:$recall = {TP\over TP+FN} $
准确率的意义很简单,当模型给出一个预测时,该预测的可靠程度。
召回率是指所有正例中,被正确预测的个数。
我们测试一下刚才建立的分类模型,准确率的计算函数是precision_score
,召回率是recall_score
。
from sklearn.metrics import precision_score,recall_score
precision_score(y_train_5,y_train_pred)
recall_score(y_train_5,y_train_pred)
0.8568473314025652
0.7640656705404907
也就是说,该模型有85.6%的几率预测正确,但是所有的图片5,只有76%被正确识别了。
而准确率和召回率是一对难以调和的冤家,因为当你提高准确率,也就是更加严格的判断是否为5,可能很多的图片5被误杀,导致召回率直线下降。
通俗点说:
准确率高:宁缺毋滥(判断正例的数据少,但是准度高)
召回率高:宁可错杀一千,不可放过一个(判断正例的数据多,但是准度低)
其中意味,需要你自己细细体会,这里将准确率和召回率呈现的趋势展示给大家看:
提高阈值,准确率上升,召回率下降,反之准确率下降,召回率提升,那么最好的位置就是两者都在0.8左右。
两者无法兼得,所以提出新的判断标准,F1值,近似为两者的平均值。但是计算方式考虑到了较小的值有更大的权重。
F 1 = 2 1 p r e c c s i o n + 1 r e c a l l F1 = {2\over {1\over preccsion}+{1\over recall}} F1=preccsion1+recall12
看不懂公式不重要,只要知道F1代表的是两者的平均值就好,咱们直接调函数f1_score
from sklearn.metrics import f1_score
f1_score(y_train_5, y_train_pred)
0.78468208092485547
有时候你需要很高的准确率,比如判断水果是否变质,我们只需要输出的水果都是好水果就行了。
有时候你需要很高的召回率,比如判断背包内是否携带易燃易爆物品,即使是错误识别了某些背包,也不能让危险的背包通过。
有时候你需要兼顾二者,就着力于提高F1值。
那么如何调整准确率和召回率呢?最直接的办法就是调整阈值了,简单地说就是调整判断图片是否为5的标准。
一般来说,分类器会得出一张图片是正例(是5)的得分值,我们只要调整得分值的判断标准,就相当于调整了了阈值。
调用decision_function
方法,这个方法返回了每个样例的得分,然后基于这个得分,你可以任意的调整阈值。
计算所有的样例得分值方法很简单,把交叉验证函数中的method改为decision_function
就可以了:
y_scores = cross_val_predict(sgd_clf, x_train, y_train_5, cv=3,method="decision_function")
下面我们单独对第36001张图片进行判断,阈值设为0,这张图我们在前面已经看到了,是图片5
y_score = sgd_clf.decision_function([x[36000]]) #第36001张图片的得分
y_score
threshold = 0 #阈值设为0
y_some_digit_pred = (y_score > threshold) #分类
y_some_digit_pred
array([1066.83525987])
array([ True])
当阈值等于0的时候,这张图片被判断为正例(是5)
阈值对准确率和召回率的影响在前面那张曲线图已经体现出来了, 你可以翻上去看看。
全称受试者工作特征曲线(ROC)是二分类器中常用的工具,这个曲线能反应预测为正例的数据中正确数据的比例,计算公式也很简单: R O C = P ( T P ) P ( F P ) ROC = {P(TP) \over P(FP)} ROC=P(FP)P(TP)
TP是真正例,也就是把图片5预测成5的数量
FP是假正例,也就是把不是图片5预测成5的数量
函数P就是对应的数量除以样本总数,也就是占比的意思。
再让大家复习一下这张图:
从下图的ROC曲线来看,在曲线最左上角的位置,大概是坐标(0.05,0.9)的位置。在左上角位置的时候,被错分为正例的数量最少,效果也是最好的。
调用函数roc_auc_score
可以很快的算出roc值。
from sklearn.metrics import roc_auc_score
roc_auc_score(y_train_5, y_scores)
0.9659714668088117
最后为性能评估做个总结,分类的任务远不止二分类,还有多分类方法,但是分类器的性能评估基本上都是用这些工具。
检验自己的分类器性能一般可以使用交叉验证来评估你的分类器,然后选择满足你需要的准确率/召回率位置或者是最佳ROC位置,找到合适的阈值点。至于选择哪种性能评估工具,取决于你的分类器更注重哪方面的性能。
关于机器学习的分类任务上半部分就讲完了,这一章我们主要学习了如何训练一个二分类器,如何评估一个分类器的性能,下一章我们继续学习分类任务,不过主要是侧重于多分类任务~
欢迎来我的博客留言讨论,我的博客主页:LeonG的博客
我的知乎机器学习专栏:LeonG与机器学习
本文参考自:《Hands-On Machine Learning with Scikit-Learn & TensorFlow机器学习实用指南》,感谢中文AI社区ApacheCN提供翻译。
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。