当前位置:   article > 正文

使用Pytorch实现强化学习——DQN算法_dqn pytorch

dqn pytorch

目录

一、强化学习的主要构成

二、基于python的强化学习框架

三、gym

四、DQN算法

1.DQN算法两个特点

(1)经验回放

(2)目标网络

2.DQN算法的流程

五、使用pytorch实现DQN算法

1.replay memory

2.神经网络部分

3.Agent

4.模型训练函数

5.训练模型

6.实验结果

六、补充说明


一、强化学习的主要构成

强化学习主要由两部分组成:智能体(agent)和环境(env)。在强化学习过程中,智能体与环境一直在交互。智能体在环境里面获取某个状态s_{t}后,它会利用该状态输出一个动作(action)a_{t}。然后这个动作会在环境之中被执行,环境会根据智能体采取的动作,输出下一个状态s_{t+1}以及当前这个动作带来的奖励r_{t}。智能体的目的就是尽可能多地从环境中获取奖励

二、基于python的强化学习框架

本次我使用到的框架是pytorch,因为DQN算法的实现包含了部分的神经网络,这部分对我来说使用pytorch会更顺手,所以就选择了这个。

三、gym

gym 定义了一套接口,用于描述强化学习中的环境这一概念,同时在其官方库中,包含了一些已实现的环境。

四、DQN算法

传统的强化学习算法使用的是Q表格存储状态价值函数或者动作价值函数,但是实际应用时,问题在的环境可能有很多种状态,甚至数不清,所以这种情况下使用离散的Q表格存储价值函数会非常不合理,所以DQN(Deep Q-learning)算法,使用神经网络拟合动作价值函数Q(s, a)

通常DQN算法只能处理动作离散,状态连续的情况,使用神经网络拟合出动作价值函数Q(s, a), 然后针对动作价值函数,选择出当状态state固定的Q值最大的动作a。

1.DQN算法两个特点

(1)经验回放

每一次的样本都放到样本池中,所以可以多次反复的使用一个样本,重复利用。训练时一次随机抽取多个数据样本来进行训练。

(2)目标网络

DQN算法的更新目标时让Q(s, a)逼近r + \gamma max_{a^{'}}Q(s^{'}, a^{'}), 但是如果两个Q使用一个网络计算,那么Q的目标值也在不断改变, 容易造成神经网络训练的不稳定。DQN使用目标网络,训练时目标值Q使用目标网络来计算,目标网络的参数定时和训练网络的参数同步。

2.DQN算法的流程

五、使用pytorch实现DQN算法

  1. import time
  2. import random
  3. import torch
  4. from torch import nn
  5. from torch import optim
  6. import gym
  7. import numpy as np
  8. import matplotlib.pyplot as plt
  9. from collections import deque, namedtuple # 队列类型
  10. from tqdm import tqdm # 绘制进度条用
  11. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  12. Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))

1.replay memory

  1. class ReplayMemory(object):
  2. def __init__(self, memory_size):
  3. self.memory = deque([], maxlen=memory_size)
  4. def sample(self, batch_size):
  5. batch_data = random.sample(self.memory, batch_size)
  6. state, action, reward, next_state, done = zip(*batch_data)
  7. return state, action, reward, next_state, done
  8. def push(self, *args):
  9. # *args: 把传进来的所有参数都打包起来生成元组形式
  10. # self.push(1, 2, 3, 4, 5)
  11. # args = (1, 2, 3, 4, 5)
  12. self.memory.append(Transition(*args))
  13. def __len__(self):
  14. return len(self.memory)

2.神经网络部分

  1. class Qnet(nn.Module):
  2. def __init__(self, n_observations, n_actions):
  3. super(Qnet, self).__init__()
  4. self.model = nn.Sequential(
  5. nn.Linear(n_observations, 128),
  6. nn.ReLU(),
  7. nn.Linear(128, n_actions)
  8. )
  9. def forward(self, state):
  10. return self.model(state)

3.Agent

  1. class Agent(object):
  2. def __init__(self, observation_dim, action_dim, gamma, lr, epsilon, target_update):
  3. self.action_dim = action_dim
  4. self.q_net = Qnet(observation_dim, action_dim).to(device)
  5. self.target_q_net = Qnet(observation_dim, action_dim).to(device)
  6. self.gamma = gamma
  7. self.lr = lr
  8. self.epsilon = epsilon
  9. self.target_update = target_update
  10. self.count = 0
  11. self.optimizer = optim.Adam(params=self.q_net.parameters(), lr=lr)
  12. self.loss = nn.MSELoss()
  13. def take_action(self, state):
  14. if np.random.uniform(0, 1) < 1 - self.epsilon:
  15. state = torch.tensor(state, dtype=torch.float).to(device)
  16. action = torch.argmax(self.q_net(state)).item()
  17. else:
  18. action = np.random.choice(self.action_dim)
  19. return action
  20. def update(self, transition_dict):
  21. states = transition_dict.state
  22. actions = np.expand_dims(transition_dict.action, axis=-1) # 扩充维度
  23. rewards = np.expand_dims(transition_dict.reward, axis=-1) # 扩充维度
  24. next_states = transition_dict.next_state
  25. dones = np.expand_dims(transition_dict.done, axis=-1) # 扩充维度
  26. states = torch.tensor(states, dtype=torch.float).to(device)
  27. actions = torch.tensor(actions, dtype=torch.int64).to(device)
  28. rewards = torch.tensor(rewards, dtype=torch.float).to(device)
  29. next_states = torch.tensor(next_states, dtype=torch.float).to(device)
  30. dones = torch.tensor(dones, dtype=torch.float).to(device)
  31. # update q_values
  32. # gather(1, acitons)意思是dim=1按行号索引, index=actions
  33. # actions=[[1, 2], [0, 1]] 意思是索引出[[第一行第2个元素, 第1行第3个元素],[第2行第1个元素, 第2行第2个元素]]
  34. # 相反,如果是这样
  35. # gather(0, acitons)意思是dim=0按列号索引, index=actions
  36. # actions=[[1, 2], [0, 1]] 意思是索引出[[第一列第2个元素, 第2列第3个元素],[第1列第1个元素, 第2列第2个元素]]
  37. # states.shape(64, 4) actions.shape(64, 1), 每一行是一个样本,所以这里用dim=1很合适
  38. predict_q_values = self.q_net(states).gather(1, actions)
  39. with torch.no_grad():
  40. # max(1) 即 max(dim=1)在行向找最大值,这样的话shape(64, ), 所以再加一个view(-1, 1)扩增至(64, 1)
  41. max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)
  42. q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)
  43. l = self.loss(predict_q_values, q_targets)
  44. self.optimizer.zero_grad()
  45. l.backward()
  46. self.optimizer.step()
  47. if self.count % self.target_update == 0:
  48. # copy model parameters
  49. self.target_q_net.load_state_dict(self.q_net.state_dict())
  50. self.count += 1

