当前位置:   article > 正文

golang蒙特卡洛树算法实现五子棋AI

monte carlo tree search 五子棋

已经实现蒙特卡洛树算法的通用逻辑,只需要对应结构体实现相关接口就可以直接使用该算法。

优化算法主要优化GetActions生成下一步动作,要尽可能少,去掉无意义的动作。

以及优化ActionPolicy从众多动作挑选比较优秀的动作。对应五子棋就是执行该动作后当前局面评分最高。

  1. package main
  2. import (
  3. "fmt"
  4. "math"
  5. "math/rand"
  6. "strings"
  7. "time"
  8. )
  9. func main() {
  10. var (
  11. board = NewQuZiQi(15)
  12. x, y int
  13. )
  14. board.Print()
  15. for board.IsTerminal() == 0 {
  16. board = Search(time.Second*10, board).(*WuZiQi)
  17. board.Print()
  18. if board.IsTerminal() == 1 {
  19. fmt.Println("电脑赢了")
  20. return
  21. }
  22. for {
  23. fmt.Print("轮到您执棋,请输入坐标: ")
  24. _, _ = fmt.Scanln(&x, &y)
  25. x--
  26. y--
  27. if x < 0 || y < 0 || x >= board.size || y >= board.size {
  28. fmt.Println("您输入的数据超出棋盘范围")
  29. } else if board.board[x][y] > 0 {
  30. fmt.Println("该位置已有棋子")
  31. } else {
  32. board.board[x][y] = 2
  33. board.player = 1 // 下一步该电脑下
  34. break
  35. }
  36. }
  37. board.Print()
  38. if board.IsTerminal() == 2 {
  39. fmt.Println("你赢了")
  40. return
  41. }
  42. }
  43. }
  44. // WuZiQi 五子棋游戏
  45. type WuZiQi struct {
  46. size int // 棋盘大小
  47. board [][]int // 棋盘状态
  48. player int // 1: 电脑落子,2: 玩家落子
  49. }
  50. func NewQuZiQi(size int) *WuZiQi {
  51. w := &WuZiQi{
  52. size: size,
  53. board: make([][]int, size),
  54. player: 1,
  55. }
  56. for i := 0; i < size; i++ {
  57. w.board[i] = make([]int, size)
  58. }
  59. size /= 2
  60. // 默认中间落一个棋子
  61. // 0: 表示没有落子,1: 表示电脑,2: 表示玩家
  62. w.board[size][size] = 2
  63. return w
  64. }
  65. func (w *WuZiQi) Print() {
  66. var (
  67. str strings.Builder
  68. num = func(n int) {
  69. a, b := n/10, n%10
  70. if a > 0 {
  71. str.WriteByte(byte(a + '0'))
  72. } else {
  73. str.WriteByte(' ') // 1位数前面加空格
  74. }
  75. str.WriteByte(byte(b + '0'))
  76. }
  77. )
  78. str.WriteString(" ")
  79. for i := 1; i <= w.size; i++ {
  80. str.WriteByte(' ')
  81. num(i)
  82. }
  83. str.WriteByte('\n')
  84. for i := 0; i < w.size; i++ {
  85. str.WriteString(" ")
  86. for j := 0; j < w.size; j++ {
  87. str.WriteString(" __")
  88. }
  89. str.WriteByte('\n')
  90. num(i + 1)
  91. str.WriteByte(' ')
  92. for j := 0; j < w.size; j++ {
  93. str.WriteByte('|')
  94. switch w.board[i][j] {
  95. case 0:
  96. str.WriteByte(' ')
  97. case 1:
  98. str.WriteByte('O')
  99. case 2:
  100. str.WriteByte('X')
  101. }
  102. str.WriteByte(' ')
  103. }
  104. str.WriteString("|\n")
  105. }
  106. str.WriteString(" ")
  107. for i := 0; i < w.size; i++ {
  108. str.WriteString(" __")
  109. }
  110. fmt.Println(str.String())
  111. }
  112. func (w *WuZiQi) IsTerminal() int {
  113. full := -1 // 没有空位且都没赢
  114. for i := 0; i < w.size; i++ {
  115. for j := 0; j < w.size; j++ {
  116. if wc := w.board[i][j]; wc == 0 {
  117. full = 0 // 还有空位,没结束
  118. } else {
  119. // 向右
  120. cnt, x, y := 1, 0, j+1
  121. for ; y < w.size && w.board[i][y] == wc; y++ {
  122. cnt++
  123. }
  124. if cnt >= 5 {
  125. return wc
  126. }
  127. // 向下
  128. cnt, x = 1, i+1
  129. for ; x < w.size && w.board[x][j] == wc; x++ {
  130. cnt++
  131. }
  132. if cnt >= 5 {
  133. return wc
  134. }
  135. // 向右下
  136. cnt, x, y = 1, i+1, j+1
  137. for ; x < w.size && y < w.size && w.board[x][y] == wc; x, y = x+1, y+1 {
  138. cnt++
  139. }
  140. if cnt >= 5 {
  141. return wc
  142. }
  143. // 向左下
  144. cnt, x, y = 1, i+1, j-1
  145. for ; x < w.size && y >= 0 && w.board[x][y] == wc; x, y = x+1, y-1 {
  146. cnt++
  147. }
  148. if cnt >= 5 {
  149. return wc
  150. }
  151. }
  152. }
  153. }
  154. return full
  155. }
  156. func (w *WuZiQi) Result(state int) float64 {
  157. switch state {
  158. case -1:
  159. return 0 // 都没赢且没空位
  160. case 1:
  161. return -1 // 电脑赢了
  162. case 2:
  163. return +1 // 玩家赢了
  164. default:
  165. return 0 // 都没赢且有空位
  166. }
  167. }
  168. func (w *WuZiQi) GetActions() (res []any) {
  169. // todo 敌方上一步落子附近才是最优搜索范围
  170. // 某个落子必胜,则直接落子,如果某个落子让对手所有落子都必败则直接落子
  171. // 因此后续动作进一步缩小范围
  172. // 可以使用hash判断棋盘状态
  173. m := map[[2]int]struct{}{} // 用于去重
  174. for i := 0; i < w.size; i++ {
  175. for j := 0; j < w.size; j++ {
  176. if w.board[i][j] == 0 || w.board[i][j] == w.player {
  177. continue // 跳过空位和己方棋子
  178. }
  179. x0, x1, y0, y1 := i-2, i+2, j-2, j+2
  180. for ii := x0; ii < x1; ii++ {
  181. for jj := y0; jj < y1; jj++ {
  182. if ii >= 0 && jj >= 0 && ii < w.size && jj < w.size &&
  183. w.board[ii][jj] == 0 {
  184. p := [2]int{ii, jj}
  185. _, ok := m[p]
  186. if !ok {
  187. // 在棋子周围2格范围的空位加到结果中
  188. // 超过2格的空位落子的意义不大
  189. res = append(res, p)
  190. m[p] = struct{}{}
  191. }
  192. }
  193. }
  194. }
  195. }
  196. }
  197. return
  198. }
  199. func (w *WuZiQi) ActionPolicy(action []any) any {
  200. // 目前随机选一个动作,应该是好方案先选出来
  201. return action[rand.Intn(len(action))]
  202. }
  203. func (w *WuZiQi) Action(action any) TreeState {
  204. wn := &WuZiQi{
  205. size: w.size,
  206. board: make([][]int, w.size),
  207. player: 3 - w.player, // 切换电脑和玩家
  208. }
  209. for i := 0; i < w.size; i++ {
  210. wn.board[i] = make([]int, w.size)
  211. for j := 0; j < w.size; j++ {
  212. wn.board[i][j] = w.board[i][j]
  213. }
  214. }
  215. ac := action.([2]int) // 在该位置落子
  216. wn.board[ac[0]][ac[1]] = w.player
  217. return wn
  218. }
  219. // MonteCarloTree 下面是算法部分
  220. // 你的对象只需要提供TreeState所有接口,就可以直接使用
  221. // https://github.com/int8/monte-carlo-tree-search
  222. // https://blog.csdn.net/masterhero666/article/details/126325506
  223. type (
  224. TreeState interface {
  225. IsTerminal() int // 0: 未结束,其他为自定义状态
  226. Result(int) float64 // 计算分数,传入IsTerminal结果
  227. GetActions() []any // 获取所有合法动作, todo 考虑获取不到动作时如何处理
  228. ActionPolicy([]any) any // 按策略挑选一个动作
  229. Action(any) TreeState // 执行动作生成子节点
  230. }
  231. McTreeNode struct {
  232. parent *McTreeNode
  233. children []*McTreeNode
  234. score float64
  235. visitCount float64
  236. untriedActions []any
  237. nodeState TreeState
  238. }
  239. )
  240. func Search(simulate any, state TreeState, discount ...float64) TreeState {
  241. var (
  242. root = &McTreeNode{nodeState: state}
  243. leaf *McTreeNode
  244. dp = 1.4 // 折扣参数默认值
  245. )
  246. if len(discount) > 0 {
  247. dp = discount[0]
  248. }
  249. var loop func() bool
  250. switch s := simulate.(type) {
  251. case int:
  252. loop = func() bool {
  253. s-- // 模拟指定次数后退出
  254. return s >= 0
  255. }
  256. case time.Duration:
  257. ts := time.Now().Add(s) // 超过指定时间后退出
  258. loop = func() bool { return time.Now().Before(ts) }
  259. case func() bool:
  260. loop = s // 或者由外部指定模拟结束方案
  261. default:
  262. panic(simulate)
  263. }
  264. for loop() {
  265. leaf = root.treePolicy(dp)
  266. result, curState := 0, leaf.nodeState
  267. for {
  268. if result = curState.IsTerminal(); result != 0 {
  269. break // 结束状态
  270. }
  271. // 根据该节点状态生成所有合法动作
  272. all := curState.GetActions()
  273. // 按照某种策略选出1个动作,不同于expand的顺序取出
  274. one := curState.ActionPolicy(all)
  275. // 执行该动作,重复该过程,直到结束
  276. curState = curState.Action(one)
  277. }
  278. // 根据结束状态计算结果,将该结果反向传播
  279. leaf.backPropagate(curState.Result(result))
  280. }
  281. return root.chooseBestChild(dp).nodeState // 选择最优子节点
  282. }
  283. func (cur *McTreeNode) chooseBestChild(c float64) *McTreeNode {
  284. var (
  285. idx = 0
  286. maxValue = -math.MaxFloat64
  287. childValue float64
  288. )
  289. for i, child := range cur.children {
  290. childValue = (child.score / child.visitCount) +
  291. c*math.Sqrt(math.Log(cur.visitCount)/child.visitCount)
  292. if childValue > maxValue {
  293. maxValue = childValue
  294. idx = i // 选择分值最高的子节点
  295. }
  296. }
  297. return cur.children[idx]
  298. }
  299. func (cur *McTreeNode) backPropagate(result float64) {
  300. nodeCursor := cur
  301. for nodeCursor.parent != nil {
  302. nodeCursor.score += result
  303. nodeCursor.visitCount++ // 反向传播,增加访问次数,更新分数
  304. nodeCursor = nodeCursor.parent
  305. }
  306. nodeCursor.visitCount++
  307. }
  308. func (cur *McTreeNode) expand() *McTreeNode {
  309. res := cur.untriedActions[0] // 返回1个未经尝试动作
  310. cur.untriedActions = cur.untriedActions[1:]
  311. child := &McTreeNode{
  312. parent: cur, // 当前节点按顺序弹出1个动作,执行动作生成子节点
  313. nodeState: cur.nodeState.Action(res),
  314. }
  315. cur.children = append(cur.children, child)
  316. return child
  317. }
  318. func (cur *McTreeNode) treePolicy(discountParamC float64) *McTreeNode {
  319. nodeCursor := cur // 一直循环直到结束
  320. for nodeCursor.nodeState.IsTerminal() == 0 {
  321. if nodeCursor.untriedActions == nil {
  322. // 只会初始化1次,找出该节点所有动作
  323. nodeCursor.untriedActions = nodeCursor.nodeState.GetActions()
  324. }
  325. if len(nodeCursor.untriedActions) > 0 {
  326. return nodeCursor.expand() // 存在未处理动作则添加子节点
  327. }
  328. // 处理完动作,选择最好子节点继续往下处理
  329. nodeCursor = nodeCursor.chooseBestChild(discountParamC)
  330. }
  331. return nodeCursor
  332. }
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/612523
推荐阅读
相关标签
  

闽ICP备14008679号