当前位置:   article > 正文

十五数码难题 A*算法及深度优先算法实现_十五数码问题初始状态

十五数码问题初始状态

一、问题描述

二、算法分析

在搜索的每一步都利用估价函数 f(n)= g(n)+h(n)对 Open 表中的节点进行排序表中的节点进行排序, 找出一个最有希望的节点作为下一次扩展的节点。且满足条 件:h(n)≤h*(n)。其中 g(n) 是在状态空间中从初始状态到状态 n 的实际代价, h(n) 是从状态 n 到目标状态的最佳路径的估计代价。

算法过程如下:

读入初始状态和目标状态,并计算初始状态评价函数值 f;

初始化两个 open 表和 closed 表,将初始状态放入 open 表中

如果 open 表为空,则查找失败;

否则:

①在 open 表中找到评价值最小的节点,作为当前结点,并放入 closed 表中;

② 判断当前结点状态和目标状态是否一致,若一致,跳出循环;否则跳转到③;

③ 对当前结点,分别按照上、下、左、右方向移动空格位置来扩展新的状态结 点,并计算新扩展结点的评价值 f 并记录其父节点;

④ 对于新扩展的状态结点,进行如下操作: A.新节点既不在 open 表中,也不在 closed 表中,则添加进 OPEN 表; B.新节点在 open 表中,则计算评价函数的值,取最小的。 C.新节点在 closed 表中,则计算评价函数的值,取最小的。

⑤ 把当前结点从 open 表中移除;

 三.深度优先算法:

(1)从图中某顶点 v 出发,访问顶点 v;

(2)依次从 v 的未被访问的邻接点出发,对图进行深度优先遍历;直至图中 和 v 有路径相通的顶点都被访问;

(3)若此时图中尚有顶点未被访问,则从一个未被访问的顶点出发,重新进行 深度优先遍历,直到图中所有顶点均被访问过为止。

