当前位置:   article > 正文

scikit-learn——机器学习应用开发的步骤_scikit-learn==0.21.0

scikit-learn==0.21.0

本篇笔记是关于对机器学习应用开发的步骤的理解。

scikit-learn 简介:

scikit-learn是一个开源的 Python 机器学习工具包,它涵盖了几乎所有主流机器学
习算法的实现,并且提供了一致的调用接口。它基于 Numpy scipy等Python 数值计算
库,提供了高效的算法实现。

(一)数据采集和标记

实现一个程序,需要先采集数据并且尽可能多的采集不同的数据(防止偶然性,使得数据具有代表性),然后对数据进行标记。

(二)特征选择

选择合适的特征,将数据保存为样本个数×特征个数格式。

(三)数据清洗

在采集数据完后,为了减少计算量,也为了模型的稳定性,我们需要对数据进行数据清洗,即把采集到的、不适合用来做机器学习训练的数据进行预处理,从而转化为适合机器学习的数据。

(四)模型选择

对于不同的数据集,选择不同的模型有不同的效率。因此在选择模型要考虑很多的因素,从众多的因素中找到一个最适合模型,同时这个模型要使结果模拟评分达到最高。

(五)模型训练

在进行模型训练之前,要将数据集划分为训练数据集和测试数据集,再利用划分好的数据集进行模型训练,最后得到训练出来的模型参数。

(六)模型测试

用上面训练出来的模型预测测试数据集,把预测结果 Ypred 真正的结果 Ytest 比较,看有多少个是正确的,这样就能评估出模型的准确度了。
scikit-learn 提供了现成的方法来完成这项工作:clf .score (Xtest , Ytest)

(七)模型保存与加载

当我们训练出一个满意的模型后可以将模型进行保存,这样当我们再一次需要使用此模型时可以直接利用此模型进行预测,不用再一次进行模型训练。

(八)实例

回顾前面介绍的机器学习应用开发的典型步骤,我们使用scikit-learn完成一个手写数字识别的例子,这是一个有监督的学习,数据是标记过的手写数字的图片即通过采集足够多的手写数字样本数据,选择合适的模型,并使用来集到的数据进行模型训练,最后验证手写识别程序的正确性(模型测试)
1.数据采集和标记
如果我们从头实现一个数字手写识别的程序,需要先采集数据,即让尽量多不同书写习惯的用户,写出从0~9的所有数字,然后把用户写出来的数据进行标记,即用户每写出一个数字,就标记他写出的是哪个数字。
scikit-learn 自带了一些数据集,其中一些是手写
数字识别图片的数据,使用以下代码来加载数据

# 导入库
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np


"""
sk-learn库中自带了一些数据集
此处使用的就是手写数字识别图片的数据
"""
# 导入sklearn库中datasets模块
from sklearn import datasets
# 利用datasets模块中的函数load_digits()进行数据加载
digits = datasets.load_digits() 


# 把数据所代表的图片显示出来
images_and_labels = list(zip(digits.images, digits.target))
plt.figure(figsize=(8, 6))
for index, (image, label) in enumerate(images_and_labels[:8]):
    plt.subplot(2, 4, index + 1)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Digit: %i' % label, fontsize=20);
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

运行结果:
在这里插入图片描述

2.特征选择
将数据保存为样本个数×特征个数格式。

# 将数据保存为 样本个数x特征个数 格式的array对象 的数据格式进行输出
# 数据已经保存在了digits.data文件中
print("shape of raw image data: {0}".format(digits.images.shape))
print("shape of data: {0}".format(digits.data.shape))
  • 1
  • 2
  • 3
  • 4

运行结果:

shape of raw image data: (1797, 8, 8)
shape of data: (1797, 64)

3.模型训练
(此处,我们使用支持向量机来作为手写识别算法的模型)
在开始训练我们的模型之前,需要先把数据集分成训练数据集和测试数据集。接着,使用训练数据集 Xtrain和Ytrain 来训练模型。

# 把数据分成训练数据集和测试数据集(此处将数据集的百分之二十作为测试数据集)
from sklearn.model_selection import train_test_split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(digits.data, digits.target, test_size=0.20, random_state=2);

# 训练完成后clf对象就会包含我们训练出来的模型参数,可以使用这个模型对象来进行预测
# 使用支持向量机来训练模型
from sklearn import svm
clf = svm.SVC(gamma=0.001, C=100., probability=True)
# 使用训练数据集Xtrain和Ytrain来训练模型
clf.fit(Xtrain, Ytrain);
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

训练完成后, elf 对象就会包含我们训练出来的模型参数,可以使用这个模型对象来进行预测。
4.模型测试
用上面训练出来的模型预测测试数据集,把预测结果 Ypred 真正的结果 Ytest 比较,看有多少个是正确的,这样就能评估出模型的准确度了。

# 评估模型的准确度(此处默认为true,直接返回正确的比例,也就是模型的准确度)
from sklearn.metrics import accuracy_score
# predict是训练后返回预测结果,是标签值。
Ypred = clf.predict(Xtest);
accuracy_score(Ytest, Ypred)
  • 1
  • 2
  • 3
  • 4
  • 5

运行结果:

0.9777777777777777

# 用训练好的模型在测试集上进行评分(0~11分代表最好
clf.score(Xtest, Ytest)
  • 1
  • 2

运行结果:

0.9777777777777777

除此之外,还可以直接把测试数据集里的部分图片显示出来,并且在图片的左下角显示预测值,右下角显示真实值。

"""
将测试数据集里的部分图片显示出来
图片的左下角显示预测值,右下角显示真实值
"""
# 查看预测的情况
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)

for i, ax in enumerate(axes.flat):
    ax.imshow(Xtest[i].reshape(8, 8), cmap=plt.cm.gray_r, interpolation='nearest')
    ax.text(0.05, 0.05, str(Ypred[i]), fontsize=32,
            transform=ax.transAxes,
            color='green' if Ypred[i] == Ytest[i] else 'red')
    ax.text(0.8, 0.05, str(Ytest[i]), fontsize=32,
            transform=ax.transAxes,
            color='black')
    ax.set_xticks([])
    ax.set_yticks([])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

运行结果:
在这里插入图片描述
从中可以看二行第一个图片预测错了,真实的数字是4,但预测成了8。

5.模型保存与加载
当我们对模型的准确度感到满意后,就可以把模型保存下来。

# 保存模型参数
from sklearn.externals import joblib
joblib.dump(clf, 'digits_svm.pkl');
  • 1
  • 2
  • 3

当需要这个模型来进行预测时,直接加载模型即可进行预测

# 导入模型参数,直接进行预测
clf = joblib.load('digits_svm.pkl')
Ypred = clf.predict(Xtest);
clf.score(Xtest, Ytest)
  • 1
  • 2
  • 3
  • 4
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/294381
推荐阅读
相关标签
  

闽ICP备14008679号