已经实现蒙特卡洛树算法的通用逻辑,只需要对应结构体实现相关接口就可以直接使用该算法。
优化算法主要优化GetActions
生成下一步动作,要尽可能少,去掉无意义的动作。
以及优化ActionPolicy
从众多动作挑选比较优秀的动作。对应五子棋就是执行该动作后当前局面评分最高。
- package main
-
- import (
- "fmt"
- "math"
- "math/rand"
- "strings"
- "time"
- )
-
- func main() {
- var (
- board = NewQuZiQi(15)
- x, y int
- )
-
- board.Print()
- for board.IsTerminal() == 0 {
- board = Search(time.Second*10, board).(*WuZiQi)
-
- board.Print()
- if board.IsTerminal() == 1 {
- fmt.Println("电脑赢了")
- return
- }
-
- for {
- fmt.Print("轮到您执棋,请输入坐标: ")
- _, _ = fmt.Scanln(&x, &y)
- x--
- y--
- if x < 0 || y < 0 || x >= board.size || y >= board.size {
- fmt.Println("您输入的数据超出棋盘范围")
- } else if board.board[x][y] > 0 {
- fmt.Println("该位置已有棋子")
- } else {
- board.board[x][y] = 2
- board.player = 1 // 下一步该电脑下
- break
- }
- }
-
- board.Print()
- if board.IsTerminal() == 2 {
- fmt.Println("你赢了")
- return
- }
- }
- }
-
- // WuZiQi 五子棋游戏
- type WuZiQi struct {
- size int // 棋盘大小
- board [][]int // 棋盘状态
- player int // 1: 电脑落子,2: 玩家落子
- }
-
- func NewQuZiQi(size int) *WuZiQi {
- w := &WuZiQi{
- size: size,
- board: make([][]int, size),
- player: 1,
- }
- for i := 0; i < size; i++ {
- w.board[i] = make([]int, size)
- }
- size /= 2
- // 默认中间落一个棋子
- // 0: 表示没有落子,1: 表示电脑,2: 表示玩家
- w.board[size][size] = 2
- return w
- }
-
- func (w *WuZiQi) Print() {
- var (
- str strings.Builder
- num = func(n int) {
- a, b := n/10, n%10
- if a > 0 {
- str.WriteByte(byte(a + '0'))
- } else {
- str.WriteByte(' ') // 1位数前面加空格
- }
- str.WriteByte(byte(b + '0'))
- }
- )
- str.WriteString(" ")
- for i := 1; i <= w.size; i++ {
- str.WriteByte(' ')
- num(i)
- }
- str.WriteByte('\n')
- for i := 0; i < w.size; i++ {
- str.WriteString(" ")
- for j := 0; j < w.size; j++ {
- str.WriteString(" __")
- }
-
- str.WriteByte('\n')
- num(i + 1)
- str.WriteByte(' ')
-
- for j := 0; j < w.size; j++ {
- str.WriteByte('|')
- switch w.board[i][j] {
- case 0:
- str.WriteByte(' ')
- case 1:
- str.WriteByte('O')
- case 2:
- str.WriteByte('X')
- }
- str.WriteByte(' ')
- }
- str.WriteString("|\n")
- }
- str.WriteString(" ")
- for i := 0; i < w.size; i++ {
- str.WriteString(" __")
- }
- fmt.Println(str.String())
- }
-
- func (w *WuZiQi) IsTerminal() int {
- full := -1 // 没有空位且都没赢
- for i := 0; i < w.size; i++ {
- for j := 0; j < w.size; j++ {
- if wc := w.board[i][j]; wc == 0 {
- full = 0 // 还有空位,没结束
- } else {
- // 向右
- cnt, x, y := 1, 0, j+1
- for ; y < w.size && w.board[i][y] == wc; y++ {
- cnt++
- }
- if cnt >= 5 {
- return wc
- }
- // 向下
- cnt, x = 1, i+1
- for ; x < w.size && w.board[x][j] == wc; x++ {
- cnt++
- }
- if cnt >= 5 {
- return wc
- }
- // 向右下
- cnt, x, y = 1, i+1, j+1
- for ; x < w.size && y < w.size && w.board[x][y] == wc; x, y = x+1, y+1 {
- cnt++
- }
- if cnt >= 5 {
- return wc
- }
- // 向左下
- cnt, x, y = 1, i+1, j-1
- for ; x < w.size && y >= 0 && w.board[x][y] == wc; x, y = x+1, y-1 {
- cnt++
- }
- if cnt >= 5 {
- return wc
- }
- }
- }
- }
- return full
- }
-
- func (w *WuZiQi) Result(state int) float64 {
- switch state {
- case -1:
- return 0 // 都没赢且没空位
- case 1:
- return -1 // 电脑赢了
- case 2:
- return +1 // 玩家赢了
- default:
- return 0 // 都没赢且有空位
- }
- }
-
- func (w *WuZiQi) GetActions() (res []any) {
- // todo 敌方上一步落子附近才是最优搜索范围
- // 某个落子必胜,则直接落子,如果某个落子让对手所有落子都必败则直接落子
- // 因此后续动作进一步缩小范围
- // 可以使用hash判断棋盘状态
-
- m := map[[2]int]struct{}{} // 用于去重
- for i := 0; i < w.size; i++ {
- for j := 0; j < w.size; j++ {
- if w.board[i][j] == 0 || w.board[i][j] == w.player {
- continue // 跳过空位和己方棋子
- }
-
- x0, x1, y0, y1 := i-2, i+2, j-2, j+2
- for ii := x0; ii < x1; ii++ {
- for jj := y0; jj < y1; jj++ {
- if ii >= 0 && jj >= 0 && ii < w.size && jj < w.size &&
- w.board[ii][jj] == 0 {
-
- p := [2]int{ii, jj}
- _, ok := m[p]
- if !ok {
- // 在棋子周围2格范围的空位加到结果中
- // 超过2格的空位落子的意义不大
- res = append(res, p)
- m[p] = struct{}{}
- }
- }
- }
- }
- }
- }
- return
- }
-
- func (w *WuZiQi) ActionPolicy(action []any) any {
- // 目前随机选一个动作,应该是好方案先选出来
- return action[rand.Intn(len(action))]
- }
-
- func (w *WuZiQi) Action(action any) TreeState {
- wn := &WuZiQi{
- size: w.size,
- board: make([][]int, w.size),
- player: 3 - w.player, // 切换电脑和玩家
- }
- for i := 0; i < w.size; i++ {
- wn.board[i] = make([]int, w.size)
- for j := 0; j < w.size; j++ {
- wn.board[i][j] = w.board[i][j]
- }
- }
-
- ac := action.([2]int) // 在该位置落子
- wn.board[ac[0]][ac[1]] = w.player
- return wn
- }
-
- // MonteCarloTree 下面是算法部分
- // 你的对象只需要提供TreeState所有接口,就可以直接使用
- // https://github.com/int8/monte-carlo-tree-search
- // https://blog.csdn.net/masterhero666/article/details/126325506
- type (
- TreeState interface {
- IsTerminal() int // 0: 未结束,其他为自定义状态
- Result(int) float64 // 计算分数,传入IsTerminal结果
- GetActions() []any // 获取所有合法动作, todo 考虑获取不到动作时如何处理
- ActionPolicy([]any) any // 按策略挑选一个动作
- Action(any) TreeState // 执行动作生成子节点
- }
-
- McTreeNode struct {
- parent *McTreeNode
- children []*McTreeNode
- score float64
- visitCount float64
- untriedActions []any
- nodeState TreeState
- }
- )
-
- func Search(simulate any, state TreeState, discount ...float64) TreeState {
- var (
- root = &McTreeNode{nodeState: state}
- leaf *McTreeNode
- dp = 1.4 // 折扣参数默认值
- )
- if len(discount) > 0 {
- dp = discount[0]
- }
-
- var loop func() bool
- switch s := simulate.(type) {
- case int:
- loop = func() bool {
- s-- // 模拟指定次数后退出
- return s >= 0
- }
- case time.Duration:
- ts := time.Now().Add(s) // 超过指定时间后退出
- loop = func() bool { return time.Now().Before(ts) }
- case func() bool:
- loop = s // 或者由外部指定模拟结束方案
- default:
- panic(simulate)
- }
-
- for loop() {
- leaf = root.treePolicy(dp)
-
- result, curState := 0, leaf.nodeState
- for {
- if result = curState.IsTerminal(); result != 0 {
- break // 结束状态
- }
-
- // 根据该节点状态生成所有合法动作
- all := curState.GetActions()
- // 按照某种策略选出1个动作,不同于expand的顺序取出
- one := curState.ActionPolicy(all)
- // 执行该动作,重复该过程,直到结束
- curState = curState.Action(one)
- }
-
- // 根据结束状态计算结果,将该结果反向传播
- leaf.backPropagate(curState.Result(result))
- }
- return root.chooseBestChild(dp).nodeState // 选择最优子节点
- }
-
- func (cur *McTreeNode) chooseBestChild(c float64) *McTreeNode {
- var (
- idx = 0
- maxValue = -math.MaxFloat64
- childValue float64
- )
- for i, child := range cur.children {
- childValue = (child.score / child.visitCount) +
- c*math.Sqrt(math.Log(cur.visitCount)/child.visitCount)
- if childValue > maxValue {
- maxValue = childValue
- idx = i // 选择分值最高的子节点
- }
- }
- return cur.children[idx]
- }
-
- func (cur *McTreeNode) backPropagate(result float64) {
- nodeCursor := cur
- for nodeCursor.parent != nil {
- nodeCursor.score += result
- nodeCursor.visitCount++ // 反向传播,增加访问次数,更新分数
- nodeCursor = nodeCursor.parent
- }
- nodeCursor.visitCount++
- }
-
- func (cur *McTreeNode) expand() *McTreeNode {
- res := cur.untriedActions[0] // 返回1个未经尝试动作
- cur.untriedActions = cur.untriedActions[1:]
-
- child := &McTreeNode{
- parent: cur, // 当前节点按顺序弹出1个动作,执行动作生成子节点
- nodeState: cur.nodeState.Action(res),
- }
- cur.children = append(cur.children, child)
- return child
- }
-
- func (cur *McTreeNode) treePolicy(discountParamC float64) *McTreeNode {
- nodeCursor := cur // 一直循环直到结束
- for nodeCursor.nodeState.IsTerminal() == 0 {
- if nodeCursor.untriedActions == nil {
- // 只会初始化1次,找出该节点所有动作
- nodeCursor.untriedActions = nodeCursor.nodeState.GetActions()
- }
- if len(nodeCursor.untriedActions) > 0 {
- return nodeCursor.expand() // 存在未处理动作则添加子节点
- }
- // 处理完动作,选择最好子节点继续往下处理
- nodeCursor = nodeCursor.chooseBestChild(discountParamC)
- }
- return nodeCursor
- }