赞
踩
gym.make是启动一个openai环境游戏的指令,执行时,首先env.reset()初始化环境,返回初始状态,将其存储在变量obeservation之中,env.step(action)是将游戏环境推进一步的指令,action=0对应于将小车推向左侧,1对应于推向右侧,输出5个变量,reward是即时奖励,小车在[-2.4,2.4]范围内且杆未超过20.9°则奖励为1,否则为0。done是一个变量,结束为True,info包含调试等所需的信息
状态存储在observation中,变量observation是4个变量组成的列表,每个变量内容为:
import matplotlib.pyplot as plt from matplotlib import animation from IPython.display import HTML import gym import numpy as np # 动画显示函数 def display_frames_as_gif(frames): plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72) patch = plt.imshow(frames[0]) plt.axis('off') def animate(i): patch.set_data(frames[i]) anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50) anim.save('./image/movie_cartpole.mp4') plt.close() # 防止显示两个输出 return HTML(anim.to_jshtml()) # 随机移动CartPole env = gym.make('CartPole-v0', render_mode='rgb_array') observation = env.reset() frames = [] for step in range(200): frames.append(env.render()) action = np.random.choice(2) observation, reward, done, _, info = env.step(action) env.close() # 关闭环境 display_frames_as_gif(frames)
ENV = 'CartPole-v0' NUM_DIZITIZED = 6 env = gym.make(ENV) observation = env.reset() #求取用于离散化的阈值 def bins(clip_min,clip_max,num): return np.linspace(clip_min,clip_max,num+1)[1:-1] def digitize_state(observation): cart_pos,cart_v,pole_angle,pole_v = observation digitized = [ np.digitize(cart_pos,bins = bins(-2.4,2.4,NUM_DIZITIZED)), np.digitize(cart_v,bins = bins(-3,3,NUM_DIZITIZED)), np.digitize(pole_angle,bins = bins(-0.5,0.5,NUM_DIZITIZED)), np.digitize(cart_pos,bins = bins(-2.0,2.0,NUM_DIZITIZED)), ] return sum([x*(NUM_DIZITIZED**i) for i,x in enumerate(digitized)])
实现三个类 Agent、Brain、Environment
Agent类表示倒立摆小推车对象,由两个函数组成,更新Q函数Update_Q_function和确定下一个动作函数get_action,有一个Brain类的对象作为成员变量
Brain类看作Agent的大脑,有4个函数:函数Bins和digitize_state用于离散化观察到的observation,函数Update_Q_function更新Q表,函数decision_action确定来自Q表的动作。将Agent和Brain分开是为了深度强化学习只用改brain就可以了
import matplotlib.pyplot as plt from matplotlib import animation from IPython.display import HTML import gym import numpy as np # 动画显示函数 def display_frames_as_gif(frames): plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72) patch = plt.imshow(frames[0]) plt.axis('off') def animate(i): patch.set_data(frames[i]) anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50) anim.save('./image/movie_cartpole.mp4') plt.close() # 防止显示两个输出 return HTML(anim.to_jshtml()) ENV = 'CartPole-v0' NUM_DIZITIZED = 6 GAMMA = 0.99 ETA = 0.5 MAX_STEPS = 200 NUM_EPISODES = 1000 class Brain: '''这是一个智能体大脑的类,用于进行Q学习''' def __init__(self,num_states,num_actions) -> None: self.num_actions = num_actions self.q_table = np.random.uniform(low=0,high=1,size=(NUM_DIZITIZED**num_states,num_actions)) def bins(self,clip_min,clip_max,num): return np.linspace(clip_min,clip_max,num+1)[1:-1] def digitize_state(self, observation): cart_pos, cart_v, pole_angle, pole_v = observation digitized = [ np.digitize(cart_pos, bins=self.bins(-2.4, 2.4, NUM_DIZITIZED)), np.digitize(cart_v, bins=self.bins(-3.0, 3.0, NUM_DIZITIZED)), np.digitize(pole_angle, bins=self.bins(-0.5, 0.5, NUM_DIZITIZED)), np.digitize(pole_v, bins=self.bins(-2.0, 2.0, NUM_DIZITIZED)) ] return sum([x * (NUM_DIZITIZED ** i) for i, x in enumerate(digitized)]) def update_Q_table(self,observation,action,reward,observation_next): state = self.digitize_state(observation) state_next = self.digitize_state(observation_next) Max_Q_next = max(self.q_table[state_next][:]) self.q_table[state,action] = self.q_table[state,action] + ETA * (reward+GAMMA*Max_Q_next-self.q_table[state,action]) def decide_action(self,observation,episode): state = self.digitize_state(observation) epsilon = 0.5*1/(episode+1) if epsilon <= np.random.uniform(0,1): action = np.argmax(self.q_table[state][:]) else: action = np.random.choice(self.num_actions) return action class Agent: '''CartPole智能体,带有杆的小车''' def __init__(self,num_states,num_actions) -> None: self.brain = Brain(num_states,num_actions) #为智能体创建大脑以做出决策 def update_Q_function(self,observation,action,reward,observation_next): '''Q函数的更新''' self.brain.update_Q_table( observation,action,reward,observation_next ) def get_action(self,observation,step): '''动作的确定''' action = self.brain.decide_action(observation,step) return action class Environment: # ... '''环境类''' def __init__(self) -> None: self.env = gym.make(ENV,render_mode='rgb_array') num_states = self.env.observation_space.shape[0] num_actions = self.env.action_space.n self.agent = Agent(num_states,num_actions) def run(self): # ... complete_episodes = 0 is_episode_final = False frames = [] for episode in range(NUM_EPISODES): observation = self.env.reset() obs_values = observation[0] # 只使用第一个元素 for step in range(MAX_STEPS): if is_episode_final: frames.append(self.env.render()) # 只使用 observation 的数值部分 obs_values = observation[0] if isinstance(observation, tuple) else observation action = self.agent.get_action(obs_values, episode) observation_next, _, done, _, _ = self.env.step(action) obs_values_next = observation_next[0] if isinstance(observation_next, tuple) else observation_next self.agent.update_Q_function(obs_values, action, reward, obs_values_next) observation = observation_next # 更新 observation 为下一步使用 if done: print(f'{episode} Episode: Finished after {step + 1} time steps') break if is_episode_final: display_frames_as_gif(frames) break if complete_episodes >= 10: print('10回合连续成功') is_episode_final = True env = gym.make('CartPole-v0', render_mode='rgb_array') observation = env.reset() # print(observation) # print(len(observation)) frames = [] for step in range(200): frames.append(env.render()) action = np.random.choice(2) observation, reward, done, _, info = env.step(action) env.close() # 关闭环境 display_frames_as_gif(frames)
cartpole_env = Environment()
cartpole_env.run()
0 Episode: Finished after 15 time steps
1 Episode: Finished after 28 time steps
2 Episode: Finished after 10 time steps
3 Episode: Finished after 10 time steps
4 Episode: Finished after 44 time steps
5 Episode: Finished after 72 time steps
6 Episode: Finished after 9 time steps
7 Episode: Finished after 29 time steps
8 Episode: Finished after 42 time steps
9 Episode: Finished after 9 time steps
……
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。