当前位置:   article > 正文

C++实现KNN算法_knn分类 c++代码

knn分类 c++代码

C++实现KNN算法、

/*
 * @Description: C++实现KNN算法
 * @Author: szq
 * @Github: https://github.com/MrQqqq
 * @Date: 2020-07-08 19:13:25
 * @LastEditors: szq
 * @LastEditTime: 2020-07-09 16:50:55
 * @FilePath: \cpp\src\KNN\KNN.cpp
 */ 

#include<iostream>
#include<vector>
#include<fstream>
#include<random>
#include<time.h>
#include<map>
#include<algorithm>
using namespace std;

/**
 * @destription: 分割字符串
 * @param s:源字符
 * @param mode:分割的字符 
 * @return: 字符串分割后的字符串数组
 */
vector<string> split(string &s,char mode){
    vector<string> res;
    while(s.size() > 0){
        int index = s.find(mode);
        if(index != -1){
            res.push_back(s.substr(0,index+1));
            s = s.substr(index + 1);
        }
        else{
            res.push_back(s);
            break;
        }
        
    }
    return res;
}

/**
 * @destription: 获取文件的行数并将每一行的内容按','分割,保存起来
 * @param in:输入流
 * @param lines:保存每一行的结果
 * @return: 返回文件的行数
 */
int getFileRows(ifstream &in,vector<vector<string>> &lines){
    int rows = 0;
    char line[512];
    while(!in.eof()){
        in.getline(line,512,'\n');
        string src = string(line);
        lines.push_back(split(src,','));
        rows++;
    }
    return rows;
}

/**
 * @destription: 按一定的比例划分训练集和测试集
 * @param filepath:保存数据的文件地址
 * @param rate:训练集占中总数据的比例
 * @param trainingSet:训练集
 * @param testSet:测试集 
 * @return: 没有返回,结果都保存在对应的参数中
 */
void loadDataset(string &filepath,double &rate,vector<vector<double>> &trainingSet,vector<vector<double>> &testSet){
    ifstream input;
    input.open("irisdata.txt",ios::in | ios :: binary);//读或二进制打开文件
    vector<vector<string>> lines;//获取文件每一行内容
    int rows = getFileRows(input,lines);//获取行数和行内容
    srand((unsigned int)time(NULL));//设置随机数种子
    //将文本中的行内容转换为数组,并放入训练集或者测试集
    vector<vector<double>> dataset(rows,vector<double>(5));
    for(int i = 0;i < rows;i++){
        //前四个数字转换为double类型
        for(int j = 0;j < 5;j++){
            dataset[i][j] = atof(lines[i][j].c_str());
        }
        //划分训练集和测试集
        if(rand()/double(RAND_MAX) < rate){
            trainingSet.push_back(dataset[i]);
        }
        else{
            testSet.push_back(dataset[i]);
        }
    }
    input.close();
}

/**
 * @destription: 计算距离
 * @param instance1:实例1
 * @param instance2:实例2
 * @param length:特征数 
 * @return: 计算的距离
 */
double calculateDistance(vector<double> &instance1,vector<double> &instance2,int length){
    double distance = 0;
    for(int i = 0;i < length;i++){
        distance += pow(instance1[i] - instance2[i],2);
    }
    return sqrt(distance);
}

/**
 * @destription: 获取训练集中距离最小的k个近邻
 * @param trainingSet:训练集 
 * @param testInstance:测试的实例对象
 * @param k:选取的近邻数量
 * @return:选取的k个近邻集合
 */
vector<vector<double>> getNeighbors(vector<vector<double>> &trainingSet,vector<double> &testInstance,int k){
    vector<pair<vector<double>,double>> distances;
    int len = testInstance.size() - 1;//特征数
    //保存实例和距离
    for(int i = 0;i < trainingSet.size();i++){
        double distance = calculateDistance(testInstance,trainingSet[i],len);
        distances.push_back(make_pair(trainingSet[i],distance));
    }
    //按照距离排序
    sort(distances.begin(),distances.end(),[](pair<vector<double>,double> &p1,pair<vector<double>,double> &p2){
            return p1.second < p2.second;
        });
    //选取距离最小的k个实例作为近邻
    vector<vector<double>> neighbors;
    for(int i = 0;i < k;i++){
        neighbors.push_back(distances[i].first);
    }
    return neighbors;
}

/**
 * @destription: 获得选取近邻的反馈的结果,根据k个近邻中分类结果最多的一个
 * @param neighbors:选取的k个近邻集合 
 * @return: 分类的结果
 */
double getResponse(vector<vector<double>> &neighbors){
    map<int,int> classVotes;
    //遍历k个近邻,统计每个种类的个数
    for(int i = 0;i < neighbors.size();i++){
        classVotes[neighbors[i][4]]++;
    }
    int maxVote = 0;
    double res = 0;
    //计算种类个数最多那个种类
    for(auto vote : classVotes){
        if(vote.second > maxVote){
            maxVote = vote.second;
            res = vote.first;
        }
    }
    return res;

}

/**
 * @destription: 计算预测的准确率
 * @param testSet:测试集合
 * @param predictions:测试的结果集合 
 * @return: 预测的准确率
 */
double getAccuracy(vector<vector<double>> &testSet,vector<double> &predictions){
    int correct = 0;
    //统计预测正确的个数
    for(int i = 0;i < testSet.size();i++){
        if(testSet[i][4] == predictions[i]){
            correct++;
        }
    }
    //返回准确率
    return correct / (double)(testSet.size()) * 100.0;
}

/**
 * @destription: 预测测试集结果
 * @param trainSet:训练集
 * @param testSet:测试集 
 * @return: 预测的结果集合
 */
vector<double> pridict(vector<vector<double>> &trainSet,vector<vector<double>> &testSet){
    vector<double> predictions;
    int k = 3;
    for(int i = 0;i < testSet.size();i++){
        vector<std::vector<double>> neighbors = getNeighbors(trainSet,testSet[i],k);
        double res = getResponse(neighbors);
        predictions.push_back(res);
    }
    return predictions;
}

int main(){
    vector<vector<double>> trainSet;
    vector<vector<double>> testSet;
    double rate = 0.8;
    string filepath = "./irisdata.txt";
    loadDataset(filepath,rate,trainSet,testSet);

    cout << "------------trainSet:--------------" << endl;
    for(auto traindata : trainSet){
        for(double num : traindata){
            cout << num << " ";
        }
        cout << endl;
    }

    cout << "------------testSet:--------------" << endl;
    for(auto testdata : testSet){
        for(double num : testdata){
            cout << num << " ";
        }
        cout << endl;
    }

    vector<double> predictions;
    predictions = pridict(trainSet,testSet);
    cout << "------------测试结果为::--------------" << endl;
    for(int i = 0;i < testSet.size();i++){
        cout << "测试数据" << i << ":";
        for(int j = 0;j < 4;j++){
            cout << testSet[i][j] << " ";
        }
        cout << "预测值:" << predictions[i] << " " << "真实值:" << testSet[i][4] << endl;
    }

    double accuracy = getAccuracy(testSet,predictions);
    cout << "准确率为:" << accuracy << endl;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/一键难忘520/article/detail/801247
推荐阅读
相关标签
  

闽ICP备14008679号