当前位置:   article > 正文

python实现A*算法解决N数码问题

a*算法解决n数码问题

相关技术

  • A算法是BFS的一个变种,它把原来的BFS算法的无启发式的搜索改成了启发式的搜索,可以有效的减低节点的搜索个数。A算法和BFS十分类似,两者的主要区别在于BFS的候选队列是盲目的,而A*算法也使用了类似于BFS的候选队列,但是在选择的时候,是先选择出候选队列中代价最小的优先搜索,这个候选队列一般使用堆来表示。
  • 启发式搜索(Heuristically Search)又称为有信息搜索(Informed Search),它是利用问题拥有的启发信息来引导搜索,达到减少搜索范围、降低问题复杂度的目的,这种利用启发信息的搜索过程称为启发式搜索。

实现原理

    整个程序实现的原理是基于A算法(A-star algorithm)来寻找一个网格(或图)上从初始状态到目标状态的最短路径。

程序实现的原理可以分为以下几个步骤,下述步骤根据下面的代码进行阐述

  1. 定义问题空间和状态:
    • 程序首先定义了问题空间,即一个N*N的网格或图。
    • 每个网格位置可以有一个状态或编号,表示不同的属性或状态。
    • 初始状态和目标状态分别由BLOCK和GOAL表示。
  2. 初始化搜索算法:
    • 程序设置了一个OPEN列表,用于存储待扩展的节点(即待访问的网格位置)。
    • 初始节点(即BLOCK表示的初始状态)被添加到OPEN列表中。
    • 可能还设置了一个BLOCK列表或集合,用于记录已经访问过的节点,以避免重复访问。
  3. A*搜索过程:
    • 在每次迭代中,程序从OPEN列表中选择一个代价最小的节点进行扩展。
    • 扩展节点意味着生成该节点的所有可能后继节点,并计算它们的预估代价(通常使用启发式函数,如曼哈顿距离)。
    • 后继节点被添加到OPEN列表中,并根据它们的总代价(已知代价+预估代价)进行排序。
    • 如果后继节点中包含了目标节点(即GOAL状态),则搜索结束,返回找到的最短路径。
  4. 路径重构:
    • 当找到目标节点后,程序通过回溯父节点来重构从初始节点到目标节点的最短路径。
    • 这通常是通过在搜索过程中维护每个节点的父节点引用来实现的。
  5. 性能优化和资源管理:
    • 为了提高搜索效率,程序可能使用了一些优化技术,如剪枝(避免访问明显不可能到达目标状态的节点)或限制搜索深度。
    • 同时,程序也需要注意资源管理,如及时关闭打开的文件、避免内存泄漏等。

实现步骤

  1. 变量定义
  2. 定义状态节点
  3. 定义曼哈顿距离计算函数
  4. 生成子节点函数
  5. 定义输出路径函数
  6. 定义A*算法
  7. 读取数据作为原始状态
  8. 查看结果

流程梳理

[开始]  
  ↓  
[读取文件]  
  |  
  | 如果文件不存在或无法打开  
  |   ↓  
  | [打印错误信息]  
  |   ↓  
  | [退出程序]  
  |  
  ↓  
[解析文件内容]  
  |  
  | 解析第一行得到 NUMBER  
  | 解析第二行得到初始状态 BLOCK  
  |  
  ↓  
[生成目标状态 GOAL]  
  |  
  ↓  
[调用 A_start 函数]  
  |   输入: BLOCK, GOAL, manhattan_dis, generate_child, 时间限制  
  |  
  ↓  
[A* 算法搜索]  
  |   初始化 OPEN 列表,包含初始节点  
  |   初始化 BLOCK 列表或集合(可选)  
  |   循环直到 OPEN 列表为空或找到目标状态  
  |     |   选择 OPEN 列表中代价最小的节点  
  |     |   扩展节点,生成后继节点  
  |     |   计算后继节点的预估代价和总代价  
  |     |   将后继节点添加到 OPEN 列表中,并根据总代价排序  
  |     |   如果后继节点中包含目标状态,结束循环  
  |   重构最短路径(通过回溯父节点)  
  |  
  ↓  
[输出结果]  
  |   打印路径长度  
  |   打印算法运行时间  
  |   打印搜索过程中访问的节点数量(如果可用)  
  |  
  ↓  
[资源管理]  
  |   关闭打开的文件(如果尚未关闭)  
  |   释放其他资源(如内存)  
  |  
  ↓  
