当前位置:   article > 正文

《边做边学深度强化学习:PyTorch程序设计实践》——3.2~3.4倒立摆CartPole

《边做边学深度强化学习:PyTorch程序设计实践》——3.2~3.4倒立摆CartPole

gym.make是启动一个openai环境游戏的指令,执行时,首先env.reset()初始化环境,返回初始状态,将其存储在变量obeservation之中,env.step(action)是将游戏环境推进一步的指令,action=0对应于将小车推向左侧,1对应于推向右侧,输出5个变量,reward是即时奖励,小车在[-2.4,2.4]范围内且杆未超过20.9°则奖励为1,否则为0。done是一个变量,结束为True,info包含调试等所需的信息

3.3.1CartPole的状态

状态存储在observation中,变量observation是4个变量组成的列表,每个变量内容为:

  1. 小车位置(-2.4~2.4)
  2. 小车速度(-∞~+∞)
  3. 杆的角度(-41.8°~41.8°)
  4. 杆的角速度(-∞~+∞)
    每个变量都取连续的值,采用0-5共6个数将其离散化,每个变量用6个值离散化,总共有6^4=1296个状态
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import gym
import numpy as np

# 动画显示函数
def display_frames_as_gif(frames):
    plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])
    
    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    anim.save('./image/movie_cartpole.mp4')
    plt.close()  # 防止显示两个输出
    return HTML(anim.to_jshtml())

# 随机移动CartPole
env = gym.make('CartPole-v0', render_mode='rgb_array')
observation = env.reset()
frames = []

for step in range(200):
    frames.append(env.render())
    action = np.random.choice(2)
    observation, reward, done, _, info = env.step(action)

env.close()  # 关闭环境
display_frames_as_gif(frames)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

在这里插入图片描述

ENV = 'CartPole-v0'
NUM_DIZITIZED = 6

env = gym.make(ENV)
observation = env.reset()

#求取用于离散化的阈值

def bins(clip_min,clip_max,num):
    return np.linspace(clip_min,clip_max,num+1)[1:-1]

def digitize_state(observation):
    cart_pos,cart_v,pole_angle,pole_v = observation
    digitized = [
        np.digitize(cart_pos,bins = bins(-2.4,2.4,NUM_DIZITIZED)),
        np.digitize(cart_v,bins = bins(-3,3,NUM_DIZITIZED)),
        np.digitize(pole_angle,bins = bins(-0.5,0.5,NUM_DIZITIZED)),
        np.digitize(cart_pos,bins = bins(-2.0,2.0,NUM_DIZITIZED)),
    ]
    return sum([x*(NUM_DIZITIZED**i) for i,x in enumerate(digitized)])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

3.4Q学习的实现

实现三个类 Agent、Brain、Environment
Agent类表示倒立摆小推车对象,由两个函数组成,更新Q函数Update_Q_function和确定下一个动作函数get_action,有一个Brain类的对象作为成员变量
Brain类看作Agent的大脑,有4个函数:函数Bins和digitize_state用于离散化观察到的observation,函数Update_Q_function更新Q表,函数decision_action确定来自Q表的动作。将Agent和Brain分开是为了深度强化学习只用改brain就可以了

import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import gym
import numpy as np

# 动画显示函数
def display_frames_as_gif(frames):
    plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])
    
    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    anim.save('./image/movie_cartpole.mp4')
    plt.close()  # 防止显示两个输出
    return HTML(anim.to_jshtml())

ENV = 'CartPole-v0'
NUM_DIZITIZED = 6
GAMMA = 0.99
ETA = 0.5
MAX_STEPS = 200
NUM_EPISODES = 1000

