当前位置:   article > 正文

knn之构造kd树和最近邻求取c++实现_数据结构kd树最近邻c++

数据结构kd树最近邻c++
这份代码测试样例为 
  1. 6
  2. 7 2
  3. 2 3
  4. 5 4
  5. 4 7
  6. 9 6
  7. 8 1
  8. 8 2

这样,通过中位数来选取根节点(这样的方法其实在一定程度上是有很大问题的,因为根节点的选取方法不同,会导致整棵树的结构不同,这里由于数据的关系,不能构成完全二叉树,所以在对于特殊的样例来说是会出错的,比如说(10,10)这个测试样例,根本无法找到包含他的子节点(区域),所以会导致出错))。

  1. #include<iostream>
  2. #include<algorithm>
  3. #include<cstring>
  4. #include<vector>
  5. #include<cmath>
  6. #include<queue>
  7. using namespace std;
  8. struct node{
  9. pair<int,int>x;
  10. int dim;
  11. node*left;
  12. node*right;
  13. node*father;
  14. node(pair<int,int>p=make_pair(0,0),int dim=0,node*left=0,node*right=0,node*father=0)
  15. :dim(dim),left(left),right(right),father(father)
  16. {
  17. x=p;
  18. }
  19. };
  20. bool cmp1(node*a,node* b)
  21. {
  22. return a->x.first<b->x.first;
  23. }
  24. bool cmp2(node*a,node*b)
  25. {
  26. return a->x.second<b->x.second;
  27. }
  28. vector<node*>vec;
  29. node* buildtree(vector<node*>temp,int cnt)
  30. {
  31. if(temp.size()==0)
  32. return 0;
  33. else if(temp.size()==1)
  34. return temp[0];
  35. else{
  36. if(cnt==1)
  37. sort(temp.begin(),temp.end(),cmp1);
  38. else
  39. sort(temp.begin(),temp.end(),cmp2);
  40. int mid=temp.size()/2;
  41. vector<node*>p;
  42. for(int i=0;i<mid;i++)
  43. {
  44. p.push_back(temp[i]);
  45. }
  46. vector<node*>q;
  47. for(int i=mid+1;i<temp.size();i++)
  48. {
  49. q.push_back(temp[i]);
  50. }
  51. node*left=buildtree(p,(cnt+1)%2);
  52. node*right=buildtree(q,(cnt+1)%2);
  53. node*fat=new node(make_pair(temp[mid]->x.first,temp[mid]->x.second),cnt,left,right,0);
  54. if(left!=0)
  55. left->father=fat;
  56. if(right!=0)
  57. right->father=fat;
  58. //cout<<fat->x.first<<" "<<fat->x.second<<endl;
  59. return fat;
  60. }
  61. }
  62. void traverse(node*root)
  63. {
  64. if(root==0)
  65. {
  66. }
  67. else
  68. {
  69. cout<<root->x.first<<" "<<root->x.second<<endl;
  70. traverse(root->left);
  71. traverse(root->right);
  72. }
  73. }
  74. node*find_first_belong(node*key,node*root)
  75. {
  76. node*temp=root;
  77. while(true) //遍历找到其归属的叶节点
  78. {
  79. if(temp->left==0&&temp->right==0)
  80. {
  81. break;
  82. }
  83. else
  84. {
  85. int dim=temp->dim;//选择维度比较
  86. if(dim==1)//选择x1比较
  87. {
  88. if(key->x.first<=temp->x.first)
  89. temp=temp->left;
  90. else
  91. temp=temp->right;
  92. }
  93. else //选择x2比较
  94. {
  95. if(key->x.second<=temp->x.second)
  96. temp=temp->left;
  97. else
  98. temp=temp->right;
  99. }
  100. }
  101. }
  102. return temp;
  103. }
  104. double distance(node*a,node*b)
  105. {
  106. double ax1=a->x.first;
  107. double ax2=a->x.second;
  108. double bx1=b->x.first;
  109. double bx2=b->x.second;
  110. return sqrt(pow(ax1-bx1,2)+pow(ax2-bx2,2));
  111. }
  112. node*query(node*key,node*root,double mindis)
  113. //这里就是最不明白的一点,当另一区域跟圆相交,书上说是递归进行最近邻搜索,
  114. //没搞懂到底怎么递归搜索,所以这里就直接用了很简单的遍历比较,希望以后能搞懂
  115. {
  116. node*rec=root;
  117. double mind=mindis;
  118. queue<node*>q;
  119. q.push(root);
  120. while(!q.empty())
  121. {
  122. node*temp=q.front();
  123. double dis=distance(key,temp);
  124. if(dis<mind)
  125. {
  126. mind=dis;
  127. rec=temp;
  128. }
  129. q.pop();
  130. if(temp->left!=0)
  131. q.push(temp->left);
  132. if(temp->right)
  133. q.push(temp->right);
  134. }
  135. return rec;
  136. }
  137. node*find_nearest(node*key,node*belong)
  138. {
  139. node *nearest=belong;
  140. double mindis=distance(key,belong);
  141. //cout<<mindis<<" mindis"<<endl;
  142. while(true)
  143. {
  144. //cout<<belong->x.first<<" "<<belong->x.second<<endl;
  145. node*fat=belong->father;
  146. if(fat==0)
  147. break;
  148. int fdim=fat->dim;
  149. if(distance(fat,key)<mindis)
  150. {
  151. mindis=distance(fat,key);
  152. nearest=fat;
  153. }
  154. if(fdim==1) //判断圆是否与x1=fat->x.first相交
  155. {
  156. int fx1=fat->x.first;
  157. int kx1=key->x.first;
  158. if(abs(fx1-kx1)<mindis)
  159. {
  160. node*res=query(key,fat->right,mindis);
  161. if(res!=0&&distance(res,key)<mindis)
  162. {
  163. nearest=res;
  164. mindis=distance(res,key);
  165. }
  166. }
  167. }
  168. else //反之
  169. {
  170. int fx2=fat->x.second;
  171. int kx2=key->x.second;
  172. if(abs(fx2-kx2)<mindis)
  173. {
  174. node*res=query(key,fat->right,mindis);
  175. if(res!=0&&distance(res,key)<mindis)
  176. {
  177. nearest=res;
  178. mindis=distance(res,key);
  179. }
  180. }
  181. }
  182. belong=fat;
  183. if(belong==0)
  184. break;
  185. }
  186. return nearest;
  187. }
  188. node*search(node*key,node*root)
  189. {
  190. node* belong=find_first_belong(key,root);
  191. //cout<<belong->x.first<<" "<<belong->x.second<<endl;
  192. node* nearest=find_nearest(key,belong);
  193. }
  194. int main()
  195. {
  196. int n;
  197. cin>>n;
  198. for(int i=0;i<n;i++)
  199. {
  200. int x,y;
  201. cin>>x>>y;
  202. node* temp=new node(make_pair(x,y));
  203. vec.push_back(temp);
  204. }
  205. node*root=buildtree(vec,1);
  206. //traverse(root);
  207. int x,y;
  208. cin>>x>>y;
  209. node *key=new node(make_pair(x,y));
  210. node*near=search(key,root);
  211. cout<<near->x.first<<" "<<near->x.second<<endl;
  212. }
