当前位置:   article > 正文

K-最近邻法(KNN) C++实现_最近邻快速从轨迹中找出最近点c++

最近邻快速从轨迹中找出最近点c++

关于KNN的介绍可以参考: http://blog.csdn.net/fengbingchun/article/details/78464169 

这里给出KNN的C++实现,用于分类。训练数据和测试数据均来自MNIST,关于MNIST的介绍可以参考: http://blog.csdn.net/fengbingchun/article/details/49611549  , 从MNIST中提取的40幅图像,0,1,2,3四类各20张,每类的前10幅来自于训练样本,用于训练,后10幅来自测试样本,用于测试,如下图:


实现代码如下:

knn.hpp:

  1. #ifndef FBC_NN_KNN_HPP_
  2. #define FBC_NN_KNN_HPP_
  3. #include <memory>
  4. #include <vector>
  5. namespace ANN {
  6. template<typename T>
  7. class KNN {
  8. public:
  9. KNN() = default;
  10. void set_k(int k);
  11. int set_train_samples(const std::vector<std::vector<T>>& samples, const std::vector<T>& labels);
  12. int predict(const std::vector<T>& sample, T& result) const;
  13. private:
  14. int k = 3;
  15. int feature_length = 0;
  16. int samples_number = 0;
  17. std::unique_ptr<T[]> samples;
  18. std::unique_ptr<T[]> labels;
  19. };
  20. } // namespace ANN
  21. #endif // FBC_NN_KNN_HPP_
knn.cpp:

  1. #include "knn.hpp"
  2. #include <limits>
  3. #include <algorithm>
  4. #include <functional>
  5. #include "common.hpp"
  6. namespace ANN {
  7. template<typename T>
  8. void KNN<T>::set_k(int k)
  9. {
  10. this->k = k;
  11. }
  12. template<typename T>
  13. int KNN<T>::set_train_samples(const std::vector<std::vector<T>>& samples, const std::vector<T>& labels)
  14. {
  15. CHECK(samples.size() == labels.size());
  16. this->samples_number = samples.size();
  17. if (this->k > this->samples_number) this->k = this->samples_number;
  18. this->feature_length = samples[0].size();
  19. this->samples.reset(new T[this->feature_length * this->samples_number]);
  20. this->labels.reset(new T[this->samples_number]);
  21. T* p = this->samples.get();
  22. for (int i = 0; i < this->samples_number; ++i) {
  23. T* q = p + i * this->feature_length;
  24. for (int j = 0; j < this->feature_length; ++j) {
  25. q[j] = samples[i][j];
  26. }
  27. this->labels.get()[i] = labels[i];
  28. }
  29. }
  30. template<typename T>
  31. int KNN<T>::predict(const std::vector<T>& sample, T& result) const
  32. {
  33. if (sample.size() != this->feature_length) {
  34. fprintf(stderr, "their feature length dismatch: %d, %d", sample.size(), this->feature_length);
  35. return -1;
  36. }
  37. typedef std::pair<T, T> value;
  38. std::vector<value> info;
  39. for (int i = 0; i < this->k + 1; ++i) {
  40. info.push_back(std::make_pair(std::numeric_limits<T>::max(), (T)-1.));
  41. }
  42. for (int i = 0; i < this->samples_number; ++i) {
  43. T s{ 0. };
  44. const T* p = this->samples.get() + i * this->feature_length;
  45. for (int j = 0; j < this->feature_length; ++j) {
  46. s += (p[j] - sample[j]) * (p[j] - sample[j]);
  47. }
  48. info[this->k] = std::make_pair(s, this->labels.get()[i]);
  49. std::stable_sort(info.begin(), info.end(), [](const std::pair<T, T>& p1, const std::pair<T, T>& p2) {
  50. return p1.first < p2.first; });
  51. }
  52. std::vector<T> vec(this->k);
  53. for (int i = 0; i < this->k; ++i) {
  54. vec[i] = info[i].second;
  55. }
  56. std::sort(vec.begin(), vec.end(), std::greater<T>());
  57. vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
  58. std::vector<std::pair<T, int>> ret;
  59. for (int i = 0; i < vec.size(); ++i) {
  60. ret.push_back(std::make_pair(vec[i], 0));
  61. }
  62. for (int i = 0; i < this->k; ++i) {
  63. for (int j = 0; j < ret.size(); ++j) {
  64. if (info[i].second == ret[j].first) {
  65. ++ret[j].second;
  66. break;
  67. }
  68. }
  69. }
  70. int max = -1, index = -1;
  71. for (int i = 0; i < ret.size(); ++i) {
  72. if (ret[i].second > max) {
  73. max = ret[i].second;
  74. index = i;
  75. }
  76. }
  77. result = ret[index].first;
  78. return 0;
  79. }
  80. template class KNN<float>;
  81. template class KNN<double>;
  82. } // namespace ANN
