当前位置:   article > 正文

[Python] scikit-learn - K近邻算法介绍和使用案例_使用python sklearn构建k近邻回归和分类 statlib 的加州房产价格数据集

使用python sklearn构建k近邻回归和分类 statlib 的加州房产价格数据集

什么是K近邻算法?

K近邻算法(K-Nearest Neighbors,简称KNN)是一种基于实例的学习方法,主要用于分类和回归任务。它的基本思想是:给定一个训练数据集,对于一个新的输入实例,在训练数据集中找到与该实例最邻近的K个实例,这K个实例的多数类别就是该输入实例的类别。

思路:

  1. 计算输入实例与训练数据集中每个实例之间的距离。
  2. 对距离进行排序,找到距离最近的K个实例。
  3. 根据这K个实例的类别进行投票,得到输入实例的类别。

K近邻算法使用场景和注意事项

K近邻算法(K-Nearest Neighbors,简称KNN)是一种基于实例的学习方法,主要用于分类和回归任务。它的使用场景包括:

  1. 数据集较小的情况:当数据集较小时,KNN算法可以快速地进行训练和预测,而不需要大量的计算资源。
  2. 数据集中存在噪声的情况:由于KNN算法是基于实例的,因此它对数据集中的噪声具有一定的容忍度。
  3. 数据集中存在异常值的情况:KNN算法在处理异常值时,会根据邻近实例的类别来进行投票,从而降低了异常值对结果的影响。
  4. 数据集中存在不平衡类别的情况:KNN算法在处理不平衡类别的数据集时,可以通过调整K值来平衡各个类别之间的样本数量。

在使用KNN算法时,需要注意以下几点:

  1. 选择合适的K值:K值的选择对算法的性能有很大影响。通常情况下,可以通过交叉验证等方法来选择合适的K值。
  2. 特征选择:KNN算法对特征的数量和质量要求较高,因此需要对特征进行选择和预处理,以提高算法的性能。
  3. 距离度量:KNN算法需要计算实例之间的距离,因此需要选择合适的距离度量方法,如欧氏距离、曼哈顿距离等。
  4. 性能评估:为了确保算法的性能,需要对算法进行性能评估,如准确率等指标。

K近邻算法python实现

  1. from sklearn.datasets import load_iris
  2. from sklearn.model_selection import train_test_split
  3. from sklearn.metrics import accuracy_score
  4. import numpy as np
  5. from collections import Counter
  6. def euclidean_distance(x1, x2):
  7. # 计算欧氏距离
  8. return np.sqrt(np.sum((x1 - x2) ** 2))
  9. class KNN:
  10. def __init__(self, k=3):
  11. self.k = k
  12. def fit(self, X, y):
  13. self.X_train = X
  14. self.y_train = y
  15. def predict(self, X):
  16. y_pred = [self._predict(x) for x in X]
  17. return np.array(y_pred)
  18. def _predict(self, x):
  19. # 计算输入实例与训练数据集中每个实例之间的距离
  20. distances = [euclidean_distance(x, x_train) for x_train in self.X_train]
  21. # 对距离进行排序,找到距离最近的K个实例的索引
  22. k_indices = np.argsort(distances)[:self.k]
  23. # 根据这K个实例的类别进行投票,得到输入实例的类别
  24. k_nearest_labels = [self.y_train[i] for i in k_indices]
  25. most_common = Counter(k_nearest_labels).most_common(1)
  26. return most_common[0][0]
  27. data = load_iris()
  28. X, y = data.data, data.target
  29. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  30. knn = KNN(k=3)
  31. knn.fit(X_train, y_train)
  32. predictions = knn.predict(X_test)
  33. print("Accuracy:", accuracy_score(y_test, predictions))

scikit-learn中的K近邻算法

K近邻算法用于分类任务

sklearn.neighbors.KNeighborsClassifier — scikit-learn 1.4.0 documentation

 

  1. from sklearn.datasets import load_iris
  2. from sklearn.model_selection import train_test_split
  3. from sklearn.metrics import accuracy_score
  4. from sklearn.neighbors import KNeighborsClassifier
  5. data = load_iris()
  6. X, y = data.data, data.target
  7. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  8. knc = KNeighborsClassifier(n_neighbors=3)
  9. knc.fit(X_train, y_train)
  10. predictions = knc.predict(X_test)
  11. print("Accuracy:", accuracy_score(y_test, predictions))

在这个示例中,我们首先从scikit-learn库中加载了iris花卉数据集,并将其划分为训练集和测试集。然后,我们创建了一个KNeighborsClassifier对象,并设置了K值为3。接下来,我们使用训练集对模型进行训练,并使用测试集进行预测。最后,我们计算了预测结果的准确度。 

K近邻算法用于回归任务

sklearn.neighbors.KNeighborsRegressor — scikit-learn 1.4.0 documentation

 

  1. from sklearn.datasets import load_iris
  2. from sklearn.model_selection import train_test_split
  3. from sklearn.neighbors import KNeighborsRegressor
  4. from sklearn.metrics import mean_squared_error
  5. # 加载iris花卉数据集
  6. data = load_iris()
  7. X = data.data
  8. y = data.target
  9. # 将数据集划分为训练集和测试集
  10. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  11. # 创建KNeighborsRegressor对象,设置K值为3
  12. knn = KNeighborsRegressor(n_neighbors=3)
  13. # 使用训练集对模型进行训练
  14. knn.fit(X_train, y_train)
  15. # 使用测试集进行预测
  16. y_pred = knn.predict(X_test)
  17. # 计算预测结果的均方误差
  18. mse = mean_squared_error(y_test, y_pred)
  19. print("均方误差:", mse)

在这个示例中,我们首先从scikit-learn库中加载了iris花卉数据集,并将其划分为训练集和测试集。然后,我们创建了一个KNeighborsRegressor对象,并设置了K值为3。接下来,我们使用训练集对模型进行训练,并使用测试集进行预测。最后,我们计算了预测结果的均方误差。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/549049
推荐阅读
相关标签
  

闽ICP备14008679号