4.模型训练函数

  1. def run_episode(env, agent, repalymemory, batch_size):
  2. state = env.reset()
  3. reward_total = 0
  4. while True:
  5. action = agent.take_action(state)
  6. next_state, reward, done, _ = env.step(action)
  7. # print(reward)
  8. repalymemory.push(state, action, reward, next_state, done)
  9. reward_total += reward
  10. if len(repalymemory) > batch_size:
  11. state_batch, action_batch, reward_batch, next_state_batch, done_batch = repalymemory.sample(batch_size)
  12. T_data = Transition(state_batch, action_batch, reward_batch, next_state_batch, done_batch)
  13. # print(T_data)
  14. agent.update(T_data)
  15. state = next_state
  16. if done:
  17. break
  18. return reward_total
  19. def episode_evaluate(env, agent, render):
  20. reward_list = []
  21. for i in range(5):
  22. state = env.reset()
  23. reward_episode = 0
  24. while True:
  25. action = agent.take_action(state)
  26. next_state, reward, done, _ = env.step(action)
  27. reward_episode += reward
  28. state = next_state
  29. if done:
  30. break
  31. if render:
  32. env.render()
  33. reward_list.append(reward_episode)
  34. return np.mean(reward_list).item()
  35. def test(env, agent, delay_time):
  36. state = env.reset()
  37. reward_episode = 0
  38. while True:
  39. action = agent.take_action(state)
  40. next_state, reward, done, _ = env.step(action)
  41. reward_episode += reward
  42. state = next_state
  43. if done:
  44. break
  45. env.render()
  46. time.sleep(delay_time)

5.训练CartPole-v0环境模型

模型训练使用到的环境时gym提供的CartPole游戏(Cart Pole - Gymnasium Documentation (farama.org)),这个环境比较经典,小车运行结束的要求有三个:

(1)杆子的角度超过\pm 12

(2)小车位置大于\pm 2.4(小车中心到达显示屏边缘)

(3)小车移动步数超过200(v1是500)

小车每走一步奖励就会+1,所以在v0版本环境中,小车一次episode的最大奖励为200

  1. if __name__ == "__main__":
  2. # print("prepare for RL")
  3. env = gym.make("CartPole-v0")
  4. env_name = "CartPole-v0"
  5. observation_n, action_n = env.observation_space.shape[0], env.action_space.n
  6. # print(observation_n, action_n)
  7. agent = Agent(observation_n, action_n, gamma=0.98, lr=2e-3, epsilon=0.01, target_update=10)
  8. replaymemory = ReplayMemory(memory_size=10000)
  9. batch_size = 64
  10. num_episodes = 200
  11. reward_list = []
  12. # print("start to train model")
  13. # 显示10个进度条
  14. for i in range(10):
  15. with tqdm(total=int(num_episodes/10), desc="Iteration %d" % i) as pbar:
  16. for episode in range(int(num_episodes / 10)):
  17. reward_episode = run_episode(env, agent, replaymemory, batch_size)
  18. reward_list.append(reward_episode)
  19. if (episode+1) % 10 == 0:
  20. test_reward = episode_evaluate(env, agent, False)
  21. # print("Episode %d, total reward: %.3f" % (episode, test_reward))
  22. pbar.set_postfix({
  23. 'episode': '%d' % (num_episodes / 10 * i + episode + 1),
  24. 'return' : '%.3f' % (test_reward)
  25. })
  26. pbar.update(1) # 更新进度条
  27. test(env, agent, 0.5) # 最后用动画观看一下效果
  28. episodes_list = list(range(len(reward_list)))
  29. plt.plot(episodes_list, reward_list)
  30. plt.xlabel('Episodes')
  31. plt.ylabel('Returns')
  32. plt.title('Double DQN on {}'.format(env_name))
  33. plt.show()

6.实验结果

六、补充说明

想要开启动画的话,这句代码里面的False更改为True。

test_reward = episode_evaluate(env, agent, False)

参考资料:

蘑菇书EasyRL (datawhalechina.github.io)

DQN 算法 (boyuai.com)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/601512
推荐阅读
相关标签
  

闽ICP备14008679号