以上代码,经过测试,除了(10,10)这种类似的特殊数据会出错,别的基本正确,代码写的很乱。。。。

这里还有一个很大的问题在于,我不知道一旦判定了圆和其他区域相交之后该怎么进行递归搜索,所以这里直接用了遍历。。。。


总算搞懂了什么递归搜索:

下面的是第二个版本:

  1. #include<iostream>
  2. #include<algorithm>
  3. #include<cstring>
  4. #include<vector>
  5. #include<cmath>
  6. #include<queue>
  7. using namespace std;
  8. struct node{
  9. pair<int,int>x;
  10. int dim;
  11. node*left;
  12. node*right;
  13. node*father;
  14. node(pair<int,int>p=make_pair(0,0),int dim=0,node*left=0,node*right=0,node*father=0)
  15. :dim(dim),left(left),right(right),father(father)
  16. {
  17. x=p;
  18. }
  19. };
  20. bool cmp1(node*a,node* b)
  21. {
  22. return a->x.first<b->x.first;
  23. }
  24. bool cmp2(node*a,node*b)
  25. {
  26. return a->x.second<b->x.second;
  27. }
  28. vector<node*>vec;
  29. node* buildtree(vector<node*>temp,int cnt)
  30. {
  31. if(temp.size()==0)
  32. return 0;
  33. else if(temp.size()==1)
  34. return temp[0];
  35. else{
  36. if(cnt==1)
  37. sort(temp.begin(),temp.end(),cmp1);
  38. else
  39. sort(temp.begin(),temp.end(),cmp2);
  40. int mid=temp.size()/2;
  41. vector<node*>p;
  42. for(int i=0;i<mid;i++)
  43. {
  44. p.push_back(temp[i]);
  45. }
  46. vector<node*>q;
  47. for(int i=mid+1;i<temp.size();i++)
  48. {
  49. q.push_back(temp[i]);
  50. }
  51. node*left=buildtree(p,(cnt+1)%2);
  52. node*right=buildtree(q,(cnt+1)%2);
  53. node*fat=new node(make_pair(temp[mid]->x.first,temp[mid]->x.second),cnt,left,right,0);
  54. if(left!=0)
  55. left->father=fat;
  56. if(right!=0)
  57. right->father=fat;
  58. //cout<<fat->x.first<<" "<<fat->x.second<<endl;
  59. return fat;
  60. }
  61. }
  62. void traverse(node*root)
  63. {
  64. if(root==0)
  65. {
  66. }
  67. else
  68. {
  69. cout<<root->x.first<<" "<<root->x.second<<endl;
  70. traverse(root->left);
  71. traverse(root->right);
  72. }
  73. }
  74. node*find_first_belong(node*key,node*root)
  75. {
  76. node*temp=root;
  77. while(true) //遍历找到其归属的叶节点
  78. {
  79. if(temp->left==0&&temp->right==0)
  80. {
  81. break;
  82. }
  83. else
  84. {
  85. int dim=temp->dim;//选择维度比较
  86. if(dim==1)//选择x1比较
  87. {
  88. if(key->x.first<=temp->x.first)
  89. temp=temp->left;
  90. else
  91. temp=temp->right;
  92. }
  93. else //选择x2比较
  94. {
  95. if(key->x.second<=temp->x.second)
  96. temp=temp->left;
  97. else
  98. temp=temp->right;
  99. }
  100. }
  101. }
  102. return temp;
  103. }
  104. double distance(node*a,node*b)
  105. {
  106. double ax1=a->x.first;
  107. double ax2=a->x.second;
  108. double bx1=b->x.first;
  109. double bx2=b->x.second;
  110. return sqrt(pow(ax1-bx1,2)+pow(ax2-bx2,2));
  111. }
  112. node*query(node*key,node*root,double mindis)//没有用的函数
  113. {
  114. node*rec=root;
  115. double mind=mindis;
  116. queue<node*>q;
  117. q.push(root);
  118. while(!q.empty())
  119. {
  120. node*temp=q.front();
  121. double dis=distance(key,temp);
  122. if(dis<mind)
  123. {
  124. mind=dis;
  125. rec=temp;
  126. }
  127. q.pop();
  128. if(temp->left!=0)
  129. q.push(temp->left);
  130. if(temp->right)
  131. q.push(temp->right);
  132. }
  133. return rec;
  134. }
  135. node*find_nearest(node*key,node*belong,node*root)
  136. {
  137. node *nearest=belong;
  138. double mindis=distance(key,belong);
  139. //cout<<belong->x.first<<" belong "<<belong->x.second<<endl;
  140. //cout<<mindis<<" mindis"<<endl;
  141. while(true)
  142. {
  143. //cout<<belong->x.first<<" "<<belong->x.second<<endl;
  144. node*fat=belong->father;
  145. if(fat==0||fat==root->father)
  146. break;
  147. node*other=new node(); //相比第一个这里还更加对了,因为这里还考虑到了万一归属的叶节点不是左节点的情况
  148. if(fat->left==belong)
  149. {
  150. other=fat->right;
  151. }
  152. else
  153. other=fat->left;
  154. //cout<<fat->x.first<<" "<<" fat "<<fat->x.second<<endl;
  155. int fdim=fat->dim;
  156. if(distance(fat,key)<mindis)
  157. {
  158. mindis=distance(fat,key);
  159. nearest=fat;
  160. }
  161. if(fdim==1) //判断圆是否与x1=fat->x.first相交
  162. {
  163. int fx1=fat->x.first;
  164. int kx1=key->x.first;
  165. if(abs(fx1-kx1)<mindis)
  166. {
  167. node*tm=find_first_belong(key,other);
  168. node*res=find_nearest(key,tm,other); //传说中的递归搜索在这里,利用他之前的函数
  169. if(res!=0&&distance(res,key)<mindis)
  170. {
  171. nearest=res;
  172. mindis=distance(res,key);
  173. }
  174. }
  175. //cout<<fx1<<" xxxx "<<kx1<<" "<<mindis<<endl;
  176. }
  177. else //反之
  178. {
  179. int fx2=fat->x.second;
  180. int kx2=key->x.second;
  181. if(abs(fx2-kx2)<mindis)
  182. {
  183. node*tm=find_first_belong(key,other);
  184. //cout<<tm->x.first<<" **** "<<tm->x.second<<endl;
  185. //cout<<other->x.first<<" other "<<other->x.second<<endl;
  186. node*res=find_nearest(key,tm,other);
  187. if(res!=0&&distance(res,key)<mindis)
  188. {
  189. nearest=res;
  190. mindis=distance(res,key);
  191. //cout<<mindis<<" mindis"<<endl;
  192. }
  193. }
  194. }
  195. belong=fat;
  196. if(belong==0)
  197. break;
  198. }
  199. return nearest;
  200. }
  201. node*search(node*key,node*root)
  202. {
  203. node* belong=find_first_belong(key,root);
  204. //cout<<belong->x.first<<" "<<belong->x.second<<endl;
  205. node* nearest=find_nearest(key,belong,root);
  206. return nearest;
  207. }
  208. int main()
  209. {
  210. int n;
  211. cin>>n;
  212. for(int i=0;i<n;i++)
  213. {
  214. int x,y;
  215. cin>>x>>y;
  216. node* temp=new node(make_pair(x,y));
  217. vec.push_back(temp);
  218. }
  219. node*root=buildtree(vec,1);
  220. //traverse(root);
  221. int x,y;
  222. cin>>x>>y;
  223. node *key=new node(make_pair(x,y));
  224. node*near=search(key,root);
  225. cout<<"the nearest point is "<<near->x.first<<" "<<near->x.second<<endl;
  226. }
然而还是没有解决(10,10)的情况,明天再说!!!!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/562852
推荐阅读
相关标签
  

闽ICP备14008679号