class Brain:

    '''这是一个智能体大脑的类,用于进行Q学习'''
    def __init__(self,num_states,num_actions) -> None:
        self.num_actions = num_actions
        self.q_table = np.random.uniform(low=0,high=1,size=(NUM_DIZITIZED**num_states,num_actions))
    
    def bins(self,clip_min,clip_max,num):
        
        return np.linspace(clip_min,clip_max,num+1)[1:-1]
    
    def digitize_state(self, observation):
        cart_pos, cart_v, pole_angle, pole_v = observation
        digitized = [
            np.digitize(cart_pos, bins=self.bins(-2.4, 2.4, NUM_DIZITIZED)),
            np.digitize(cart_v, bins=self.bins(-3.0, 3.0, NUM_DIZITIZED)),
            np.digitize(pole_angle, bins=self.bins(-0.5, 0.5, NUM_DIZITIZED)),
            np.digitize(pole_v, bins=self.bins(-2.0, 2.0, NUM_DIZITIZED))
        ]
        return sum([x * (NUM_DIZITIZED ** i) for i, x in enumerate(digitized)])


    
    def update_Q_table(self,observation,action,reward,observation_next):

        state = self.digitize_state(observation)
        state_next = self.digitize_state(observation_next)
        Max_Q_next = max(self.q_table[state_next][:])
        self.q_table[state,action] = self.q_table[state,action] + ETA * (reward+GAMMA*Max_Q_next-self.q_table[state,action])

    def decide_action(self,observation,episode):
        state = self.digitize_state(observation)
        epsilon = 0.5*1/(episode+1)

        if epsilon <= np.random.uniform(0,1):
            action = np.argmax(self.q_table[state][:])
        else:
            action = np.random.choice(self.num_actions)
        return action
            

        
class Agent:

    '''CartPole智能体,带有杆的小车'''
    def __init__(self,num_states,num_actions) -> None:
        self.brain = Brain(num_states,num_actions)
        #为智能体创建大脑以做出决策

    def update_Q_function(self,observation,action,reward,observation_next):
        '''Q函数的更新'''
        self.brain.update_Q_table(
            observation,action,reward,observation_next
        )

    def get_action(self,observation,step):
        '''动作的确定'''
        action = self.brain.decide_action(observation,step)
        return action
    

class Environment:
    # ...
    '''环境类'''
    def __init__(self) -> None:
        self.env = gym.make(ENV,render_mode='rgb_array')
        num_states = self.env.observation_space.shape[0]
        num_actions = self.env.action_space.n
        self.agent = Agent(num_states,num_actions)

    def run(self):
        # ...
        complete_episodes = 0
        is_episode_final = False
        frames = []
        for episode in range(NUM_EPISODES):
            observation = self.env.reset()
            obs_values = observation[0]  # 只使用第一个元素
            
            for step in range(MAX_STEPS):
                if is_episode_final:
                    frames.append(self.env.render())

                # 只使用 observation 的数值部分
                obs_values = observation[0] if isinstance(observation, tuple) else observation
                action = self.agent.get_action(obs_values, episode)

                observation_next, _, done, _, _ = self.env.step(action)
                obs_values_next = observation_next[0] if isinstance(observation_next, tuple) else observation_next

                self.agent.update_Q_function(obs_values, action, reward, obs_values_next)
                observation = observation_next  # 更新 observation 为下一步使用


                if done:
                    print(f'{episode} Episode: Finished after {step + 1} time steps')
                    break

            if is_episode_final:
                display_frames_as_gif(frames)
                break

            if complete_episodes >= 10:
                print('10回合连续成功')
                is_episode_final = True

env = gym.make('CartPole-v0', render_mode='rgb_array')
observation = env.reset()
# print(observation)
# print(len(observation))
frames = []

for step in range(200):
    frames.append(env.render())
    action = np.random.choice(2)
    observation, reward, done, _, info = env.step(action)

env.close()  # 关闭环境
display_frames_as_gif(frames)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146

在这里插入图片描述

cartpole_env = Environment()
cartpole_env.run()
  • 1
  • 2
0 Episode: Finished after 15 time steps
1 Episode: Finished after 28 time steps
2 Episode: Finished after 10 time steps
3 Episode: Finished after 10 time steps
4 Episode: Finished after 44 time steps
5 Episode: Finished after 72 time steps
6 Episode: Finished after 9 time steps
7 Episode: Finished after 29 time steps
8 Episode: Finished after 42 time steps
9 Episode: Finished after 9 time steps
……
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/263721
推荐阅读
相关标签
  

闽ICP备14008679号