当前位置:   article > 正文

动手学强化学习 第 19 章 目标导向的强化学习 训练代码

动手学强化学习 第 19 章 目标导向的强化学习 训练代码

基于 Hands-on-RL/第19章-目标导向的强化学习.ipynb at main · boyu-ai/Hands-on-RL · GitHub

理论 目标导向的强化学习

这节改动不大,官方代码基本可以直接运行,应该只有一个warning

运行环境

  1. Debian GNU/Linux 12
  2. Python 3.9.19
  3. torch 2.0.1
  4. gym 0.26.2

运行代码

HER.py

  1. #!/usr/bin/env python
  2. import torch
  3. import torch.nn.functional as F
  4. import numpy as np
  5. import random
  6. from tqdm import tqdm
  7. import collections
  8. import matplotlib.pyplot as plt
  9. class WorldEnv:
  10. def __init__(self):
  11. self.distance_threshold = 0.15
  12. self.action_bound = 1
  13. def reset(self): # 重置环境
  14. # 生成一个目标状态, 坐标范围是[3.54.5, 3.54.5]
  15. self.goal = np.array(
  16. [4 + random.uniform(-0.5, 0.5), 4 + random.uniform(-0.5, 0.5)])
  17. self.state = np.array([0, 0]) # 初始状态
  18. self.count = 0
  19. return np.hstack((self.state, self.goal))
  20. def step(self, action):
  21. action = np.clip(action, -self.action_bound, self.action_bound)
  22. x = max(0, min(5, self.state[0] + action[0]))
  23. y = max(0, min(5, self.state[1] + action[1]))
  24. self.state = np.array([x, y])
  25. self.count += 1
  26. dis = np.sqrt(np.sum(np.square(self.state - self.goal)))
  27. reward = -1.0 if dis > self.distance_threshold else 0
  28. if dis <= self.distance_threshold or self.count == 50:
  29. done = True
  30. else:
  31. done = False
  32. return np.hstack((self.state, self.goal)), reward, done
  33. class PolicyNet(torch.nn.Module):
  34. def __init__(self, state_dim, hidden_dim, action_dim, action_bound):
  35. super(PolicyNet, self).__init__()
  36. self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
  37. self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
  38. self.fc3 = torch.nn.Linear(hidden_dim, action_dim)
  39. self.action_bound = action_bound # action_bound是环境可以接受的动作最大值
  40. def forward(self, x):
  41. x = F.relu(self.fc2(F.relu(self.fc1(x))))
  42. return torch.tanh(self.fc3(x)) * self.action_bound
  43. class QValueNet(torch.nn.Module):
  44. def __init__(self, state_dim, hidden_dim, action_dim):
  45. super(QValueNet, self).__init__()
  46. self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)
  47. self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
  48. self.fc3 = torch.nn.Linear(hidden_dim, 1)
  49. def forward(self, x, a):
  50. cat = torch.cat([x, a], dim=1) # 拼接状态和动作
  51. x = F.relu(self.fc2(F.relu(self.fc1(cat))))
  52. return self.fc3(x)
  53. class DDPG:
  54. ''' DDPG算法 '''
  55. def __init__(self, state_dim, hidden_dim, action_dim, action_bound,
  56. actor_lr, critic_lr, sigma, tau, gamma, device):
  57. self.action_dim = action_dim
  58. self.actor = PolicyNet(state_dim, hidden_dim, action_dim,
  59. action_bound).to(device)
  60. self.critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)
  61. self.target_actor = PolicyNet(state_dim, hidden_dim, action_dim,
  62. action_bound).to(device)
  63. self.target_critic = QValueNet(state_dim, hidden_dim,
  64. action_dim).to(device)
  65. # 初始化目标价值网络并使其参数和价值网络一样
  66. self.target_critic.load_state_dict(self.critic.state_dict())
  67. # 初始化目标策略网络并使其参数和策略网络一样
  68. self.target_actor.load_state_dict(self.actor.state_dict())
  69. self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
  70. lr=actor_lr)
  71. self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
  72. lr=critic_lr)
  73. self.gamma = gamma
  74. self.sigma = sigma # 高斯噪声的标准差,均值直接设为0
  75. self.tau = tau # 目标网络软更新参数
  76. self.action_bound = action_bound
  77. self.device = device
  78. def take_action(self, state):
  79. state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
  80. action = self.actor(state).detach().cpu().numpy()[0]
  81. # 给动作添加噪声,增加探索
  82. action = action + self.sigma * np.random.randn(self.action_dim)
  83. return action
  84. def soft_update(self, net, target_net):
  85. for param_target, param in zip(target_net.parameters(),
  86. net.parameters()):
  87. param_target.data.copy_(param_target.data * (1.0 - self.tau) +
  88. param.data * self.tau)
  89. def update(self, transition_dict):
  90. states = torch.tensor(transition_dict['states'],
  91. dtype=torch.float).to(self.device)
  92. actions = torch.tensor(transition_dict['actions'],
  93. dtype=torch.float).to(self.device)
  94. rewards = torch.tensor(transition_dict['rewards'],
  95. dtype=torch.float).view(-1, 1).to(self.device)
  96. next_states = torch.tensor(transition_dict['next_states'],
  97. dtype=torch.float).to(self.device)
  98. dones = torch.tensor(transition_dict['dones'],
  99. dtype=torch.float).view(-1, 1).to(self.device)
  100. next_q_values = self.target_critic(next_states,
  101. self.target_actor(next_states))
  102. q_targets = rewards + self.gamma * next_q_values * (1 - dones)
  103. # MSE损失函数
  104. critic_loss = torch.mean(
  105. F.mse_loss(self.critic(states, actions), q_targets))
  106. self.critic_optimizer.zero_grad()
  107. critic_loss.backward()
  108. self.critic_optimizer.step()
  109. # 策略网络就是为了使Q值最大化
  110. actor_loss = -torch.mean(self.critic(states, self.actor(states)))
  111. self.actor_optimizer.zero_grad()
  112. actor_loss.backward()
  113. self.actor_optimizer.step()
  114. self.soft_update(self.actor, self.target_actor) # 软更新策略网络
  115. self.soft_update(self.critic, self.target_critic) # 软更新价值网络
  116. class Trajectory:
  117. ''' 用来记录一条完整轨迹 '''
  118. def __init__(self, init_state):
  119. self.states = [init_state]
  120. self.actions = []
  121. self.rewards = []
  122. self.dones = []
  123. self.length = 0
  124. def store_step(self, action, state, reward, done):
  125. self.actions.append(action)
  126. self.states.append(state)
  127. self.rewards.append(reward)
  128. self.dones.append(done)
  129. self.length += 1
  130. class ReplayBuffer_Trajectory:
  131. ''' 存储轨迹的经验回放池 '''
  132. def __init__(self, capacity):
  133. self.buffer = collections.deque(maxlen=capacity)
  134. def add_trajectory(self, trajectory):
  135. self.buffer.append(trajectory)
  136. def size(self):
  137. return len(self.buffer)
  138. def sample(self, batch_size, use_her, dis_threshold=0.15, her_ratio=0.8):
  139. batch = dict(states=[],
  140. actions=[],
  141. next_states=[],
  142. rewards=[],
  143. dones=[])
  144. for _ in range(batch_size):
  145. traj = random.sample(self.buffer, 1)[0]
  146. step_state = np.random.randint(traj.length)
  147. state = traj.states[step_state]
  148. next_state = traj.states[step_state + 1]
  149. action = traj.actions[step_state]
  150. reward = traj.rewards[step_state]
  151. done = traj.dones[step_state]
  152. if use_her and np.random.uniform() <= her_ratio:
  153. step_goal = np.random.randint(step_state + 1, traj.length + 1)
  154. goal = traj.states[step_goal][:2] # 使用HER算法的future方案设置目标
  155. dis = np.sqrt(np.sum(np.square(next_state[:2] - goal)))
  156. reward = -1.0 if dis > dis_threshold else 0
  157. done = False if dis > dis_threshold else True
  158. state = np.hstack((state[:2], goal))
  159. next_state = np.hstack((next_state[:2], goal))
  160. batch['states'].append(state)
  161. batch['next_states'].append(next_state)
  162. batch['actions'].append(action)
  163. batch['rewards'].append(reward)
  164. batch['dones'].append(done)
  165. batch['states'] = np.array(batch['states'])
  166. batch['next_states'] = np.array(batch['next_states'])
  167. batch['actions'] = np.array(batch['actions'])
  168. return batch
  169. actor_lr = 1e-3
  170. critic_lr = 1e-3
  171. hidden_dim = 128
  172. state_dim = 4
  173. action_dim = 2
  174. action_bound = 1
  175. sigma = 0.1
  176. tau = 0.005
  177. gamma = 0.98
  178. num_episodes = 2000
  179. n_train = 20
  180. batch_size = 256
  181. minimal_episodes = 200
  182. buffer_size = 10000
  183. device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
  184. "cpu")
  185. random.seed(0)
  186. np.random.seed(0)
  187. torch.manual_seed(0)
  188. env = WorldEnv()
  189. replay_buffer = ReplayBuffer_Trajectory(buffer_size)
  190. agent = DDPG(state_dim, hidden_dim, action_dim, action_bound, actor_lr,
  191. critic_lr, sigma, tau, gamma, device)
  192. return_list = []
  193. for i in range(10):
  194. with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
  195. for i_episode in range(int(num_episodes / 10)):
  196. episode_return = 0
  197. state = env.reset()
  198. traj = Trajectory(state)
  199. done = False
  200. while not done:
  201. action = agent.take_action(state)
  202. state, reward, done = env.step(action)
  203. episode_return += reward
  204. traj.store_step(action, state, reward, done)
  205. replay_buffer.add_trajectory(traj)
  206. return_list.append(episode_return)
  207. if replay_buffer.size() >= minimal_episodes:
  208. for _ in range(n_train):
  209. transition_dict = replay_buffer.sample(batch_size, True)
  210. agent.update(transition_dict)
  211. if (i_episode + 1) % 10 == 0:
  212. pbar.set_postfix({
  213. 'episode':
  214. '%d' % (num_episodes / 10 * i + i_episode + 1),
  215. 'return':
  216. '%.3f' % np.mean(return_list[-10:])
  217. })
  218. pbar.update(1)
  219. episodes_list = list(range(len(return_list)))
  220. plt.plot(episodes_list, return_list)
  221. plt.xlabel('Episodes')
  222. plt.ylabel('Returns')
  223. plt.title('DDPG with HER on {}'.format('GridWorld'))
  224. plt.show()
  225. random.seed(0)
  226. np.random.seed(0)
  227. torch.manual_seed(0)
  228. env = WorldEnv()
  229. replay_buffer = ReplayBuffer_Trajectory(buffer_size)
  230. agent = DDPG(state_dim, hidden_dim, action_dim, action_bound, actor_lr,
  231. critic_lr, sigma, tau, gamma, device)
  232. return_list = []
  233. for i in range(10):
  234. with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
  235. for i_episode in range(int(num_episodes / 10)):
  236. episode_return = 0
  237. state = env.reset()
  238. traj = Trajectory(state)
  239. done = False
  240. while not done:
  241. action = agent.take_action(state)
  242. state, reward, done = env.step(action)
  243. episode_return += reward
  244. traj.store_step(action, state, reward, done)
  245. replay_buffer.add_trajectory(traj)
  246. return_list.append(episode_return)
  247. if replay_buffer.size() >= minimal_episodes:
  248. for _ in range(n_train):
  249. # 和使用HER训练的唯一区别
  250. transition_dict = replay_buffer.sample(batch_size, False)
  251. agent.update(transition_dict)
  252. if (i_episode + 1) % 10 == 0:
  253. pbar.set_postfix({
  254. 'episode':
  255. '%d' % (num_episodes / 10 * i + i_episode + 1),
  256. 'return':
  257. '%.3f' % np.mean(return_list[-10:])
  258. })
  259. pbar.update(1)
  260. episodes_list = list(range(len(return_list)))
  261. plt.plot(episodes_list, return_list)
  262. plt.xlabel('Episodes')
  263. plt.ylabel('Returns')
  264. plt.title('DDPG without HER on {}'.format('GridWorld'))
  265. plt.show()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/948944
推荐阅读
相关标签
  

闽ICP备14008679号