赞
踩
通过定义EightPuzzle类来实现八数码问题中的以下关键问题:1.两种估价函数计算方法(heuristic1和heuristic2);2.A*搜索算法函数;3.子节点的创建;4.路径的构建。
通过定义Node类来定义八数码中出现的节点数据类型Node,Node存放当前状态、父节点、当前节点的g值和f值(节点已有代价g(n)和总计代价f(n),f(n) = g(n) + h(n))等,同时在Node类中定义方法it来辅助判断该Node在优先队列中的位置
对于A*算法的实现,通过定义优先队列open[]表,将节点Node中的代价估计值(f)最小的节点放在表头,越小的节点放的越靠前。对于open表每个节点,每次弹出第一个节点(即f值最小的节点)并将其与最终结果(goal元组)进行比较,比较结果为True则找到结果并计算其路径(extract_path);若比较结果为False则将该节点放入closed表中,记录为已经扩展过此状态,并将该节点放入neighbours()函数中得到不同的子节点。对于每个子节点,都将其与closed表中状态进行比较,若不在closed表中,则按计算子节点f值大小并升序按插入open[]优先队列中,等待被扩展;若在closed表中,则视为重复扩展状态,不加入open表。
本程序使用两种不同的估价函数:曼哈顿距离和海明距离,分别计算两个点在标准坐标系上的绝对坐标轴距离和不在位的数的个数。当使用BFS搜索时,令启发函数的值始终为0即可使优先队列退化为普通先进先出队列,此时搜索算法扩展节点的顺序即为层次顺序(树形结构层次顺序)
- import heapq
- import time
-
- class EightPuzzle:
- def __init__(self, init_state, goal_state):
- self.init_state = init_state
- self.goal_state = goal_state
- self.actions = [(1, 0), (0, 1), (-1, 0), (0, -1)]
-
- def heuristic1(self, node): # 计算给定节点与目标状态之间的汉明距离,即计算状态中每个元素与目标状态相应位置元素不一致的数量
- misplaced = 0
- for i in range(9):
- if node.state[i] != self.goal_state[i]:
- misplaced += 1
- return misplaced
-
- def heuristic2(self, node): # 计算给定节点与目标状态之间的曼哈顿距离,即计算状态中每个非零元素到其在目标状态中对应位置的水平和垂直距离之和
- dist = 0
- for i in range(9):
- x1, y1 = i // 3, i % 3
- if node.state[i] == 0:
- continue
- x2, y2 = (node.state[i] - 1) // 3, (node.state[i] - 1) % 3
- dist += abs(x1 - x2) + abs(y1 - y2)
- return dist
-
- def neighbors(self, node): # 遍历给定节点的所有可能移动(根据self.actions),生成新的状态作为邻居节点。新状态通过交换当前位置(空格所在位置)与相邻位置的元素得到
- state = node.state
- x, y = state.index(0) // 3, state.index(0) % 3
- for dx, dy in self.actions:
- new_x, new_y = x + dx, y + dy
- if 0 <= new_x <= 2 and 0 <= new_y <= 2:
- new_state = list(state)
- new_state[x * 3 + y], new_state[new_x * 3 + new_y] = state[new_x * 3 + new_y], state[x * 3 + y] #
- # 交换当前位置与相邻位置的元素
- yield Node(tuple(new_state), node) # 将新状态作为邻居节点
-
- def astar_search(self, heuristic): # A*搜索算法,给出初始状态和目标状态,返回路径
- start = Node(self.init_state, None)
- end = Node(self.goal_state, None)
- openlist = [start] # open 表
- closedlist = set() # closed 表
- steps = 0 # 扩展节点数
- nodes = 1 # 生成节点数
-
- while openlist:
- node = heapq.heappop(openlist) # 使用优先队列,从open表弹出一个节点,该节点是f值最小的节点
- if node.state == end.state: # 判断是否找到目标状态
- print('States Expanded:', steps)
- print('Nodes Generated:', nodes)
- return self.extract_path(node) # 返回路径
-
- closedlist.add(node.state) # 将当前节点加入closed表
- steps += 1
- for child in self.neighbors(node):
- if child.state not in closedlist: # 子节点不在closed表内,将其加入open表
- nodes += 1 # 生成节点数加1
- cost = node.g + 1 # 计算当前子节点的g值,即父节点的g值加1
- heuristic_cost = heuristic(child) # 计算当前节点的h值,即当前节点与目标状态之间的汉明距离或曼哈顿距离
- child.g = cost
- child.f = child.g + heuristic_cost
- child.parent = node
- heapq.heappush(openlist, child) # 将子节点加入open表
- return None
-
- def extract_path(self, end): # 提取路径
- path = []
- node = end # 从end节点开始,不断向前回溯,直到到达根节点
- while node.parent is not None: # 将节点加入路径列表
- path.append(node.state)
- node = node.parent
- path.reverse() # 将路径列表反转,以便从根节点到目标节点的顺序
- return path
-
-
- class Node:
- def __init__(self, state, parent):
- self.state = state
- self.parent = parent
- self.g = 0
- self.f = 0
-
- def __lt__(self, other): # __lt__方法,使Node实例可以被正确地放入heapq实现的优先队列中
- return self.f < other.f # 比较两个节点的f值,以确定优先级
-
-
- initial = (7, 3, 0, 1, 6, 4, 2, 8, 5)
- goal = (1, 2, 3, 8, 0, 4, 7, 6, 5)
- puzzle = EightPuzzle(initial, goal)
-
- print("\nUsing Heuristic 1")
- start_time = time.time()
- path = puzzle.astar_search(puzzle.heuristic1)
- end_time = time.time()
- for i in path:
- print("Path:", i)
- print("Time: {:.6f} seconds".format(end_time - start_time))
-
- print("\nUsing Heuristic 2")
- start_time = time.time()
- path = puzzle.astar_search(puzzle.heuristic2)
- end_time = time.time()
- for i in path:
- print("Path:", i)
- print("Time: {:.6f} seconds".format(end_time - start_time))
-
- print("\nBreadth First Search")
- start_time = time.time()
- # 按照BFS的方式,按层次依次遍历搜索空间中的节点,直到找到目标状态
- path = puzzle.astar_search(lambda node: 0) # 将heuristic函数设置为一个常数函数,即对所有节点返回相同的值(返回0)
- end_time = time.time()
- for i in path:
- print("Path:", i)
- print("Time: {:.6f} seconds".format(end_time - start_time))

运行截图:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。