[结束]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

注意

    本实验的相关数据节点可以点击此处去下载。

代码实现

# 导入heapq模块,用于实现优先队列(最小堆)  
import heapq  
# 导入copy模块,用于对象的深拷贝  
import copy  
# 导入re模块,用于正则表达式操作  
import re  
# 导入datetime模块,用于日期和时间操作
import datetime
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
# 初始化一个空列表,用于存储BLOCK数据  
BLOCK = []  
# 初始化一个空列表,用于存储GOAL数据 
GOAL = []  
# 定义四个方向的移动向量:[上, 右, 下, 左]  
direction = [[0, 1], [0, -1], [1, 0], [-1, 0]]  
# 初始化一个空列表,用于存储OPEN列表(即待搜索的节点列表)  
OPEN = []  
# 初始化一个变量,用于记录已访问的节点数量  
SUM_NODE_NUM = 0  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
# 定义状态节点类  
class State(object):  
    # 初始化方法  
    def __init__(self, gn=0, hn=0, state=None, hash_value=None, par=None):  
        # 初始化gn(从起点到当前状态的代价)  
        self.gn = gn  
        # 初始化hn(从当前状态到目标的启发式代价)  
        self.hn = hn  
        # 计算fn(总代价,gn和hn的和)  
        self.fn = self.gn + self.hn  
        # 初始化子节点列表  
        self.child = []  
        # 初始化父节点  
        self.par = par  
        # 初始化状态(可能是二维数组或类似结构)  
        self.state = state  
        # 初始化状态的哈希值(用于快速比较状态是否相同)  
        self.hash_value = hash_value
        
    # 定义小于比较方法,用于优先队列  
    def __lt__(self, other):  
        return self.fn < other.fn
    
    # 定义等于比较方法  
    def __eq__(self, other):  
        return self.hash_value == other.hash_value
    
    # 定义不等于比较方法  
    def __ne__(self, other):  
        return not self.__eq__(other)  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
# 定义曼哈顿距离计算函数  
def manhattan_dis(cur_node, end_node):  
    # 获取当前节点的状态  
    cur_state = cur_node.state  
    # 获取目标节点的状态  
    end_state = end_node.state  
    # 初始化距离变量  
    dist = 0  
    # 获取状态的维度大小  
    N = len(cur_state)  
    # 遍历状态的每一个位置  
    for i in range(N):  
        for j in range(N):  
            # 如果当前位置的状态与目标位置的状态相同,则跳过  
            if cur_state[i][j] == end_state[i][j]:  
                continue  
            # 获取当前位置的数字(如果是0,表示空位)  
            num = cur_state[i][j]  
            # 如果是空位,则设定空位的位置为数组的最大坐标  
            if num == 0:  
                x = N - 1  
                y = N - 1  
            else:  
                # 计算数字在数组中的实际位置(x为行,y为列)  
                x = num // N  
                y = num - N * x - 1  
            # 计算曼哈顿距离并累加到总距离中  
            dist += (abs(x - i) + abs(y - j))  
    # 返回计算得到的曼哈顿距离  
    return dist
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
# 生成子节点函数  
def generate_child(cur_node, end_node, hash_set, open_table, dis_fn):  
    # 如果当前节点已经等于目标节点,则将其加入到OPEN列表中并返回  
    if cur_node == end_node:  
        heapq.heappush(open_table, end_node)  
        return  
      
    # 获取状态的大小(假设为N*N的二维数组)  
    num = len(cur_node.state)
    # 遍历当前状态的所有位置  
    for i in range(0, num):  
        for j in range(0, num):  
            # 如果当前位置不是空位(0表示空位)  
            if cur_node.state[i][j] != 0:  
                continue
            # 遍历四个可能的移动方向  
            for d in direction:  
                # 计算新的位置  
                x = i + d[0]  
                y = j + d[1]
                # 如果新位置越界,则跳过此次移动  
                if x < 0 or x >= num or y < 0 or y >= num:  
                    continue
                # 增加已访问节点数量  
                global SUM_NODE_NUM  
                SUM_NODE_NUM += 1  
                # 复制当前状态  
                state = copy.deepcopy(cur_node.state)
                # 交换空位和移动方向上的数字  
                state[i][j], state[x][y] = state[x][y], state[i][j]
                # 计算新状态的哈希值  
                h = hash(str(state))
                # 如果新状态的哈希值已经在哈希集合中,则跳过此次生成的子节点  
                if h in hash_set:  
                    continue
                # 将新状态的哈希值加入到哈希集合中  
                hash_set.add(h)
                # 计算新的gn值(从起点到当前状态的代价)  
                gn = cur_node.gn + 1
                # 计算新的hn值(启发式代价,由传入的dis_fn函数决定)  
                hn = dis_fn(cur_node, end_node)
                # 创建新的状态节点  
                node = State(gn, hn, state, h, cur_node)
                # 将新节点添加到当前节点的子节点列表中  
                cur_node.child.append(node)
                # 将新节点加入到OPEN列表中  
                heapq.heappush(open_table, node)  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
