当前位置:   article > 正文

C++蒙特卡洛树算法实现五子棋AI_c++蒙特卡罗树实现

c++蒙特卡罗树实现

蒙特卡洛方法

随机的对当前局面进行后续状态模拟,根据模拟结果决定下一步行动

下面是蒙特卡洛树的实现原理:

 

                

下面是具体实现,详细讲解见注释。

一些优化:

1.中心邻域搜索:根据启发式信息,每一步棋子应当选择与对手上一步棋子周围的一步,于是可以只搜索对方上一步周围的后继状态

2.必胜状态:如果已经能一步走成必胜局面,直接走这一步,或者一个后继状态对应的对手所有后继状态都是失败的,则这步必胜。

3.判断棋盘状态:使用std::set(平衡树),也可以使用hash算法

Tips:调整select_num和sta_num可以让AI获得不同效果

  1. #include<iostream>
  2. #include<map>
  3. #include<time.h>
  4. #include<stdlib.h>
  5. #include<cmath>
  6. #include<vector>
  7. #include<utility>
  8. #include<windows.h>
  9. #include<conio.h>
  10. #define select_num 300 //选择次数
  11. #define sta_num 15 //每个状态模拟次数
  12. #define mp(x,y) make_pair(x,y)
  13. const int row=11,col=11;//棋盘的行和列
  14. const int search_range=2;//搜索范围
  15. const double inf=1e9+7;
  16. using namespace std;
  17. typedef struct chess{//定义棋盘
  18. int g[12][12];
  19. }chess;
  20. //0代表无棋子,1代表白棋,2代表黑棋
  21. bool operator < (chess x,chess y){//用于搜索棋盘状态
  22. for(int i=0;i<=row;i++)
  23. {
  24. for(int j=0;j<=col;j++)
  25. {
  26. if(x.g[i][j]<y.g[i][j])return 1;
  27. else if(x.g[i][j]>y.g[i][j])return 0;
  28. }
  29. }
  30. return 0;
  31. }
  32. bool operator == (chess x,chess y){//用于判断棋盘状态是否相等
  33. for(int i=0;i<=row;i++)
  34. {
  35. for(int j=0;j<col;j++)
  36. {
  37. if(x.g[i][j]!=y.g[i][j])return 0;
  38. }
  39. }
  40. return 1;
  41. }
  42. typedef struct property{//棋盘状态的一些性质
  43. double a,b;
  44. vector<chess> vec;
  45. }property;
  46. int rd=0;//回合数
  47. map<chess,property> mp;//用于记录棋盘状态的一些性质
  48. map<chess,chess> fa;//棋盘状态的父节点,用于反向传播
  49. pair<int,int> center;//搜索中心
  50. void init_chess(chess x);//初始化棋盘状态
  51. void init_window();//初始化窗口
  52. void game_window();//游戏窗口
  53. void print_board(chess x);//打印棋盘
  54. void set_pos(int x,int y);//调整光标位置
  55. void white_win();//白棋胜利
  56. void black_win();//黑棋胜利
  57. chess UCT_search(chess x,pair<int,int> center,int player);//蒙特卡洛树搜索
  58. pair<chess,int> tree_policy(chess x,pair<int,int> center,int player);//选择子节点
  59. chess expand(chess x,pair<int,int> center,int player);//扩展当前节点
  60. double UCB(chess x,int player);//计算节点的UCB值
  61. pair<int,int> cal_centre(chess x);//计算当前局面的搜索中心
  62. double default_policy(chess x,int player);//模拟结果
  63. void back_up(chess x,chess y,int value);//反向传播
  64. pair< int,pair<int,int> > check_four(chess x);//预言胜利
  65. pair< int,pair<int,int> > check_three(chess x);//预言胜利优化
  66. int check(chess x);//检查是否胜利,1为白胜,2为黑胜
  67. int check_win(chess x,int player);//分别检查黑白棋是否胜利
  68. int is_terminal(chess x);//检查是否为叶子节点
  69. int cnt_num(chess x,int x1,int x2,int y1,int y2);//查看节点的棋子数量
  70. //player:0为白,1为黑
  71. void init_chess(chess x){
  72. property p;
  73. p.a=p.b=0.0;
  74. mp[x]=p;
  75. }
  76. void set_pos(int x,int y){
  77. HANDLE winHandle;
  78. COORD pos;
  79. pos.X=x,pos.Y=y;
  80. winHandle = GetStdHandle(STD_OUTPUT_HANDLE);
  81. SetConsoleCursorPosition(winHandle,pos);
  82. }
  83. void init_window(){
  84. system("cls");
  85. for(int i=1;i<=7;i++)cout<<"\n";
  86. for(int i=1;i<=6;i++)cout<<"\t";
  87. cout<<"五子棋"<<"\n\n\n\n";
  88. for(int i=1;i<=5;i++)cout<<"\t ";
  89. cout<<"输入任意键开始游戏";
  90. char h;
  91. h=getch();
  92. game_window();
  93. }
  94. void game_window(){
  95. rd=0;
  96. mp.clear(),fa.clear();
  97. chess now;
  98. for(int i=0;i<=row;i++)
  99. {
  100. for(int j=0;j<=col;j++)
  101. {
  102. now.g[i][j]=0;
  103. }
  104. }
  105. init_chess(now);
  106. //pair<int,int> center;
  107. center.first=row/2,center.second=col/2;
  108. srand(time(0));
  109. print_board(now);
  110. while(!check(now))
  111. {
  112. now=UCT_search(now,center,1);
  113. if(check(now)==2){
  114. print_board(now);
  115. black_win();
  116. }
  117. while(1)
  118. {
  119. print_board(now);
  120. set_pos(65,12);
  121. cout<<"轮到您执棋,请输入坐标:";
  122. set_pos(65,14);
  123. int x,y;
  124. cin>>x>>y;
  125. x--,y--;
  126. if(x<0||x>row||y<0||y>col){
  127. set_pos(65,16);
  128. cout<<"您输入的数据超出棋盘范围"<<'\n';
  129. Sleep(1500);
  130. }else if(now.g[x][y]){
  131. set_pos(65,16);
  132. cout<<"该位置已有棋子";
  133. Sleep(1500);
  134. }else{
  135. now.g[x][y]=1;
  136. center.first=cal_centre(now).first,center.second=cal_centre(now).second;
  137. rd++;
  138. break;
  139. }
  140. }
  141. print_board(now);
  142. if(check(now)==1)white_win();
  143. }
  144. }
  145. void print_board(chess x){
  146. system("cls");
  147. for(int i=1;i<=2;i++)cout<<"\t";
  148. for(int i=0;i<=col;i++)
  149. {
  150. if(i+1<10)cout<<" "<<i+1<<" ";
  151. else cout<<" "<<i+1;
  152. }
  153. cout<<"\n";
  154. for(int i=0;i<=row;i++)
  155. {
  156. for(int j=1;j<=2;j++)cout<<"\t";
  157. for(int j=0;j<=col;j++)cout<<" __";
  158. cout<<"\n";
  159. for(int j=1;j<=1;j++)cout<<"\t";
  160. cout<<i+1;
  161. cout<<"\t";
  162. for(int j=0;j<=col;j++)
  163. {
  164. char p;
  165. if(x.g[i][j]==0)p=' ';
  166. else if(x.g[i][j]==1)p='O';
  167. else if(x.g[i][j]==2)p='X';
  168. cout<<"|"<<p<<" ";
  169. }
  170. cout<<"|";
  171. cout<<"\n";
  172. }
  173. for(int j=1;j<=2;j++)cout<<"\t";
  174. for(int i=0;i<=col;i++)cout<<" __";
  175. }
  176. void white_win(){
  177. Sleep(1500);
  178. system("cls");
  179. for(int i=1;i<=8;i++)cout<<'\n';
  180. for(int i=1;i<=4;i++)cout<<'\t';
  181. cout<<"白棋胜利";
  182. cout<<"\n\n";
  183. Sleep(1500);
  184. init_window();
  185. }
  186. void black_win(){
  187. Sleep(1500);
  188. system("cls");
  189. for(int i=1;i<=8;i++)cout<<'\n';
  190. for(int i=1;i<=4;i++)cout<<'\t';
  191. cout<<"黑棋胜利";
  192. cout<<"\n\n";
  193. Sleep(1500);
  194. init_window();
  195. }
  196. chess UCT_search(chess x,pair<int,int> center,int player){
  197. chess y=x;
  198. chess ans=y;
  199. if(check_four(y).first)
  200. {
  201. pair<int,int> u=check_four(y).second;
  202. ans.g[u.first][u.second]=player+1;
  203. return ans;
  204. }
  205. if(check_three(y).first)
  206. {
  207. pair<int,int> u=check_three(y).second;
  208. ans.g[u.first][u.second]=player+1;
  209. return ans;
  210. }
  211. if(mp.find(x)==mp.end())
  212. {
  213. init_chess(x);
  214. }
  215. int cnt=0;//选择次数
  216. while(cnt<=select_num)
  217. {
  218. //int judge=rand()%100;
  219. //if(judge<1)center.first=max(0,center.first-1);
  220. //if(judge<1)center.second=min(col,center.second+1);
  221. //if(judge>98)center.first=min(row,center.first+1);
  222. //if(judge>98)center.second=max(0,center.second-1);
  223. cnt++;
  224. pair<chess,int> select_point=tree_policy(x,center,1);
  225. for(int j=1;j<=sta_num;j++)//每个状态多次模拟,增强效果
  226. {
  227. double s=default_policy(select_point.first,select_point.second^1);
  228. back_up(select_point.first,y,s);
  229. }
  230. }
  231. vector<chess>::iterator it;
  232. double maxn=UCB(*(mp[y].vec.begin()),player);
  233. for(it=mp[y].vec.begin();it!=mp[y].vec.end();it++)
  234. {
  235. if(UCB(*it,player)>=maxn)
  236. {
  237. maxn=UCB(*it,player);
  238. ans=*it;
  239. }
  240. //print_board(*it);
  241. //cout<<mp[*it].a<<" "<<mp[*it].b<<'\n';
  242. //Sleep(1500);
  243. }
  244. vector<chess>().swap(mp[y].vec);//释放内存
  245. return ans;
  246. }
  247. pair<chess,int> tree_policy(chess x,pair<int,int> center,int player){
  248. while(!check(x)&&!is_terminal(x))
  249. {
  250. int x1=max(0,center.first-search_range);
  251. int x2=min(row,center.first+search_range);
  252. int y1=max(0,center.second-search_range);
  253. int y2=min(col,center.second+search_range);
  254. if(cnt_num(x,x1,x2,y1,y2)+mp[x].vec.size()<(x2-x1+1)*(y2-y1+1))
  255. {
  256. return mp(expand(x,center,player),player);
  257. }else{
  258. chess y=x;
  259. vector<chess>::iterator it;
  260. if(mp[y].vec.size()==0)break;
  261. double maxn=UCB(*(mp[y].vec.begin()),player);
  262. for(it=mp[y].vec.begin();it!=mp[y].vec.end();it++)
  263. {
  264. if(UCB(*it,player)>=maxn)
  265. {
  266. maxn=UCB(*it,player);
  267. x=*it;
  268. }
  269. }
  270. fa[x]=y;
  271. }
  272. player^=1;
  273. }
  274. return mp(x,player);
  275. }
  276. chess expand(chess x,pair<int,int> center,int player){
  277. chess y=x;
  278. int x1=max(0,center.first-search_range);
  279. int x2=min(row,center.first+search_range);
  280. int y1=max(0,center.second-search_range);
  281. int y2=min(col,center.second+search_range);
  282. int cnt=0;
  283. while(cnt<=1000)
  284. {
  285. int i=x1+rand()%(x2-x1+1);
  286. int j=y1+rand()%(y2-y1+1);
  287. int o=x.g[i][j];
  288. y.g[i][j]=player+1;
  289. if(!x.g[i][j]&&mp.find(y)==mp.end())
  290. {
  291. init_chess(y);
  292. mp[x].vec.push_back(y);
  293. fa[y]=x;
  294. return y;
  295. }
  296. y.g[i][j]=o;
  297. cnt++;
  298. }
  299. while(1)
  300. {
  301. int i=x1+rand()%(x2-x1+1);
  302. int j=y1+rand()%(y2-y1+1);
  303. int o=x.g[i][j];
  304. y.g[i][j]=player+1;
  305. if(!x.g[i][j]){
  306. if(mp.find(y)==mp.end()){
  307. init_chess(y);
  308. }
  309. return y;
  310. }
  311. y.g[i][j]=o;
  312. }
  313. }
  314. double UCB(chess x,int player){
  315. if(mp[x].b==0)return 0;
  316. double a1=mp[x].a,b1=mp[x].b;
  317. if(a1+b1==0)return -inf;
  318. if(player==1)return a1/b1+sqrt(log(a1+b1)/b1);
  319. else if(player==0)return -a1/b1+sqrt(log(a1+b1)/b1);
  320. }
  321. pair<int,int> cal_centre(chess x){//以每个棋子横纵坐标的均值为搜索中心
  322. int cnt=0,p1=0,p2=0;
  323. for(int i=0;i<=row;i++)
  324. {
  325. for(int j=0;j<=col;j++)
  326. {
  327. if(x.g[i][j]){
  328. cnt++;
  329. p1+=i;
  330. p2+=j;
  331. }
  332. }
  333. }
  334. p1=max(0,p1/cnt);
  335. p2=max(0,p2/cnt);
  336. return mp(p1,p2);
  337. }
  338. double default_policy(chess x,int player){
  339. while(1)
  340. {
  341. if(check(x)||is_terminal(x))break;
  342. pair<int,int> h=cal_centre(x);
  343. int o=rand()%100;
  344. int i,j;
  345. if(o<75){
  346. i=min(max(0,h.first-search_range+rand()%(search_range*2+1)),row);
  347. j=min(max(0,h.second-search_range+rand()%(search_range*2+1)),col);
  348. }else{
  349. i=rand()%(row+1);
  350. j=rand()%(col+1);
  351. }
  352. if(!x.g[i][j])
  353. {
  354. x.g[i][j]=player+1;
  355. player^=1;
  356. }
  357. }
  358. if(check(x)==1)return -1;
  359. else if(check(x)==2)return 1;
  360. else return 0;
  361. }
  362. void back_up(chess x,chess y,int value){
  363. mp[x].a+=value;
  364. mp[x].b++;
  365. while(!(x==y))
  366. {
  367. if(fa.find(x)==fa.end())break;
  368. x=fa[x];
  369. mp[x].a+=value;
  370. mp[x].b++;
  371. }
  372. }
  373. pair< int,pair<int,int> > check_four(chess x){
  374. chess y=x;
  375. for(int i=0;i<=row;i++)
  376. {
  377. for(int j=0;j<=col;j++)
  378. {
  379. if(!x.g[i][j])
  380. {
  381. x.g[i][j]=2;
  382. if(check(x)==2)return mp(2,mp(i,j));
  383. x.g[i][j]=0;
  384. }
  385. }
  386. }
  387. for(int i=0;i<=row;i++)
  388. {
  389. for(int j=0;j<=col;j++)
  390. {
  391. if(!y.g[i][j])
  392. {
  393. y.g[i][j]=1;
  394. if(check(y)==1)return mp(1,mp(i,j));
  395. y.g[i][j]=0;
  396. }
  397. }
  398. }
  399. return mp(0,mp(0,0));
  400. }
  401. pair< int,pair<int,int> > check_three(chess x){
  402. chess y1=x,y2=x;
  403. for(int i=0;i<=row;i++)
  404. {
  405. for(int j=0;j<=col;j++)
  406. {
  407. if(!x.g[i][j])
  408. {
  409. x.g[i][j]=2;
  410. int flag=1;
  411. for(int k1=0;k1<=row;k1++)
  412. {
  413. for(int k2=0;k2<=col;k2++)
  414. {
  415. if(!x.g[k1][k2]){
  416. x.g[k1][k2]=1;
  417. if(check_four(x).first!=2){
  418. flag=0;
  419. x.g[k1][k2]=0;
  420. break;
  421. }else x.g[k1][k2]=0;
  422. }
  423. }
  424. if(!flag)break;
  425. }
  426. if(flag)return mp(2,mp(i,j));
  427. x.g[i][j]=0;
  428. }
  429. }
  430. }
  431. vector< pair<int,int> > vec;
  432. for(int i=0;i<=row;i++)
  433. {
  434. for(int j=0;j<=col;j++)
  435. {
  436. if(!y1.g[i][j])
  437. {
  438. y1.g[i][j]=1;
  439. int flag=1;
  440. for(int k1=0;k1<=row;k1++)
  441. {
  442. for(int k2=0;k2<=col;k2++)
  443. {
  444. if(!y1.g[k1][k2]){
  445. y1.g[k1][k2]=2;
  446. if(check_four(y1).first!=1){
  447. flag=0;
  448. y1.g[k1][k2]=0;
  449. break;
  450. }else y1.g[k1][k2]=0;
  451. }
  452. }
  453. if(!flag)break;
  454. }
  455. if(flag)return mp(1,mp(i,j));
  456. //if(flag)s.push_back(mp(i,j));
  457. y1.g[i][j]=0;
  458. }
  459. }
  460. }
  461. vector< pair<int,int> >::iterator it;
  462. pair<int,int> ret;
  463. int minn=1e9+7;
  464. for(it=vec.begin();it!=vec.end();it++)
  465. {
  466. pair<int,int> k=*it;
  467. if((k.first-cal_centre(y2).first)+(k.second-cal_centre(y2).second)<minn)
  468. {
  469. minn=(k.first-cal_centre(y2).first)+(k.second-cal_centre(y2).second);
  470. ret=k;
  471. }
  472. }
  473. if(vec.size())return mp(1,ret);
  474. return mp(0,mp(0,0));
  475. }
  476. int check(chess x){
  477. if(check_win(x,1))return 1;
  478. else if(check_win(x,2))return 2;
  479. else return 0;
  480. }
  481. int check_win(chess x,int gamer){
  482. int cnt=0;
  483. for(int i=0;i<=row;i++)
  484. {
  485. cnt=0;
  486. for(int j=0;j<=col;j++)
  487. {
  488. if(x.g[i][j]==gamer)cnt++;
  489. else cnt=0;
  490. if(cnt>=5)return 1;
  491. }
  492. }
  493. cnt=0;
  494. for(int i=0;i<=col;i++)
  495. {
  496. cnt=0;
  497. for(int j=0;j<=row;j++)
  498. {
  499. if(x.g[j][i]==gamer)cnt++;
  500. else cnt=0;
  501. if(cnt>=5)return 1;
  502. }
  503. }
  504. cnt=0;
  505. for(int i=0;i<=row;i++)
  506. {
  507. cnt=0;
  508. int l=i,r=0;
  509. while(l<=col&&r<=col)
  510. {
  511. if(x.g[r][l]==gamer)cnt++;
  512. else cnt=0;
  513. if(cnt>=5)return 1;
  514. l++,r++;
  515. }
  516. }
  517. cnt=0;
  518. for(int i=0;i<=row;i++)
  519. {
  520. cnt=0;
  521. int l=i,r=0;
  522. while(l<=col&&r<=col)
  523. {
  524. if(x.g[l][r]==gamer)cnt++;
  525. else cnt=0;
  526. if(cnt>=5)return 1;
  527. l++,r++;
  528. }
  529. }
  530. cnt=0;
  531. for(int i=row;i>=0;i--)
  532. {
  533. cnt=0;
  534. int l=i,r=0;
  535. while(l>=0&&r<=col)
  536. {
  537. if(x.g[r][l]==gamer)cnt++;
  538. else cnt=0;
  539. if(cnt>=5)return 1;
  540. l--,r++;
  541. }
  542. }
  543. cnt=0;
  544. for(int i=0;i<=row;i++)
  545. {
  546. cnt=0;
  547. int l=i,r=col;
  548. while(l<=col&&r>=0)
  549. {
  550. if(x.g[l][r]==gamer)cnt++;
  551. else cnt=0;
  552. if(cnt>=5)return 1;
  553. l++,r--;
  554. }
  555. }
  556. return 0;
  557. }
  558. int is_terminal(chess x){
  559. for(int i=0;i<=row;i++)
  560. {
  561. for(int j=0;j<=col;j++)
  562. {
  563. if(!x.g[i][j])return 0;
  564. }
  565. }
  566. return 1;
  567. }
  568. int cnt_num(chess x,int x1,int x2,int y1,int y2){
  569. int sum=0;
  570. for(int i=x1;i<=x2;i++)
  571. {
  572. for(int j=y1;j<=y2;j++)
  573. {
  574. if(x.g[i][j])sum++;
  575. }
  576. }
  577. return sum;
  578. }
  579. int main()
  580. {
  581. init_window();
  582. return 0;
  583. }

可能的改进策略:

1.快速生成棋盘序列,以快速进行大量多次模拟

2.小范围可以使用博弈思想打表或者暴力查找必胜状态

3.使用并行计算

欢迎大佬萌帮忙改进!~

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/243656
推荐阅读
相关标签
  

闽ICP备14008679号