A*算法

  1. #-*-coding:utf-8-*-
  2. import heapq
  3. import copy
  4. import time
  5. import math
  6. import argparse
  7. # 初始状态
  8. # S0 = [[11, 9, 4, 15],
  9. # [1, 3, 0, 12],
  10. # [7, 5, 8, 6],
  11. # [13, 2, 10, 14]]
  12. S0 = [[5, 1, 2, 4],
  13. [9, 6, 3, 8],
  14. [13, 15, 10, 11],
  15. [0, 14, 7, 12]]
  16. # 目标状态
  17. SG = [[1, 2, 3, 4],
  18. [5, 6, 7, 8],
  19. [9, 10, 11, 12],
  20. [13, 14, 15, 0]]
  21. # 上下左右四个方向移动
  22. MOVE = {'up': [1, 0],
  23. 'down': [-1, 0],
  24. 'left': [0, -1],
  25. 'right': [0, 1]}
  26. # OPEN
  27. OPEN = []
  28. # 节点的总数
  29. SUM_NODE_NUM = 0
  30. # 状态节点
  31. class State(object):
  32. def __init__(self, deepth=0, rest_dis=0.0, state=None, hash_value=None, father_node=None):
  33. '''
  34. 初始化
  35. :参数 deepth: 从初始节点到目前节点所经过的步数
  36. :参数 rest_dis: 启发距离
  37. :参数 state: 节点存储的状态 4*4的列表
  38. :参数 hash_value: 哈希值,用于判重
  39. :参数 father_node: 父节点指针
  40. '''
  41. self.deepth = deepth
  42. self.rest_dis = rest_dis
  43. self.fn = self.deepth + self.rest_dis
  44. self.child = [] # 孩子节点
  45. self.father_node = father_node # 父节点
  46. self.state = state # 局面状态
  47. self.hash_value = hash_value # 哈希值
  48. def __lt__(self, other): # 用于堆的比较,返回距离最小的
  49. return self.fn < other.fn
  50. def __eq__(self, other): # 相等的判断
  51. return self.hash_value == other.hash_value
  52. def __ne__(self, other): # 不等的判断
  53. return not self.__eq__(other)
  54. def cal_M_distence(cur_state):
  55. '''
  56. 计算曼哈顿距离
  57. :参数 state: 当前状态,4*4的列表, State.state
  58. :返回: M_cost 每一个节点计算后的曼哈顿距离总和
  59. '''
  60. M_cost = 0
  61. for i in range(4):
  62. for j in range(4):
  63. if cur_state[i][j] == SG[i][j]:
  64. continue
  65. num = cur_state[i][j]
  66. if num == 0:
  67. x, y = 3, 3
  68. else:
  69. x = num / 4 # 理论横坐标
  70. y = num - 4 * x - 1 # 理论的纵坐标
  71. M_cost += (abs(x - i) + abs(y - j))
  72. return M_cost
  73. def cal_E_distence(cur_state):
  74. '''
  75. 计算曼哈顿距离
  76. :参数 state: 当前状态,4*4的列表, State.state
  77. :返回: M_cost 每一个节点计算后的曼哈顿距离总和
  78. '''
  79. E_cost = 0
  80. for i in range(4):
  81. for j in range(4):
  82. if cur_state[i][j] == SG[i][j]:
  83. continue
  84. num = cur_state[i][j]
  85. if num == 0:
  86. x, y = 3, 3
  87. else:
  88. x = num / 4 # 理论横坐标
  89. y = num - 4 * x - 1 # 理论的纵坐标
  90. E_cost += math.sqrt((x - i)*(x - i) + (y - j)*(y - j))
  91. return E_cost
  92. def generate_child(sn_node, sg_node, hash_set, open_table, cal_distence):
  93. '''
  94. 生成子节点函数
  95. :参数 sn_node: 当前节点
  96. :参数 sg_node: 最终状态节点
  97. :参数 hash_set: 哈希表,用于判重
  98. :参数 open_table: OPEN表
  99. :参数 cal_distence: 距离函数
  100. :返回: None
  101. '''
  102. if sn_node == sg_node:
  103. heapq.heappush(open_table, sg_node)
  104. print('已找到终止状态!')
  105. return
  106. for i in range(0, 4):
  107. for j in range(0, 4):
  108. if sn_node.state[i][j] != 0:
  109. continue
  110. for d in ['up', 'down', 'left', 'right']: # 四个偏移方向
  111. x = i + MOVE[d][0]
  112. y = j + MOVE[d][1]
  113. if x < 0 or x >= 4 or y < 0 or y >= 4: # 越界了
  114. continue
  115. state = copy.deepcopy(sn_node.state) # 复制父节点的状态
  116. state[i][j], state[x][y] = state[x][y], state[i][j] # 交换位置
  117. h = hash(str(state)) # 哈希时要先转换成字符串
  118. if h in hash_set: # 重复了
  119. continue
  120. hash_set.add(h) # 加入哈希表
  121. # 记录扩展节点的个数
  122. global SUM_NODE_NUM
  123. SUM_NODE_NUM += 1
  124. deepth = sn_node.deepth + 1 # 已经走的距离函数
  125. rest_dis = cal_distence(state) # 启发的距离函数
  126. node = State(deepth, rest_dis, state, h, sn_node) # 新建节点
  127. sn_node.child.append(node) # 加入到孩子队列
  128. heapq.heappush(open_table, node) # 加入到堆中
  129. # show_block(state, deepth) # 打印每一步的搜索过程
  130. def show_block(block, step):
  131. print("------", step, "--------")
  132. for b in block:
  133. print(b)
  134. def print_path(node):
  135. '''
  136. 输出路径
  137. :参数 node: 最终的节点
  138. :返回: None
  139. '''
  140. print("最终搜索路径为:")
  141. steps = node.deepth
  142. stack = [] # 模拟栈
  143. while node.father_node is not None:
  144. stack.append(node.state)
  145. node = node.father_node
  146. stack.append(node.state)
  147. step = 0
  148. while len(stack) != 0:
  149. t = stack.pop()
  150. show_block(t, step)
  151. step += 1
  152. return steps
  153. def A_start(start, end, distance_fn, generate_child_fn):
  154. '''
  155. A*算法
  156. :参数 start: 起始状态
  157. :参数 end: 终止状态
  158. :参数 distance_fn: 距离函数,可以使用自定义的
  159. :参数 generate_child_fn: 产生孩子节点的函数
  160. :返回: 最优路径长度
  161. '''
  162. root = State(0, 0, start, hash(str(S0)), None) # 根节点
  163. end_state = State(0, 0, end, hash(str(SG)), None) # 最后的节点
  164. if root == end_state:
  165. print("start == end !")
  166. OPEN.append(root)
  167. heapq.heapify(OPEN)
  168. node_hash_set = set() # 存储节点的哈希值
  169. node_hash_set.add(root.hash_value)
  170. while len(OPEN) != 0:
  171. top = heapq.heappop(OPEN)
  172. if top == end_state: # 结束后直接输出路径
  173. return print_path(top)
  174. # 产生孩子节点,孩子节点加入OPEN
  175. generate_child_fn(sn_node=top, sg_node=end_state, hash_set=node_hash_set,
  176. open_table=OPEN, cal_distence=distance_fn)
  177. print("无搜索路径!") # 没有路径
  178. return -1
  179. if __name__ == '__main__':
  180. # 可配置式运行文件
  181. parser = argparse.ArgumentParser(description='选择距离计算方法')
  182. parser.add_argument('--method', '-m', help='method 选择距离计算方法(cal_E_distence or cal_M_distence)', default = 'cal_M_distence')
  183. args = parser.parse_args()
  184. method = args.method
  185. time1 = time.time()
  186. if method == 'cal_E_distence':
  187. length = A_start(S0, SG, cal_E_distence, generate_child)
  188. else:
  189. length = A_start(S0, SG, cal_M_distence, generate_child)
  190. time2 = time.time()
  191. if length != -1:
  192. if method == 'cal_E_distence':
  193. print("采用欧式距离计算启发函数")
  194. else:
  195. print("采用曼哈顿距离计算启发函数")
  196. print("搜索最优路径长度为", length)
  197. print("搜索时长为", (time2 - time1), "s")
  198. print("共检测节点数为", SUM_NODE_NUM)

