当前位置:   article > 正文

A* 算法的 Python 实现_python实现a*算法

python实现a*算法

本文使用 Python 语言实现 A* 算法。
算法流程和原理不赘述。

代码文件结构:
文件结构

point.py

import sys


class Point(object):
    def __init__(self, x: int, y: int):
        self.x = x
        self.y = y

        self.cost = sys.maxsize

        self.parent = None

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

map.py

from typing import Tuple, List

from point import Point


class Map(object):
    def __init__(self, width: int, height: int, obstacles: List[Tuple[int, int]] = []):
        self.width = width
        self.height = height

        self.obstacles = [Point(x=osc[0], y=osc[1]) for osc in obstacles]

    def is_obstacle(self, i: int, j: int):
        for p in self.obstacles:
            if i==p.x and j==p.y:
                return True
            
        return False

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

a_star.py
有可视化的代码,最终生成视频,生成后可将中间生成的图片删除。

import os
import sys
import time
from typing import Tuple, List

from matplotlib.patches import Rectangle
import cv2
import glob

from point import Point
from map import Map


class AStar(object):
    """
    A* algorithm
    """
    def __init__(self, map: Map, origin: Tuple[int, int], target: Tuple[int, int]):
        """
        initialise
        
        :param map:  map
        :param origin:  starting point coordinates
        :param target:  ending point coordinates
        """

        self.map = map

        self.origin = Point(x=origin[0], y=origin[1])
        self.target = Point(x=target[0], y=target[1])

        self.open_points = []
        self.close_points = []

    def _basic_cost(self, point: Point):
        """
        basic cost from origin
        """
        return abs(point.x - self.origin.x) + abs(point.y - self.origin.y)

    def _heuristic_cost(self, point: Point):
        """
        estimated cost to target
        """
        return abs(point.x - self.target.x) + abs(point.y - self.target.y)

    def _total_cost(self, point: Point):
        """
        total cost
        """
        return self._basic_cost(point) + self._heuristic_cost(point)

    def _is_valid_point(self, x: int, y: int):
        if x < 0 or y < 0:
            return False
        if x >= self.map.width or y >= self.map.height:
            return False
        if self.map.is_obstacle(x, y):
            return False
        
        return True

    def _in_point_list(self, point: Point, points: List[Point]):
        for p in points:
            if point.x == p.x and point.y == p.y:
                return True
            
        return False

    def _in_open_list(self, point: Point):
        return self._in_point_list(point, self.open_points)

    def _in_close_list(self, point: Point):
        return self._in_point_list(point, self.close_points)
    
    def run(self, ax, plt):
        """
        run alogrithm and visualise
        
        :param ax:  matplotlib.axes._subplots.AxesSubplot
        :param plt:  matplotlib.pyplot
        """

        tms = time.time()

        self.origin.cost = 0
        self.open_points.append(self.origin)
        while True:
            idx = self._select_from_open_list()
            if idx < 0:
                print("No path found, algorithm failed!")
                return
            point = self.open_points[idx]

            rectangle = Rectangle(xy=(point.x, point.y), width=1, height=1, color='cyan')
            ax.add_patch(rectangle)
            self._save_image(plt)

            if point.x == self.target.x and point.y == self.target.y:
                return self._build_path(point=point, tms=tms, ax=ax, plt=plt)

            del self.open_points[idx]
            self.close_points.append(point)

            # neighbours
            self._process_point(x=point.x - 1, y=point.y, parent=point)
            self._process_point(x=point.x, y=point.y - 1, parent=point)
            self._process_point(x=point.x + 1, y=point.y, parent=point)
            self._process_point(x=point.x, y=point.y + 1, parent=point)

    def _save_image(self, plt):
        """
        save images to outputs folder
        """

        millisecond = int(round(time.time() * 1000))
        file_name = './outputs/' + str(millisecond) + '.png'
        plt.savefig(file_name)

    def _process_point(self, x: int, y: int, parent: Point):
        """
        process current point
        
        :param x:  x coordinate
        :param y:  y coordinate
        :param parent:  current point's parent point
        """

        # do nothing for invalid point
        if not self._is_valid_point(x, y):
            return
        
        # do nothing for visited point
        point = Point(x, y)
        if self._in_close_list(point):
            return
        
        print("process point [{}, {}],  cost: {}".format(point.x, point.y, point.cost))
        if not self._in_open_list(point):
            point.parent = parent
            point.cost = self._total_cost(point)
            self.open_points.append(point)

    def _select_from_open_list(self) -> int:
        """
        select the point with least cost from the open list

        :return idx_select:  the index of the selected point in the open list
        """

        idx = 0
        idx_select = -1
        min_cost = sys.maxsize
        for point in self.open_points:
            cost = self._total_cost(point)
            if cost < min_cost:
                min_cost = cost
                idx_select = idx
            idx += 1

        return idx_select

    def _build_path(self, point: Point, tms: float, ax, plt):
        """
        build the whole path after algorithm terminates
        
        :param point:  ending point
        :param tms:  start time
        :param ax:  matplotlib.axes._subplots.AxesSubplot
        :param plt:  matplotlib.pyplot
        """

        # get whole path
        path = []
        while True:
            path.insert(0, point)
            if point.x == self.origin.x and point.y == self.origin.y:
                break
            else:
                point = point.parent

        # visualise
        for p in path:
            rec = Rectangle(xy=(p.x, p.y), width=1, height=1, color='green')
            ax.add_patch(rec)
            plt.draw()
            self._save_image(plt)

        self._merge_video()

        tme = time.time()
        print("Algorithm finishes in {} s".format(int(tme - tms)))

    def _merge_video(self):
        """
        merge images to video
        """

        # get image files
        image_files = []
        file_names = []
        for file_name in glob.glob('./outputs/*.png'):
            file_names.append(file_name)
            image = cv2.imread(filename=file_name)
            height, width, layers = image.shape
            size = (width, height)
            image_files.append(image)

        # generate video
        tm= time.time()
        video_path = f'./outputs/{round(tm)}.avi'
        fourcc = cv2.VideoWriter_fourcc(*'DIVX')
        video = cv2.VideoWriter(video_path, fourcc, 5, size)
        for image in image_files:
            video.write(image)
        video.release()

        # delete original image files
        for file in file_names:
            os.remove(file)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221

