当前位置:   article > 正文

一个基于蒙特卡洛搜索树的五子棋实现

一个基于蒙特卡洛搜索树的五子棋实现

          最近有点烦啊,也有点无聊,去年研究德州扑克失败,后面知道AlphaZero都用了蒙特卡洛搜索树,那估计俺方向错误了?如是准备学习下这个东东,为深度学习攻克德州扑克做技术准备工作。这个东东理论上的介绍网络上实在是太多了,大部分也没有什么问题。但没有代码的实现的东西,感觉不是踏实,不靠谱。我想用什么方法来验证下我是否真正理解了这个东西了,那就做一个地球人都知道的五子棋来验证我的对这个算法的理解吧!

                我从网上把理论介绍最最关键的部分摘录如下(有我自己的修改)

    对于MCTS的树结构,如果是最简单的方法,只需要在节点上保存状态对应的历史胜负记录。在每条边上保存采样的动作。这样MCTS的搜索需要走4步,如下图(图来自维基百科):

    第一步是选择(Selection):这一步会从根节点开始,每次都选一个“最值得搜索的子节点”,一般使用UCT选择分数最高的节点,直到来到一个“存在未扩展的子节点”的节点,如图中的 3/3 节点。之所以叫做“存在未扩展的子节点”,是因为这个局面存在未走过的后续着法,也就是MCTS中没有后续的动作可以参考了。这时我们进入第二步。

          本人标注:选择过程中,因为如果第一层是对方,那么第二个层就是自己,接下来第三层就是对方,在选择的过程都是选择比值最大的点(假设对手也是很厉害啊)

    第二步是扩展(Expansion),在这个搜索到的存在未扩展的子节点,加上一个0/0的子节点,表示没有历史记录参考。这时我们进入第三步。

      本人标注:扩展的过程中要记得你动作不能重复(对五子棋就是不能再下已经有落点的地方),我是一次扩展完毕。

    第三步是仿真(simulation),从上面这个没有试过的着法开始,用一个简单策略比如快速走子策略(Rollout policy)走到底,得到一个胜负结果。快速走子策略一般适合选择走子很快可能不是很精确的策略。因为如果这个策略走得慢,结果虽然会更准确,但由于耗时多了,在单位时间内的模拟次数就少了,所以不一定会棋力更强,有可能会更弱。这也是为什么我们一般只模拟一次,因为如果模拟多次,虽然更准确,但更慢。

     本人标注:模拟过程,完全是随机走,走到底。

    第四步是回溯(backpropagation), 将我们最后得到的胜负结果回溯加到MCTS树结构上。注意除了之前的MCTS树要回溯外,新加入的节点也要加上一次胜负历史记录,如上图最右边所示。

      本人标注:赢了加1,输掉子增加访问节点数量

 