深度优先:

  1. #-*-coding:utf-8-*-
  2. import copy
  3. import time
  4. # 初始状态
  5. # S0 = [[11, 9, 4, 15],
  6. # [1, 3, 0, 12],
  7. # [7, 5, 8, 6],
  8. # [13, 2, 10, 14]]
  9. S0 = [[5, 1, 2, 4],
  10. [9, 6, 3, 8],
  11. [13, 15, 10, 11],
  12. [0, 14, 7, 12]]
  13. # 目标状态
  14. SG = [[1, 2, 3, 4],
  15. [5, 6, 7, 8],
  16. [9, 10, 11, 12],
  17. [13, 14, 15, 0]]
  18. # 上下左右四个方向移动
  19. MOVE = {'up': [1, 0],
  20. 'down': [-1, 0],
  21. 'left': [0, -1],
  22. 'right': [0, 1]}
  23. # OPEN
  24. OPEN = []
  25. # 节点的总数
  26. SUM_NODE_NUM = 0
  27. # 状态节点
  28. class State(object):
  29. def __init__(self, deepth=0, state=None, hash_value=None, father_node=None):
  30. '''
  31. 初始化
  32. :参数 deepth: gn是初始化到现在的距离
  33. :参数 state: 节点存储的状态
  34. :参数 hash_value: 哈希值,用于判重
  35. :参数 father_node: 父节点指针
  36. '''
  37. self.deepth = deepth
  38. self.child = [] # 孩子节点
  39. self.father_node = father_node # 父节点
  40. self.state = state # 局面状态
  41. self.hash_value = hash_value # 哈希值
  42. def __eq__(self, other): # 相等的判断
  43. return self.hash_value == other.hash_value
  44. def __ne__(self, other): # 不等的判断
  45. return not self.__eq__(other)
  46. def generate_child(sn_node, sg_node, hash_set):
  47. '''
  48. 生成子节点函数
  49. :参数 sn_node: 当前节点
  50. :参数 sg_node: 最终状态节点
  51. :参数 hash_set: 哈希表,用于判重
  52. :参数 open_table: OPEN表
  53. :返回: None
  54. '''
  55. for i in range(0, 4):
  56. for j in range(0, 4):
  57. if sn_node.state[i][j] != 0:
  58. continue
  59. for d in ['up', 'down', 'left', 'right']: # 四个偏移方向
  60. x = i + MOVE[d][0]
  61. y = j + MOVE[d][1]
  62. if x < 0 or x >= 4 or y < 0 or y >= 4: # 越界了
  63. continue
  64. state = copy.deepcopy(sn_node.state) # 复制父节点的状态
  65. state[i][j], state[x][y] = state[x][y], state[i][j] # 交换位置
  66. h = hash(str(state)) # 哈希时要先转换成字符串
  67. if h in hash_set: # 重复了
  68. continue
  69. hash_set.add(h) # 加入哈希表
  70. # 记录扩展节点的个数
  71. global SUM_NODE_NUM
  72. SUM_NODE_NUM += 1
  73. deepth = sn_node.deepth + 1 # 已经走的距离函数
  74. node = State(deepth, state, h, sn_node) # 新建节点
  75. sn_node.child.append(node) # 加入到孩子队列
  76. OPEN.insert(0, node)
  77. # show_block(state, deepth)
  78. def show_block(block, step):
  79. print("------", step, "--------")
  80. for b in block:
  81. print(b)
  82. def print_path(node):
  83. '''
  84. 输出路径
  85. :参数 node: 最终的节点
  86. :返回: None
  87. '''
  88. print("最终搜索路径为:")
  89. steps = node.deepth
  90. stack = [] # 模拟栈
  91. while node.father_node is not None:
  92. stack.append(node.state)
  93. node = node.father_node
  94. stack.append(node.state)
  95. step = 0
  96. while len(stack) != 0:
  97. t = stack.pop()
  98. show_block(t, step)
  99. step += 1
  100. return steps
  101. def DFS_max_deepth(start, end, generate_child_fn, max_deepth):
  102. '''
  103. A*算法
  104. :参数 start: 起始状态
  105. :参数 end: 终止状态
  106. :参数 generate_child_fn: 产生孩子节点的函数
  107. :参数 max_deepth: 最深搜索深度
  108. :返回: None
  109. '''
  110. root = State(0, start, hash(str(S0)), None) # 根节点
  111. end_state = State(0, end, hash(str(SG)), None) # 最后的节点
  112. if root == end_state:
  113. print("start == end !")
  114. OPEN.append(root)
  115. node_hash_set = set() # 存储节点的哈希值
  116. node_hash_set.add(root.hash_value)
  117. while len(OPEN) != 0:
  118. top = OPEN.pop(0)
  119. if top == end_state: # 结束后直接输出路径
  120. return print_path(top)
  121. if top.deepth >= max_deepth:
  122. continue
  123. # 产生孩子节点,孩子节点加入OPEN
  124. generate_child_fn(sn_node=top, sg_node=end_state, hash_set=node_hash_set)
  125. print("设置最深深度不合适,无搜索路径!") # 没有路径
  126. return -1
  127. if __name__ == '__main__':
  128. time1 = time.time()
  129. length = DFS_max_deepth(S0, SG, generate_child, 25)
  130. time2 = time.time()
  131. if length != -1:
  132. print("搜索最优路径长度为", length)
  133. print("搜索时长为", (time2 - time1), "s")
  134. print("共检测节点数为", SUM_NODE_NUM)