main.py(主程序)

from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle

from map import Map
from a_star import AStar


""" map settings """

width, height = 10, 15

origin, target = (0, 0), (width - 1, height - 1)
obstacles = [(round(width * (1 / 4)), j) for j in range(round(height * (2 / 3)))] + [
    (round(width * (1 / 2)), j) for j in range(round(height * (1 / 3)), height)] + [
    (round(width * (3 / 4)), j) for j in range(round(height * (2 / 3)))]

map_ = Map(width=width, height=height, obstacles=obstacles)


""" visual settings """

plt.figure(figsize=(5, 5))
ax = plt.gca()
ax.set_xlim([0, map_.width])
ax.set_ylim([0, map_.height])

for i in range(map_.width):
    for j in range(map_.height):
        if map_.is_obstacle(i, j):
            rectangle = Rectangle(xy=(i, j), width=1, height=1, color='gray')
            ax.add_patch(rectangle)
        else:
            rectangle = Rectangle(xy=(i, j), width=1, height=1, edgecolor='gray', facecolor='white')
            ax.add_patch(rectangle)

rectangle = Rectangle(xy=origin, width=1, height=1, facecolor='blue')
ax.add_patch(rectangle)
rectangle = Rectangle(xy=target, width=1, height=1, facecolor='red')
ax.add_patch(rectangle)

plt.axis('equal')  # set equal scaling
plt.axis('off')  # turn off axis lines and labels
plt.tight_layout()


""" algorithm """

a_star = AStar(map=map_, origin=(0, 0), target=(width - 1, height - 1))
a_star.run(ax, plt)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/561075
推荐阅读
相关标签
  

闽ICP备14008679号