赞
踩
K近邻分类算法是机器学习常用的算法之一,很多机器学习库都提供了具体实现,可以直接调用相应的方法。下面用C语言实现该算法,用了一点C++的输入输出语句。
算法思想:现在要预测某个样本的分类结果,首先到预测样本附近找K个最近邻居,然后统计这些邻居的分类,出现频率最高的分类就是预测分类结果。
数据集:海伦约会数据集。每行有4个数据,前3个是样本特征(飞行公里数,玩游戏时间百分比,吃冰激凌公升数),最后1个是分类结果(很喜欢,有点喜欢,讨厌)。原始数据的分隔符是制表符\t。
C/C++程序如下,函数形参列表有很多数组,站在C语言语法角度,[ ]里的数值都是没用的,目的只是方便知道数组里数据个数。另外,偷了个懒,没有动态申请内存,因为样本数据量不大,栈空间基本够用。
原始数据有1000个样本,选择10%的数据(前100个样本)作为测试数据。代码框架如下(省略了函数的具体实现),主要包括:读txt文件,数据归一化处理,计算测试样本预测准确率,对某个样本进行分类预测。
- /k近邻分类器
-
- #include <windows.h>
- #include <iostream>
- using namespace std;
-
- #define MaxDataNum 1000
- #define FileMaxCol 100
- #define FeatureNum 3
- #define ClassNum 3
- #define K 7
- #define TestDataRatio 0.1
-
- int FileToArray(const char fileName[], double featureData[MaxDataNum][FeatureNum], int classLabel[MaxDataNum]);
- void Normalization(double featureData[MaxDataNum][FeatureNum], double normFeatureData[MaxDataNum][FeatureNum], double range[FeatureNum], double minVal[FeatureNum]);
- int Classify(double normOnePersonData[FeatureNum], double normTrainingFeatureData[int(MaxDataNum * (1.0 - TestDataRatio))][FeatureNum], const int trainingClassLabel[int(MaxDataNum * (1.0 - TestDataRatio))]);
- void ClassifyTest(double normFeatureData[MaxDataNum][FeatureNum], const int classLabel[MaxDataNum]);
- int ClassifyOnePerson(double normOnePersonData[FeatureNum], double normFeatureData[int(MaxDataNum * (1.0 - TestDataRatio))][FeatureNum], const int trainingClassLabel[int(MaxDataNum * (1.0 - TestDataRatio))]);
-
- int main()
- {
- clock_t t1 = clock();
-
- char fileName[] = "e:\\helenData.txt";
-
- double featureData[MaxDataNum][FeatureNum];
- int classLabel[MaxDataNum];
- double normFeatureData[MaxDataNum][FeatureNum];
- double range[FeatureNum];
- double minVal[FeatureNum];
-
- double onePersonData[FeatureNum] = { 0.0 };
- double normOnePersonData[FeatureNum] = { 0.0 };
- int testDataNum = int(MaxDataNum * TestDataRatio);
- double normTrainingFeatureData[int(MaxDataNum * (1.0 - TestDataRatio))][FeatureNum];
- int trainingClassLabel[int(MaxDataNum * (1.0 - TestDataRatio))];
-
- FileToArray(fileName, featureData, classLabel);
- Normalization(featureData, normFeatureData, range, minVal);
- ClassifyTest(normFeatureData, classLabel);
-
- for (int i = testDataNum; i < MaxDataNum; i++)
- {
- for (int j = 0; j < FeatureNum; j++)
- {
- normTrainingFeatureData[i - testDataNum][j] = normFeatureData[i][j];
- }
- trainingClassLabel[i - testDataNum] = classLabel[i];
- }
-
- clock_t t2 = clock();
- cout << t2 - t1 << "毫秒" << endl;
-
- while (true)
- {
- cout << "----------------------------------------------" << endl;
- cout << "输入每年飞行里程数:" << endl;
- cin >> onePersonData[0];
- cout << "输入玩视频游戏所耗时间百分比:" << endl;
- cin >> onePersonData[1];
- cout << "输入每周消费冰激淋公升数:" << endl;
- cin >> onePersonData[2];
-
- for (int i = 0; i < FeatureNum; i++)
- {
- normOnePersonData[i] = (onePersonData[i] - minVal[i]) / range[i];
- }
-
- ClassifyOnePerson(normOnePersonData, normTrainingFeatureData, trainingClassLabel);
- }
-
- return 0;
- }
运行结果如下(1,2,3分别表示讨厌,有点喜欢,很喜欢),100个测试样本错了4个,算法没有问题,一般是样本有问题,比如一个男的,他的3个特征反映出这个男的很优秀,10个女的有9个女的会喜欢这个男的,偏偏有1个女的不喜欢,对于这样一个样本,程序给出的预测结果和真实结果会不一致。
下面用一个冷门的函数式编程语言实现上述算法,使用Spark分布式计算框架。该程序很容易改成鸢尾花预测和手写数字识别程序,只要修改main函数的前5行代码即可。缺点是数据量大的时候,运行速度有点慢,比C语言慢至少2个数量级。
- import org.apache.spark.{SparkConf, SparkContext}
- import org.apache.spark.rdd.RDD
- import scala.collection.mutable.ListBuffer
- import scala.io.Source
- import java.util.Date
-
- object SparkKNN {
- def main(args: Array[String]): Unit = {
- val trainingDataFile="E:\\helenTrainingData.txt"
- val testDataFile="E:\\helenTestData.txt"
- val featureNumber=3
- val splitChar="\t"
- val K=7
-
- val start_time=new Date().getTime
- println("begin")
- var source = Source.fromFile(trainingDataFile, "UTF-8")
- var lines = source.getLines().toArray
- source.close()
-
- val feature_list: Array[ListBuffer[Double]] = new Array[ListBuffer[Double]](featureNumber)
- for(i<-0 until featureNumber){
- feature_list(i)=new ListBuffer[Double]
- }
-
- for (i <- 0 until lines.length) {
- val featureArray = lines(i).trim.split(splitChar)
- for(j<-0 until featureNumber){
- feature_list(j).append(featureArray(j).toDouble)
- }
- }
-
- //部分代码省略
-
- println("--------------------------------")
- println("预测错误次数:"+errorCount)
- println("错误率:"+100*errorCount.toDouble/lines.length+"%")
-
- val end_time=new Date().getTime
- println(end_time-start_time+"毫秒")
- }
- }
运行结果和C程序对比:
如果想知道分类预测错误的原因,可以输出相关信息:k个最近邻居的距离值和分类结果,以及这些分类结果的统计情况(每个分类结果出现了几次),例如:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。