广度优先

  1. #-*-coding:utf-8-*-
  2. import heapq
  3. import copy
  4. import time
  5. # 初始状态
  6. # S0 = [[11, 9, 4, 15],
  7. # [1, 3, 0, 12],
  8. # [7, 5, 8, 6],
  9. # [13, 2, 10, 14]]
  10. S0 = [[5, 1, 2, 4],
  11. [9, 6, 3, 8],
  12. [13, 15, 10, 11],
  13. [0, 14, 7, 12]]
  14. # 目标状态
  15. SG = [[1, 2, 3, 4],
  16. [5, 6, 7, 8],
  17. [9, 10, 11, 12],
  18. [13, 14, 15, 0]]
  19. # 上下左右四个方向移动
  20. MOVE = {'up': [1, 0],
  21. 'down': [-1, 0],
  22. 'left': [0, -1],
  23. 'right': [0, 1]}
  24. # OPEN表
  25. OPEN = []
  26. # 节点的总数
  27. SUM_NODE_NUM = 0
  28. # 状态节点
  29. class State(object):
  30. def __init__(self, deepth=0, state=None, hash_value=None, father_node=None):
  31. '''
  32. 初始化
  33. :参数 deepth: 从初始节点到目前节点所经过的步数
  34. :参数 state: 节点存储的状态 4*4的列表
  35. :参数 hash_value: 哈希值,用于判重
  36. :参数 father_node: 父节点指针
  37. '''
  38. self.deepth = deepth
  39. self.child = [] # 孩子节点
  40. self.father_node = father_node # 父节点
  41. self.state = state # 局面状态
  42. self.hash_value = hash_value # 哈希值
  43. def __lt__(self, other): # 用于堆的比较,返回距离最小的
  44. return self.deepth < other.deepth
  45. def __eq__(self, other): # 相等的判断
  46. return self.hash_value == other.hash_value
  47. def __ne__(self, other): # 不等的判断
  48. return not self.__eq__(other)
  49. def generate_child(sn_node, sg_node, hash_set, open_table):
  50. '''
  51. 生成子节点函数
  52. :参数 sn_node: 当前节点
  53. :参数 sg_node: 最终状态节点
  54. :参数 hash_set: 哈希表,用于判重
  55. :参数 open_table: OPEN表
  56. :返回: None
  57. '''
  58. if sn_node == sg_node:
  59. heapq.heappush(open_table, sg_node)
  60. print('已找到终止状态!')
  61. return
  62. for i in range(0, 4):
  63. for j in range(0, 4):
  64. if sn_node.state[i][j] != 0:
  65. continue
  66. for d in ['up', 'down', 'left', 'right']: # 四个偏移方向
  67. x = i + MOVE[d][0]
  68. y = j + MOVE[d][1]
  69. if x < 0 or x >= 4 or y < 0 or y >= 4: # 越界了
  70. continue
  71. state = copy.deepcopy(sn_node.state) # 复制父节点的状态
  72. state[i][j], state[x][y] = state[x][y], state[i][j] # 交换位置
  73. h = hash(str(state)) # 哈希时要先转换成字符串
  74. if h in hash_set: # 重复了
  75. continue
  76. hash_set.add(h) # 加入哈希表
  77. # 记录扩展节点的个数
  78. global SUM_NODE_NUM
  79. SUM_NODE_NUM += 1
  80. deepth = sn_node.deepth + 1 # 已经走的距离函数
  81. node = State(deepth, state, h, sn_node) # 新建节点
  82. sn_node.child.append(node) # 加入到孩子队列
  83. heapq.heappush(open_table, node) # 加入到堆中
  84. # show_block(state, deepth) # 打印每一步的搜索过程
  85. def show_block(block, step):
  86. print("------", step, "--------")
  87. for b in block:
  88. print(b)
  89. def print_path(node):
  90. '''
  91. 输出路径
  92. :参数 node: 最终的节点
  93. :返回: None
  94. '''
  95. print("最终搜索路径为:")
  96. steps = node.deepth
  97. stack = [] # 模拟栈
  98. while node.father_node is not None:
  99. stack.append(node.state)
  100. node = node.father_node
  101. stack.append(node.state)
  102. step = 0
  103. while len(stack) != 0:
  104. t = stack.pop()
  105. show_block(t, step)
  106. step += 1
  107. return steps
  108. def A_start(start, end, generate_child_fn):
  109. '''
  110. A*算法
  111. :参数 start: 起始状态
  112. :参数 end: 终止状态
  113. :参数 generate_child_fn: 产生孩子节点的函数
  114. :返回: 最优路径长度
  115. '''
  116. root = State(0, start, hash(str(S0)), None) # 根节点
  117. end_state = State(0, end, hash(str(SG)), None) # 最后的节点
  118. if root == end_state:
  119. print("start == end !")
  120. OPEN.append(root)
  121. heapq.heapify(OPEN)
  122. node_hash_set = set() # 存储节点的哈希值
  123. node_hash_set.add(root.hash_value)
  124. while len(OPEN) != 0:
  125. top = heapq.heappop(OPEN)
  126. if top == end_state: # 结束后直接输出路径
  127. return print_path(top)
  128. # 产生孩子节点,孩子节点加入OPEN表
  129. generate_child_fn(sn_node=top, sg_node=end_state, hash_set=node_hash_set,
  130. open_table=OPEN)
  131. print("无搜索路径!") # 没有路径
  132. return -1
  133. if __name__ == '__main__':
  134. time1 = time.time()
  135. length = A_start(S0, SG, generate_child)
  136. time2 = time.time()
  137. if length != -1:
  138. print("搜索最优路径长度为", length)
  139. print("搜索时长为", (time2 - time1), "s")
  140. print("共检测节点数为", SUM_NODE_NUM)

四.运行截图

五.总结

通过对比分析,可以发现,A 星算法的搜索时长和检测节点数明显小于深度优先方法,可见 启发式信息对于搜索过程的重要性;另外,有界深度优先算法的算法性能差异较大,设置不 同的最深深度得到的结果有一定的差异,一般设置较大会造成内存爆炸的现象,所以通过该 方法进行搜索较为困难,对于任务较为复杂的情况,很难快速求解。

另外,广度优先算法, 针对较为简单问题,基本可以以最短路径给出答案,但同时搜索时间和搜索节点数一定会比 启发式搜索多一些,针对复杂问题,很难给出答案,每扩展一层,都会以指数的形式增加待 扩展节点的数量,很难得出答案。

综上所述,与深度优先算法相比,启发式搜索算法有很强的优越性,一般情况下要尽可能去 寻找启发函数,添加到代码中辅助进行算法的训练,尽可能缩短程序运行时间,提高程序效 率。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/889728
推荐阅读
相关标签
  

闽ICP备14008679号