本来想用Python来实现,但考虑到用界面C# WinForm方便,就放弃了。到时候项目进展的时候再改成Python。代码如下,如果大家有问题欢迎指正,交流!

 

  1. using System;
  2. using System.Collections.Generic;
  3. using System.ComponentModel;
  4. using System.Data;
  5. using System.Drawing;
  6. using System.Linq;
  7. using System.Text;
  8. using System.Threading;
  9. using System.Threading.Tasks;
  10. using System.Windows.Forms;
  11. namespace WindowsFormsApp1
  12. {
  13. public partial class Form1 : Form
  14. {
  15. public AutoResetEvent waitEvent = new AutoResetEvent(false);
  16. const int Row =8;
  17. const int Col =8;
  18. short[,] m_board = new short[Row, Col];
  19. short[,] m_boardSearch = new short[Row, Col];
  20. public class Pos
  21. {
  22. public short x;
  23. public short y;
  24. public Button btn=null;
  25. public Pos(short ax, short ay)
  26. {
  27. x = ax;
  28. y = ay;
  29. }
  30. }
  31. class TreeNode
  32. {
  33. public TreeNode parentNode;
  34. public List<TreeNode> m_ChildNodes=new List<TreeNode>();
  35. public double m_win_count = 0;
  36. public double m_visit_count = 0;
  37. public Pos m_pos;
  38. public bool IsAI;
  39. }
  40. public Form1()
  41. {
  42. InitializeComponent();
  43. CreateBoardUI();
  44. }
  45. Dictionary<string, Pos> xy_pos_dic = new Dictionary<string, Pos>();
  46. void CreateBoardUI()
  47. {
  48. int x_s=100, y_s=100;
  49. for(byte i=0;i<Row;++i)
  50. for(byte j =0;j<Col;++j)
  51. {
  52. Button btn = new Button();
  53. Pos p= new Pos(i, j);
  54. p.btn = btn;
  55. btn.Tag = p;
  56. btn.Left = x_s+i*20;
  57. btn.Top = y_s+j*20;
  58. btn.Width = 20;
  59. btn.Height = 20;
  60. btn.Click += Btn_Click;
  61. this.Controls.Add(btn);
  62. xy_pos_dic.Add(i + "-" + j, p);
  63. }
  64. }
  65. private void Btn_Click(object sender, EventArgs e)
  66. {
  67. if (m_AI_thinking)
  68. {
  69. MessageBox.Show("AI正在搜索,稍后!");
  70. return;
  71. }
  72. Button btn = (Button)sender;
  73. Pos p =(Pos) btn.Tag;
  74. if (m_board[p.x, p.y] == 0)
  75. m_board[p.x, p.y] = 2;
  76. else
  77. return;
  78. btn.Text = "2";
  79. p.btn.ForeColor = Color.Red;
  80. // if (GameOver(m_board, 2))
  81. // MessageBox.Show("OK");
  82. waitEvent.Set();
  83. }
  84. void BackUp(bool win, bool onlyvisit, TreeNode leafnode)
  85. {
  86. ++leafnode.m_visit_count;
  87. if (!onlyvisit)
  88. {
  89. if (win)
  90. ++leafnode.m_win_count;
  91. //else --leafnode.m_win_count;
  92. }
  93. if (leafnode.parentNode != null)
  94. BackUp(!win, onlyvisit, leafnode.parentNode);
  95. }
  96. int StartSimulate(TreeNode leafNode)
  97. {
  98. while (true)
  99. {
  100. Pos p = GetAnEmptyPos();
  101. if (p.x == -1)
  102. return 0;
  103. if (leafNode.IsAI)
  104. m_boardSearch[p.x, p.y] = 1;
  105. else
  106. m_boardSearch[p.x, p.y] = 2;
  107. p = GetAnEmptyPos();
  108. if (p.x == -1)
  109. return 0;
  110. if (leafNode.IsAI)
  111. m_boardSearch[p.x, p.y] = 2;
  112. else
  113. m_boardSearch[p.x, p.y] = 1;
  114. if (GameOver(m_boardSearch, 1))
  115. return 1;
  116. else if (GameOver(m_boardSearch, 2))
  117. return 2;
  118. }
  119. }
  120. bool GameOver(short[,] board, short who)
  121. {
  122. int h = 0, v = 0, p = 0, l = 0;
  123. for (int i = 0; i < Row; i++)//m
  124. for (int j = 0; j < Col; j++) //n
  125. {
  126. if (board[i, j] == who)
  127. {
  128. h = 1; v = 1; p = 1;
  129. for (int m = j + 1; m < Col; m++) // h方向-
  130. {
  131. if (board[i, m] == who)
  132. {
  133. h += 1;
  134. if (h >= 5)
  135. return true;// "—"
  136. }
  137. else
  138. h = 0;
  139. }
  140. for (int m = i + 1; m < Row; m++)
  141. {
  142. // V方向|
  143. if (board[m, j] == who)
  144. {
  145. v += 1;
  146. if (v >= 5)
  147. return true;//, "|"
  148. }
  149. else
  150. v = 0;
  151. }
  152. }
  153. }
  154. Dictionary<int, int> kb = new Dictionary<int, int>();
  155. for (int i = -Col; i < Col; i++)// '/'
  156. kb.Add(i, -i);
  157. foreach (var k in kb.Keys)
  158. {//one line
  159. p = 0;
  160. for (int x = 0; x < Col; x++)
  161. {
  162. int y =x + kb[k];
  163. if (x >= 0 && x < Col
  164. && y >= 0 && y < Row)
  165. {
  166. if (board[x, y] == who)
  167. {
  168. ++p;
  169. if (p >= 5)
  170. return true;
  171. }
  172. else
  173. p = 0;
  174. }
  175. }
  176. }
  177. kb.Clear();
  178. for (int i =0; i < Col*2; i++)// '\'
  179. kb.Add(-i, i);
  180. foreach (var k in kb.Keys)
  181. {//one line
  182. l = 0;
  183. for (int x = 0; x < Col; x++)
  184. {
  185. int y = -x + kb[k];
  186. if (x >= 0 && x < Col
  187. && y >= 0 && y < Row)
  188. {
  189. if (board[x, y] == who)
  190. {
  191. ++l;
  192. if (l >= 5)
  193. return true;
  194. }
  195. else
  196. l = 0;
  197. }
  198. }
  199. }
  200. return false;
  201. }
  202. TreeNode ExpandNodeOld(TreeNode node)//扩展,模拟
  203. {
  204. for (int i = 0; i <= 80; ++i)//数量如何控制
  205. {
  206. TreeNode oneNode = new TreeNode();
  207. oneNode.m_pos = GetAnEmptyPos();
  208. if (oneNode.m_pos.x == -1)
  209. break;
  210. m_boardSearch[oneNode.m_pos.x, oneNode.m_pos.y] = 1;
  211. oneNode.parentNode = node;
  212. if (node.m_ChildNodes == null)
  213. node.m_ChildNodes = new List<TreeNode>();
  214. node.m_ChildNodes.Add(oneNode);
  215. }
  216. return node.m_ChildNodes.Count > 0 ? node.m_ChildNodes[0] : null;
  217. }
  218. TreeNode ExpandNode(TreeNode node)//扩展,模拟
  219. {
  220. for (short i = 0; i < Row; ++i)
  221. for (short j = 0; j < Col;++j)
  222. if (m_boardSearch[i, j] == 0)
  223. {
  224. TreeNode oneNode = new TreeNode();
  225. oneNode.IsAI = !node.IsAI;
  226. oneNode.m_pos = xy_pos_dic[i + "-" + j];
  227. m_boardSearch[i,j] = 1;
  228. oneNode.parentNode = node;
  229. node.m_ChildNodes.Add(oneNode);
  230. }
  231. return node.m_ChildNodes.Count > 0 ? node.m_ChildNodes[0] : null;
  232. }
  233. Random m_rnd = new Random();
  234. Pos GetAnEmptyPos(int x1, int x2, int y1, int y2)
  235. {
  236. short x = -1, y = -1;
  237. bool cando = false;
  238. for (int i = x1; i < x2; i++)
  239. {
  240. for (int j = y1; j < y2; j++)
  241. {
  242. if (m_boardSearch[i, j] == 0)
  243. {
  244. cando = true;
  245. break;
  246. }
  247. }
  248. if (cando)
  249. break;
  250. }
  251. if(!cando)
  252. return new Pos(-1, -1);
  253. while (true)
  254. {
  255. var res = from m in m_boardSearch.Cast<short>() where m == 0 select m;
  256. if (res.Count() == 0)
  257. {
  258. return new Pos(-1, -1);
  259. }
  260. x = (short)m_rnd.Next(x1, x2);
  261. y = (short)m_rnd.Next(y1, y2);
  262. if (m_boardSearch[x, y] == 0)
  263. break;
  264. }
  265. return xy_pos_dic[x + "-" + y];
  266. }
  267. Pos GetAnEmptyPos()
  268. {
  269. short x = -1, y = -1;
  270. var res = from m in m_boardSearch.Cast<short>() where m == 0 select m;
  271. if (res.Count() == 0)
  272. {
  273. return new Pos(-1, -1);
  274. }
  275. Random m_rnd = new Random();
  276. while (true)
  277. {
  278. x = (short)m_rnd.Next(0, Row);
  279. y = (short)m_rnd.Next(0, Col);
  280. if (m_boardSearch[x, y] == 0)
  281. break;
  282. }
  283. return xy_pos_dic[x + "-" + y];
  284. }
  285. TreeNode mcts_select(bool isMe,TreeNode pNode)//选择少叶子节点再扩展
  286. {
  287. if(pNode.m_ChildNodes.Count==0)
  288. {
  289. return pNode;
  290. }
  291. double max_score = 0;
  292. bool isFirst = true;
  293. TreeNode bestNode = null;
  294. List<TreeNode> canSelectList = new List<TreeNode>();
  295. foreach (var node in pNode.m_ChildNodes)
  296. {//探索最优点
  297. if(node.m_visit_count==0)//优先选择未采用过的节点
  298. {
  299. bestNode = node;
  300. canSelectList.Clear();
  301. canSelectList.Add(node);
  302. break;
  303. }
  304. double score = node.m_win_count / node.m_visit_count + 1.414 * Math.Sqrt(Math.Log(pNode.m_visit_count) / node.m_visit_count);
  305. score = Math.Round(score, 4);
  306. // if (isMe)
  307. {
  308. if (max_score < score || isFirst)
  309. {
  310. max_score = score;
  311. bestNode = node;
  312. canSelectList.Clear();
  313. canSelectList.Add(bestNode);
  314. isFirst = false;
  315. }
  316. else if (max_score == score)
  317. {
  318. canSelectList.Add(node);
  319. }
  320. }
  321. //else
  322. //{
  323. // if (max_score > score || isFirst)
  324. // {
  325. // max_score = score;
  326. // canSelectList.Clear();
  327. // bestNode = node;
  328. // canSelectList.Add(bestNode);
  329. // isFirst = false;
  330. // }
  331. // else if (max_score == score)
  332. // canSelectList.Add(node);
  333. //}
  334. }
  335. bestNode = canSelectList[m_rnd.Next(0, canSelectList.Count)];//这里导致expand失败。
  336. m_boardSearch[bestNode.m_pos.x, bestNode.m_pos.y] = isMe ?(short) 1 : (short)2;
  337. return mcts_select(!isMe, bestNode);
  338. }
  339. void GotoNext(TreeNode rootNode)
  340. {
  341. TreeNode bestNode=null;
  342. foreach(var node in rootNode.m_ChildNodes)
  343. {
  344. if (bestNode == null)
  345. bestNode = node;
  346. else if ( bestNode.m_visit_count < node.m_visit_count)
  347. // else if( bestNode.m_win_count/ bestNode.m_visit_count < node.m_win_count/ node.m_visit_count)
  348. bestNode = node;
  349. }
  350. if (bestNode.m_pos.x <= 1 || bestNode.m_pos.x >= 7)
  351. Console.WriteLine("aaaa");
  352. if (bestNode.m_pos.y <= 1 || bestNode.m_pos.y >= 7)
  353. Console.WriteLine("bbbb");
  354. m_board[bestNode.m_pos.x, bestNode.m_pos.y] = 1;
  355. UpdateBoardUI(bestNode.m_pos);
  356. }
  357. private void FillBoard(bool isMe, TreeNode leafNode)
  358. {
  359. m_boardSearch[leafNode.m_pos.x, leafNode.m_pos.y] = isMe ? (short)1 : (short)2;
  360. if (leafNode.parentNode != null
  361. && leafNode.parentNode.m_pos!=null)
  362. FillBoard(!isMe, leafNode.parentNode);
  363. }
  364. //private void FillBoard(bool isMe, TreeNode leafNode)
  365. //{
  366. // if (leafNode.parentNode == null)
  367. // return;
  368. // m_boardSearch[leafNode.m_pos.x, leafNode.m_pos.y] = isMe ? (short)1 : (short)2;
  369. // FillBoard(!isMe, leafNode.parentNode);
  370. //}
  371. private void StartBtn_Click(object sender, EventArgs e)
  372. {
  373. for (int i = 0; i < Row; ++i)
  374. for (int j = 0; j < Col; ++j)
  375. m_board[i, j] = 0;
  376. foreach (var one in xy_pos_dic.Values)
  377. one.btn.Text = "";
  378. Thread th = new Thread(this.AIMainThread);
  379. th.IsBackground = true;
  380. th.Start();
  381. }
  382. private bool m_AI_thinking = false;
  383. private void AIMainThread()
  384. {
  385. waitEvent.Reset();//后手
  386. while (true)
  387. {
  388. waitEvent.WaitOne();
  389. if (GameOver(m_board, 2))
  390. {
  391. UpdateMsg("恭喜,你赢啦!");
  392. break;
  393. }
  394. m_AI_thinking = true;
  395. UpdateMsg("AI思考中...");
  396. TreeNode rootNode = new TreeNode();
  397. rootNode.m_visit_count = 1;
  398. rootNode.IsAI = false;
  399. int count = 9000;// 1500+9655;//
  400. for (int i = 0; i < Row; ++i)
  401. for (int j = 0; j < Col; ++j)
  402. m_boardSearch[i, j] = m_board[i, j];
  403. DateTime dtStart = DateTime.Now;
  404. // while (--count > 0)//模拟次数
  405. //while((--count > 0))//||
  406. while((DateTime.Now- dtStart).TotalMinutes<=0.33)
  407. {
  408. // Console.WriteLine("\n1 " + DateTime.Now + ":" + DateTime.Now.Millisecond);
  409. for (int i = 0; i < Row; ++i)
  410. for (int j = 0; j < Col; ++j)
  411. m_boardSearch[i, j] = m_board[i, j];
  412. TreeNode leafNode = mcts_select(true,rootNode);//在m_boardSearch copy m_board的基础上模拟,一条直线下来。
  413. // Console.WriteLine("\n2 " + DateTime.Now + ":" + DateTime.Now.Millisecond);
  414. TreeNode expNode = leafNode.m_ChildNodes.Count == 0&& leafNode.m_visit_count>0 ? ExpandNodeOld(leafNode) : leafNode;
  415. if (expNode == null)
  416. break;
  417. for (int i = 0; i < Row; ++i)
  418. for (int j = 0; j < Col; ++j)
  419. m_boardSearch[i, j] = m_board[i, j];
  420. // Console.WriteLine("\n3 " + DateTime.Now + ":" + DateTime.Now.Millisecond);
  421. FillBoard(true, expNode);//回填下棋步骤(直线回传)
  422. // Console.WriteLine("\n4 " + DateTime.Now + ":" + DateTime.Now.Millisecond);
  423. int res = StartSimulate(expNode);
  424. // Console.WriteLine("\n5 " + DateTime.Now + ":" + DateTime.Now.Millisecond);
  425. BackUp(res == 1, res == 0, expNode);
  426. // Console.WriteLine("\n6 " + DateTime.Now + ":" + DateTime.Now.Millisecond);
  427. }
  428. GotoNext(rootNode);//真正选择最佳策略.
  429. if (GameOver(m_board, 1))
  430. {
  431. UpdateMsg("AI赢啦!");
  432. break;
  433. }
  434. if (GameOver(m_board, 2))
  435. {
  436. UpdateMsg("恭喜,你赢啦!");
  437. break;
  438. }
  439. UpdateMsg("请开始选择...");
  440. m_AI_thinking = false;
  441. }
  442. m_AI_thinking = false;
  443. }
  444. public void UpdateMsg(string str)
  445. {
  446. this.BeginInvoke(new Action(() =>
  447. {
  448. SYlabel.Text = str;
  449. }));
  450. }
  451. public void UpdateBoardUI(Pos p)
  452. {
  453. this.BeginInvoke(new Action(() =>
  454. {
  455. p.btn.Text = "1";
  456. p.btn.ForeColor = Color.Black;
  457. p.btn.Focus();
  458. }));
  459. }
  460. }
  461. }

可以看到在8*8 ,10*10 ,搜索时间在20秒内,AI基本能下棋正确!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/612510
推荐阅读
相关标签
  

闽ICP备14008679号