(Introduction to Machine Learning with Python)
[德] Andreas C.Müller [美] Sarah Guido 著 张亮(hysic)译
import pandas as pd
import mglearn
首先是在进行import mglearn
时出现的future warning
\anaconda\lib\site-packages\sklearn\externals\six.py:31: FutureWarning: The module is deprecated in version 0.21 and will be removed in version 0.23 since we've dropped support for Python 2.7. Please rely on the official version of six (https://pypi.org/project/six/).
"(https://pypi.org/project/six/).", FutureWarning)
\anaconda\lib\site-packages\sklearn\externals\joblib\__init__.py:15: FutureWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.
warnings.warn(msg, category=FutureWarning)
grr = pd.plotting.scatter_matrix(iris_dataFrame, c=y_train, figsize=(15, 15), marker=‘o’,
hist_kwds={‘bins’: 20}, s=60, alpha=.8, cmap=mglearn.cm3)
需要导入一个包import matplotlib as plt
# 文件名 test.py # 导入包 import pandas as pd import matplotlib.pyplot as plt import mglearn # 随机分割数据集、分为训练集和测试集的函数 from sklearn.model_selection import train_test_split # sklearn自带的数据集 from sklearn.datasets import load_iris # 载入数据集 iris_dataset = load_iris() # 随机分割数据集【因为数据集原本是按照target顺序排列的】 ''' Target: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] ''' X_train, X_test, y_train, y_test = train_test_split( iris_dataset['data'],iris_dataset['target'], random_state=0 ) # 将numpy数组转换成pandas dataFrame类型 iris_dataFrame = pd.DataFrame(X_train, columns=iris_dataset.feature_names) display(iris_dataFrame) # 此处可以打印查看一下,记得要 `from IPython.display import display` # display(iris_dataFrame) # 调用函数 scatter_matrix,绘制散点图矩阵 grr = pd.plotting.scatter_matrix(iris_dataFrame, c=y_train, figsize=(15, 15), marker='o', hist_kwds={'bins': 20}, s=60, alpha=.8, cmap=mglearn.cm3) plt.show() # KNN算法对未知分类的花分类 from sklearn.neighbors import KNeighborsClassifier # 只考虑一位邻居 ——如果多位邻居,把参数n_neighors改掉就行 knn = KNeighborsClassifier(n_neighbors=1) # 训练模型 knn.fit(X_train,y_train) # 尝试预测新的种类 import numpy as np X_new = np.array([[5, 2.9, 1, 0.2]]) print("X_new.shape: {}".format(X_new.shape)) '''X_new.shape: (1, 4) ''' # shape必须符合X_test # 例如 # shape of data: (150, 4) # 因此一个元组的shape为(1, 4) # 尝试调用 knn 对象的 predict 方法来进行预测 prediction = knn.predict(X_new) print("Prediction: \n{}".format(prediction)) print("Prediction target name :\n {}".format(iris_dataset['target_names'][prediction])) ''' Prediction: [0] Prediction target name : ['setosa'] # 预测值为0 ,对应种类为setosa ''' # 评估模型 y_pre = knn.predict(X_test) corr_rate = np.mean(y_pre == y_test) print("Test set score : {:.2f}".format(corr_rate)) ''' Test set score : 0.97 '''
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。