# 定义输出路径函数  
def print_path(node):  
    # 初始化步数  
    num = node.gn
    # 定义内部函数用于展示二维数组的状态  
    def show_block(block):  
        print("+-------------------+")  
        for b in block:  
            print(b)
    # 初始化栈,用于逆序存储路径上的状态  
    stack = []
    # 从当前节点开始,沿着父节点回溯,直到根节点  
    while node.par is not None:  
        stack.append(node.state)  
        node = node.par
    # 将根节点的状态也加入到栈中  
    stack.append(node.state)
    # 逆序弹出栈中的状态并展示  
    while len(stack) != 0:  
        t = stack.pop()  
        show_block(t) 
    # 返回步数  
    return num
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
# 定义A*算法  
def A_start(start, end, distance_fn, generate_child_fn, time_limit=10):  
    # 初始化根节点,起始代价为0,启发式代价为0,状态为起始状态,哈希值为起始状态的哈希值,父节点为None  
    root = State(0, 0, start, hash(str(BLOCK)), None)  
    # 初始化目标节点,代价和父节点与根节点类似,状态为目标状态,哈希值为目标状态的哈希值  
    end_state = State(0, 0, end, hash(str(GOAL)), None)  
    # 如果根节点就是目标节点,则打印提示信息  
    if root == end_state:  
        print("start == end !")  
    # 将根节点添加到OPEN列表中  
    OPEN.append(root)  
    # 将OPEN列表转换为堆结构,以便可以高效地取出代价最小的节点  
    heapq.heapify(OPEN)  
    # 初始化哈希集合,用于存储已访问节点的哈希值  
    node_hash_set = set()  
    # 将根节点的哈希值添加到哈希集合中  
    node_hash_set.add(root.hash_value)  
    # 记录算法开始运行的时间  
    start_time = datetime.datetime.now()
    # 当OPEN列表不为空时,继续循环  
    while len(OPEN) != 0:  
        # 取出OPEN列表中代价最小的节点  
        top = heapq.heappop(OPEN)  
        # 如果取出的节点是目标节点,则打印路径并返回步数  
        if top == end_state:  
            return print_path(top)  
        # 生成当前节点的所有子节点  
        generate_child_fn(cur_node=top, end_node=end_state, hash_set=node_hash_set,  
                         open_table=OPEN, dis_fn=distance_fn)  
        # 记录当前时间  
        cur_time = datetime.datetime.now()  
        # 如果算法运行时间超过了设定的时间限制,则打印提示信息、节点数量和步数,并返回-1表示超时  
        if (cur_time - start_time).seconds > time_limit:  
            print("Time running out, break !")  
            print(f"Number of nodes: {SUM_NODE_NUM}")  
            return -1  
    # 如果循环结束仍未找到路径,则打印提示信息并返回-1表示无解  
    print("No road !")  
    return -1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