测试代码如下:

  1. #include "funset.hpp"
  2. #include <iostream>
  3. #include "perceptron.hpp"
  4. #include "BP.hpp""
  5. #include "CNN.hpp"
  6. #include "linear_regression.hpp"
  7. #include "naive_bayes_classifier.hpp"
  8. #include "logistic_regression.hpp"
  9. #include "common.hpp"
  10. #include "knn.hpp"
  11. #include <opencv2/opencv.hpp>
  12. // =========================== KNN(K-Nearest Neighbor) ======================
  13. int test_knn_classifier_predict()
  14. {
  15. const std::string image_path{ "E:/GitCode/NN_Test/data/images/digit/handwriting_0_and_1/" };
  16. const int K{ 3 };
  17. cv::Mat tmp = cv::imread(image_path + "0_1.jpg", 0);
  18. const int train_samples_number{ 40 }, predict_samples_number{ 40 };
  19. const int every_class_number{ 10 };
  20. cv::Mat train_data(train_samples_number, tmp.rows * tmp.cols, CV_32FC1);
  21. cv::Mat train_labels(train_samples_number, 1, CV_32FC1);
  22. float* p = (float*)train_labels.data;
  23. for (int i = 0; i < 4; ++i) {
  24. std::for_each(p + i * every_class_number, p + (i + 1)*every_class_number, [i](float& v){v = (float)i; });
  25. }
  26. // train data
  27. for (int i = 0; i < 4; ++i) {
  28. static const std::vector<std::string> digit{ "0_", "1_", "2_", "3_" };
  29. static const std::string suffix{ ".jpg" };
  30. for (int j = 1; j <= every_class_number; ++j) {
  31. std::string image_name = image_path + digit[i] + std::to_string(j) + suffix;
  32. cv::Mat image = cv::imread(image_name, 0);
  33. CHECK(!image.empty() && image.isContinuous());
  34. image.convertTo(image, CV_32FC1);
  35. image = image.reshape(0, 1);
  36. tmp = train_data.rowRange(i * every_class_number + j - 1, i * every_class_number + j);
  37. image.copyTo(tmp);
  38. }
  39. }
  40. ANN::KNN<float> knn;
  41. knn.set_k(K);
  42. std::vector<std::vector<float>> samples(train_samples_number);
  43. std::vector<float> labels(train_samples_number);
  44. const int feature_length{ tmp.rows * tmp.cols };
  45. for (int i = 0; i < train_samples_number; ++i) {
  46. samples[i].resize(feature_length);
  47. const float* p1 = train_data.ptr<float>(i);
  48. float* p2 = samples[i].data();
  49. memcpy(p2, p1, feature_length * sizeof(float));
  50. }
  51. const float* p1 = (const float*)train_labels.data;
  52. float* p2 = labels.data();
  53. memcpy(p2, p1, train_samples_number * sizeof(float));
  54. knn.set_train_samples(samples, labels);
  55. // predict datta
  56. cv::Mat predict_data(predict_samples_number, tmp.rows * tmp.cols, CV_32FC1);
  57. for (int i = 0; i < 4; ++i) {
  58. static const std::vector<std::string> digit{ "0_", "1_", "2_", "3_" };
  59. static const std::string suffix{ ".jpg" };
  60. for (int j = 11; j <= every_class_number + 10; ++j) {
  61. std::string image_name = image_path + digit[i] + std::to_string(j) + suffix;
  62. cv::Mat image = cv::imread(image_name, 0);
  63. CHECK(!image.empty() && image.isContinuous());
  64. image.convertTo(image, CV_32FC1);
  65. image = image.reshape(0, 1);
  66. tmp = predict_data.rowRange(i * every_class_number + j - 10 - 1, i * every_class_number + j - 10);
  67. image.copyTo(tmp);
  68. }
  69. }
  70. cv::Mat predict_labels(predict_samples_number, 1, CV_32FC1);
  71. p = (float*)predict_labels.data;
  72. for (int i = 0; i < 4; ++i) {
  73. std::for_each(p + i * every_class_number, p + (i + 1)*every_class_number, [i](float& v){v = (float)i; });
  74. }
  75. std::vector<float> sample(feature_length);
  76. int count{ 0 };
  77. for (int i = 0; i < predict_samples_number; ++i) {
  78. float value1 = ((float*)predict_labels.data)[i];
  79. float value2;
  80. memcpy(sample.data(), predict_data.ptr<float>(i), feature_length * sizeof(float));
  81. CHECK(knn.predict(sample, value2) == 0);
  82. fprintf(stdout, "expected value: %f, actual value: %f\n", value1, value2);
  83. if (int(value1) == int(value2)) ++count;
  84. }
  85. fprintf(stdout, "when K = %d, accuracy: %f\n", K, count * 1.f / predict_samples_number);
  86. return 0;
  87. }
执行结果如下:与OpenCV中KNN结果相似。



GitHub:  https://github.com/fengbingchun/NN_Test 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/562827
推荐阅读
相关标签
  

闽ICP备14008679号