当前位置:   article > 正文

C++实现最基本的KNN手写数字识别_"ifstream pic(\"train-images.idx3-ubyte\", ios::bi

"ifstream pic(\"train-images.idx3-ubyte\", ios::binary); ifstream tag(\"train-la"

KNN算法概述KNN算法的思想即,一个样本属于其在特征空间里的最近邻样本中数目最多的分类。

该算法有几个关键要素:

一.K的取值,即选区K个最近邻样本中所属最多的分类,显然算法的效果很大程度上取决于K的大小。

二.样本对象之间距离的定义,一般使用欧氏距离或曼哈顿距离:

KNN算法流程

1.计算预测样本与每个训练集样本的距离

2.将样本按距离从小到大排序

3.从前往后取K个样本,统计各个标签对应样本数目

4.找出对于数目最多的标签,则为测试样本所属分类

 

这里我们选择欧式距离来计算样本间距离,测试不同K的取值下算法的预测精度。对MNIST数据集进行一些处理,即将像素点不为0的位置全部定义为1。因为朴素的KNN算法计算复杂度非常高,仅选用部分训练数据和测试数据进行实验。KNN算法并不会训练出一个用于分类或回归的模型,所以每次进行预测时,我们都需要打包所有的训练数据。

  1. #include <bits/stdc++.h>
  2. using namespace std ;
  3. vector<double>labels;
  4. vector<vector<double> >images;//训练集
  5. vector<double>labels1;
  6. vector<vector<double> >images1;//测试集
  7. const int train_number=10000;
  8. const int test_number=500;
  9. int a[20];
  10. int KNN(int i,int k);
  11. struct node
  12. {
  13. int labels;
  14. int dis;
  15. }q[train_number+100];
  16. bool cmp(node a,node b)
  17. {
  18. return a.dis<b.dis;
  19. }
  20. /**********************************/
  21. int ReverseInt(int i)
  22. {
  23. unsigned char ch1, ch2, ch3, ch4;
  24. ch1 = i & 255;
  25. ch2 = (i >> 8) & 255;
  26. ch3 = (i >> 16) & 255;
  27. ch4 = (i >> 24) & 255;
  28. return((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4;
  29. }
  30. void read_Mnist_Label(string filename, vector<double>&labels)
  31. {
  32. ifstream file;
  33. file.open("train-labels.idx1-ubyte", ios::binary);
  34. if (file.is_open())
  35. {
  36. int magic_number = 0;
  37. int number_of_images = 0;
  38. file.read((char*)&magic_number, sizeof(magic_number));
  39. file.read((char*)&number_of_images, sizeof(number_of_images));
  40. magic_number = ReverseInt(magic_number);
  41. number_of_images = ReverseInt(number_of_images);
  42. cout << "magic number = " << magic_number << endl;
  43. cout << "number of images = " << number_of_images << endl;
  44. for (int i = 0; i < number_of_images; i++)
  45. {
  46. unsigned char label = 0;
  47. file.read((char*)&label, sizeof(label));
  48. labels.push_back((double)label);
  49. }
  50. }
  51. }
  52. void read_Mnist_Images(string filename, vector<vector<double> >&images)
  53. {
  54. ifstream file("train-images.idx3-ubyte", ios::binary);
  55. if (file.is_open())
  56. {
  57. int magic_number = 0;
  58. int number_of_images = 0;
  59. int n_rows = 0;
  60. int n_cols = 0;
  61. unsigned char label;
  62. file.read((char*)&magic_number, sizeof(magic_number));
  63. file.read((char*)&number_of_images, sizeof(number_of_images));
  64. file.read((char*)&n_rows, sizeof(n_rows));
  65. file.read((char*)&n_cols, sizeof(n_cols));
  66. magic_number = ReverseInt(magic_number);
  67. number_of_images = ReverseInt(number_of_images);
  68. n_rows = ReverseInt(n_rows);
  69. n_cols = ReverseInt(n_cols);
  70. cout << "magic number = " << magic_number << endl;
  71. cout << "number of images = " << number_of_images << endl;
  72. cout << "rows = " << n_rows << endl;
  73. cout << "cols = " << n_cols << endl;
  74. for (int i = 0; i < number_of_images; i++)
  75. {
  76. vector<double>tp;
  77. for (int r = 0; r < n_rows; r++)
  78. {
  79. for (int c = 0; c < n_cols; c++)
  80. {
  81. unsigned char image = 0;
  82. file.read((char*)&image, sizeof(image));
  83. tp.push_back(image);
  84. }
  85. }
  86. images.push_back(tp);
  87. }
  88. }
  89. }
  90. void read_Mnist_Label1(string filename, vector<double>&labels)
  91. {
  92. ifstream file;
  93. file.open("t10k-labels.idx1-ubyte", ios::binary);
  94. if (file.is_open())
  95. {
  96. int magic_number = 0;
  97. int number_of_images = 0;
  98. file.read((char*)&magic_number, sizeof(magic_number));
  99. file.read((char*)&number_of_images, sizeof(number_of_images));
  100. magic_number = ReverseInt(magic_number);
  101. number_of_images = ReverseInt(number_of_images);
  102. for (int i = 0; i < number_of_images; i++)
  103. {
  104. unsigned char label = 0;
  105. file.read((char*)&label, sizeof(label));
  106. labels.push_back((double)label);
  107. }
  108. }
  109. }
  110. void read_Mnist_Images1(string filename, vector<vector<double> >&images)
  111. {
  112. ifstream file("t10k-images.idx3-ubyte", ios::binary);
  113. if (file.is_open())
  114. {
  115. int magic_number = 0;
  116. int number_of_images = 0;
  117. int n_rows = 0;
  118. int n_cols = 0;
  119. unsigned char label;
  120. file.read((char*)&magic_number, sizeof(magic_number));
  121. file.read((char*)&number_of_images, sizeof(number_of_images));
  122. file.read((char*)&n_rows, sizeof(n_rows));
  123. file.read((char*)&n_cols, sizeof(n_cols));
  124. magic_number = ReverseInt(magic_number);
  125. number_of_images = ReverseInt(number_of_images);
  126. n_rows = ReverseInt(n_rows);
  127. n_cols = ReverseInt(n_cols);
  128. for (int i = 0; i < number_of_images; i++)
  129. {
  130. vector<double>tp;
  131. for (int r = 0; r < n_rows; r++)
  132. {
  133. for (int c = 0; c < n_cols; c++)
  134. {
  135. unsigned char image = 0;
  136. file.read((char*)&image, sizeof(image));
  137. tp.push_back(image);
  138. }
  139. }
  140. images.push_back(tp);
  141. }
  142. }
  143. }
  144. /**************以上为MNIST数据集读取部分,下面开始KNN算法**************/
  145. void test(int k)
  146. {
  147. int sum=0;
  148. for(int i=0;i<test_number;i++)
  149. {
  150. int predict=KNN(i,k);
  151. //printf("pre:%d label:%d\n",predict,(int)labels1[i]);
  152. if(predict==(int)labels1[i]) sum++;
  153. }
  154. printf("k=%d precision: %.5f\n",k,1.0*sum/test_number);
  155. }
  156. int KNN(int number,int k)//预测函数
  157. {
  158. memset(q,0,sizeof(q));
  159. memset(a,0,sizeof(a));
  160. int dis=0;
  161. for(int i=0;i<train_number;i++)
  162. {
  163. for(int j=0;j<784;j++)
  164. dis+=(images[i][j]-images1[number][j])*(images[i][j]-images1[number][j]);
  165. dis=sqrt(dis);//获得欧式距离
  166. q[i].dis=(int)dis;
  167. q[i].labels=(int)labels[i];
  168. }
  169. sort(q,q+train_number,cmp);
  170. for(int i=0;i<k;i++)
  171. {
  172. a[q[i].labels]++;
  173. }
  174. int ans=-1,minn=-1;
  175. for(int i=0;i<10;i++)
  176. {
  177. if(a[i]>minn)
  178. {
  179. minn=a[i];
  180. ans=i;
  181. }
  182. }
  183. return ans;
  184. }
  185. int main()
  186. {
  187. read_Mnist_Label("t10k-labels.idx1-ubyte", labels);
  188. read_Mnist_Images("t10k-images.idx3-ubyte", images);
  189. read_Mnist_Label1("t10k-labels.idx1-ubyte", labels1);
  190. read_Mnist_Images1("t10k-images.idx3-ubyte", images1);//读取mnist数据集
  191. for (int i = 0; i < images1.size(); i++)
  192. {
  193. for (int j = 0; j < images1[0].size(); j++)
  194. {
  195. images1[i][j]=(images1[i][j]>0)?1:0;
  196. }
  197. }
  198. for (int i = 0; i < images.size(); i++)
  199. {
  200. for (int j = 0; j < images[0].size(); j++)
  201. {
  202. images[i][j]=(images[i][j]>0)?1:0;
  203. }
  204. }
  205. test(1);
  206. test(2);
  207. test(3);
  208. test(4);
  209. return 0;
  210. }

部分测试截图如下:

 

KNN算法存在几个主要的缺点:

1.当样本数据分布不均衡的时候,很有可能出现当输入一个未知样本时,该样本的K个邻居中大数量类的样本占多数,但是这类样本并不接近目标样本(就像下图的Y点),会被误判为蓝色分类。针对此类情况,可以采用距离加权的优化算法,即距离越近的样本对应分类获得更大的全权值。

2.数据量庞大时,计算复杂度过高。执行一次KNN算法,需要遍历一遍所有数据集。可以采用数据结构优化的办法,常见的有K-D数和球树等优化。

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/357209
推荐阅读
相关标签
  

闽ICP备14008679号