# 读取数据作为原始状态  
def read_block(block, line, N):  
    # 定义一个正则表达式,用于匹配一行中的数字  
    pattern = re.compile(r'\d+')  
    # 使用正则表达式查找所有匹配的数字,并返回数字列表  
    res = re.findall(pattern, line)  
    t = 0  
    # 初始化临时列表,用于暂存每一行的数字  
    tmp = []  
    # 遍历找到的数字  
    for i in res:  
        # 计数器加1  
        t += 1  
        # 将数字添加到临时列表中  
        tmp.append(int(i))  
        # 如果计数器等于N,表示一行数字已经收集完毕  
        if t == N:  
            # 将临时列表添加到block中,并清空临时列表以便收集下一行数字  
            t = 0  
            block.append(tmp)  
            tmp = []
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
# 运行并输出结果
if __name__ == '__main__':
    try:
        # 尝试打开文件infile.txt进行读取
        file = open('./infile.txt', 'r')
    except IOError:
        # 如果文件打开失败,则打印错误信息并退出程序
        print('can not open file infile.txt !')
        exit(1)
    # 再次打开文件infile.txt
    f = open('./infile.txt')
    # 读取第一行,并取倒数第二个字符(应该是数字字符)转换为整数,作为后续处理中的NUMBER
    NUMBER = int(f.readline()[-2])
    # 初始化计数器n为1,用于生成目标状态GOAL
    n = 1
    # 根据NUMBER生成目标状态GOAL,是一个NUMBER*NUMBER的二维列表
    for i in range(NUMBER):
        l = []
        for j in range(NUMBER):
            l.append(n)
            n += 1
        GOAL.append(l)
    # 将目标状态GOAL的最后一个元素设置为0,表示终点位置
    GOAL[NUMBER - 1][NUMBER - 1] = 0
    # 逐行读取文件infile.txt的内容
    for line in f:
        # 初始化OPEN列表、BLOCK列表和SUM_NODE_NUM
        OPEN = []
        BLOCK = []
        # 调用read_block函数解析当前行line,生成起始状态BLOCK
        read_block(BLOCK, line, NUMBER)
        SUM_NODE_NUM = 0
        # 记录开始时间
        start_t = datetime.datetime.now()
        # 调用A_start函数进行A*算法搜索,返回路径长度length
        length = A_start(BLOCK, GOAL, manhattan_dis, generate_child, time_limit=10)
        # 记录结束时间
        end_t = datetime.datetime.now()
        # 如果找到了路径就打印路径长度、运行时间和节点数量
        if length != -1:
            print("+-------------------+")
            print(f"length = {length}")
            print(f"time = {(end_t - start_t).total_seconds()}s")
            print(f"Nodes = {SUM_NODE_NUM}")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

运行结果

+-------------------+
[1, 6, 3]
[4, 5, 2]
[8, 7, 0]
+-------------------+
[1, 6, 3]
[4, 5, 0]
[8, 7, 2]
+-------------------+
[1, 6, 3]
[4, 0, 5]
[8, 7, 2]
+-------------------+
[1, 0, 3]
[4, 6, 5]
[8, 7, 2]
+-------------------+
[0, 1, 3]
[4, 6, 5]
[8, 7, 2]
+-------------------+
[4, 1, 3]
[0, 6, 5]
[8, 7, 2]
+-------------------+
[4, 1, 3]
[8, 6, 5]
[0, 7, 2]
+-------------------+
[4, 1, 3]
[8, 6, 5]
[7, 0, 2]
+-------------------+
[4, 1, 3]
[8, 0, 5]
[7, 6, 2]
+-------------------+
[4, 1, 3]
[0, 8, 5]
[7, 6, 2]
+-------------------+
[0, 1, 3]
[4, 8, 5]
[7, 6, 2]
+-------------------+
[1, 0, 3]
[4, 8, 5]
[7, 6, 2]
+-------------------+
[1, 3, 0]
[4, 8, 5]
[7, 6, 2]
+-------------------+
[1, 3, 5]
[4, 8, 0]
[7, 6, 2]
+-------------------+
[1, 3, 5]
[4, 8, 2]
[7, 6, 0]
+-------------------+
[1, 3, 5]
[4, 8, 2]
[7, 0, 6]
+-------------------+
[1, 3, 5]
[4, 0, 2]
[7, 8, 6]
+-------------------+
[1, 3, 5]
[4, 2, 0]
[7, 8, 6]
+-------------------+
[1, 3, 0]
[4, 2, 5]
[7, 8, 6]
+-------------------+
[1, 0, 3]
[4, 2, 5]
[7, 8, 6]
+-------------------+
[1, 2, 3]
[4, 0, 5]
[7, 8, 6]
+-------------------+
[1, 2, 3]
[4, 5, 0]
[7, 8, 6]
+-------------------+
[1, 2, 3]
[4, 5, 6]
[7, 8, 0]
+-------------------+
length = 22
time = 0.160433s
Nodes = 10827
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96

不足的地方

    在提供的代码片段中,虽然缺少了一些关键的实现细节(如A_start函数的内部逻辑、manhattan_dis和generate_child函数的定义等),但整体上可以看出程序是按照A算法的原理来设计和实现的。通过读取输入文件来初始化搜索问题,然后执行A搜索算法来找到最短路径,并最后输出搜索结果。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/561070
推荐阅读
相关标签
  

闽ICP备14008679号