赞
踩
本文使用 Python 语言实现 A* 算法。
算法流程和原理不赘述。
代码文件结构:
point.py
import sys
class Point(object):
def __init__(self, x: int, y: int):
self.x = x
self.y = y
self.cost = sys.maxsize
self.parent = None
map.py
from typing import Tuple, List from point import Point class Map(object): def __init__(self, width: int, height: int, obstacles: List[Tuple[int, int]] = []): self.width = width self.height = height self.obstacles = [Point(x=osc[0], y=osc[1]) for osc in obstacles] def is_obstacle(self, i: int, j: int): for p in self.obstacles: if i==p.x and j==p.y: return True return False
a_star.py
有可视化的代码,最终生成视频,生成后可将中间生成的图片删除。
import os import sys import time from typing import Tuple, List from matplotlib.patches import Rectangle import cv2 import glob from point import Point from map import Map class AStar(object): """ A* algorithm """ def __init__(self, map: Map, origin: Tuple[int, int], target: Tuple[int, int]): """ initialise :param map: map :param origin: starting point coordinates :param target: ending point coordinates """ self.map = map self.origin = Point(x=origin[0], y=origin[1]) self.target = Point(x=target[0], y=target[1]) self.open_points = [] self.close_points = [] def _basic_cost(self, point: Point): """ basic cost from origin """ return abs(point.x - self.origin.x) + abs(point.y - self.origin.y) def _heuristic_cost(self, point: Point): """ estimated cost to target """ return abs(point.x - self.target.x) + abs(point.y - self.target.y) def _total_cost(self, point: Point): """ total cost """ return self._basic_cost(point) + self._heuristic_cost(point) def _is_valid_point(self, x: int, y: int): if x < 0 or y < 0: return False if x >= self.map.width or y >= self.map.height: return False if self.map.is_obstacle(x, y): return False return True def _in_point_list(self, point: Point, points: List[Point]): for p in points: if point.x == p.x and point.y == p.y: return True return False def _in_open_list(self, point: Point): return self._in_point_list(point, self.open_points) def _in_close_list(self, point: Point): return self._in_point_list(point, self.close_points) def run(self, ax, plt): """ run alogrithm and visualise :param ax: matplotlib.axes._subplots.AxesSubplot :param plt: matplotlib.pyplot """ tms = time.time() self.origin.cost = 0 self.open_points.append(self.origin) while True: idx = self._select_from_open_list() if idx < 0: print("No path found, algorithm failed!") return point = self.open_points[idx] rectangle = Rectangle(xy=(point.x, point.y), width=1, height=1, color='cyan') ax.add_patch(rectangle) self._save_image(plt) if point.x == self.target.x and point.y == self.target.y: return self._build_path(point=point, tms=tms, ax=ax, plt=plt) del self.open_points[idx] self.close_points.append(point) # neighbours self._process_point(x=point.x - 1, y=point.y, parent=point) self._process_point(x=point.x, y=point.y - 1, parent=point) self._process_point(x=point.x + 1, y=point.y, parent=point) self._process_point(x=point.x, y=point.y + 1, parent=point) def _save_image(self, plt): """ save images to outputs folder """ millisecond = int(round(time.time() * 1000)) file_name = './outputs/' + str(millisecond) + '.png' plt.savefig(file_name) def _process_point(self, x: int, y: int, parent: Point): """ process current point :param x: x coordinate :param y: y coordinate :param parent: current point's parent point """ # do nothing for invalid point if not self._is_valid_point(x, y): return # do nothing for visited point point = Point(x, y) if self._in_close_list(point): return print("process point [{}, {}], cost: {}".format(point.x, point.y, point.cost)) if not self._in_open_list(point): point.parent = parent point.cost = self._total_cost(point) self.open_points.append(point) def _select_from_open_list(self) -> int: """ select the point with least cost from the open list :return idx_select: the index of the selected point in the open list """ idx = 0 idx_select = -1 min_cost = sys.maxsize for point in self.open_points: cost = self._total_cost(point) if cost < min_cost: min_cost = cost idx_select = idx idx += 1 return idx_select def _build_path(self, point: Point, tms: float, ax, plt): """ build the whole path after algorithm terminates :param point: ending point :param tms: start time :param ax: matplotlib.axes._subplots.AxesSubplot :param plt: matplotlib.pyplot """ # get whole path path = [] while True: path.insert(0, point) if point.x == self.origin.x and point.y == self.origin.y: break else: point = point.parent # visualise for p in path: rec = Rectangle(xy=(p.x, p.y), width=1, height=1, color='green') ax.add_patch(rec) plt.draw() self._save_image(plt) self._merge_video() tme = time.time() print("Algorithm finishes in {} s".format(int(tme - tms))) def _merge_video(self): """ merge images to video """ # get image files image_files = [] file_names = [] for file_name in glob.glob('./outputs/*.png'): file_names.append(file_name) image = cv2.imread(filename=file_name) height, width, layers = image.shape size = (width, height) image_files.append(image) # generate video tm= time.time() video_path = f'./outputs/{round(tm)}.avi' fourcc = cv2.VideoWriter_fourcc(*'DIVX') video = cv2.VideoWriter(video_path, fourcc, 5, size) for image in image_files: video.write(image) video.release() # delete original image files for file in file_names: os.remove(file)
main.py(主程序)
from matplotlib import pyplot as plt from matplotlib.patches import Rectangle from map import Map from a_star import AStar """ map settings """ width, height = 10, 15 origin, target = (0, 0), (width - 1, height - 1) obstacles = [(round(width * (1 / 4)), j) for j in range(round(height * (2 / 3)))] + [ (round(width * (1 / 2)), j) for j in range(round(height * (1 / 3)), height)] + [ (round(width * (3 / 4)), j) for j in range(round(height * (2 / 3)))] map_ = Map(width=width, height=height, obstacles=obstacles) """ visual settings """ plt.figure(figsize=(5, 5)) ax = plt.gca() ax.set_xlim([0, map_.width]) ax.set_ylim([0, map_.height]) for i in range(map_.width): for j in range(map_.height): if map_.is_obstacle(i, j): rectangle = Rectangle(xy=(i, j), width=1, height=1, color='gray') ax.add_patch(rectangle) else: rectangle = Rectangle(xy=(i, j), width=1, height=1, edgecolor='gray', facecolor='white') ax.add_patch(rectangle) rectangle = Rectangle(xy=origin, width=1, height=1, facecolor='blue') ax.add_patch(rectangle) rectangle = Rectangle(xy=target, width=1, height=1, facecolor='red') ax.add_patch(rectangle) plt.axis('equal') # set equal scaling plt.axis('off') # turn off axis lines and labels plt.tight_layout() """ algorithm """ a_star = AStar(map=map_, origin=(0, 0), target=(width - 1, height - 1)) a_star.run(ax, plt)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。