当前位置:   article > 正文

[强化学习总结6] actor-critic算法_actorcritic算法

actorcritic算法
  • actor:策略
  • critic:评估价值

Actor-Critic 是囊括一系列算法的整体架构,目前很多高效的前沿算法都属于 Actor-Critic 算法,本章接下来将会介绍一种最简单的 Actor-Critic 算法。需要明确的是,Actor-Critic 算法本质上是基于策略的算法,因为这一系列算法的目标都是优化一个带参数的策略,只是会额外学习价值函数,从而帮助策略函数更好地学习。

1 核心

  • 在 REINFORCE 算法中,目标函数的梯度中有一项轨迹回报(trajectory return),用于指导策略(policy, π(s | a) )的更新。REINFOCE 算法用蒙特卡洛方法来估计q(s, a)。
    • 其实就是用回报作为策略的加权值,所以这里可以推广出一个一般形式,只要一个值能作为aciton的好坏的判断,就可以做为权重。
    • 所以critic学的就是一个权重,输出是一个值。

    • 权重可以有以下这些:

actor-critic优势:

  • 事实上,用q值或者v值本质上也是用奖励来进行指导,但是用神经网络进行估计的方法可以减小方差、提高鲁棒性。除此之外,REINFORCE 算法基于蒙特卡洛采样,只能在序列结束后进行更新,这同时也要求任务具有有限的步数,而 Actor-Critic 算法则可以在每一步之后都进行更新,并且不对任务的步数做限制。

2 Actor-Critic

我们将 Actor-Critic 分为两个部分:Actor(策略网络)和 Critic(价值网络)。

  • Actor 要做的是与环境交互,并在 Critic 价值函数的指导下用策略梯度学习一个更好的策略。
  • Critic 要做的是通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于判断在当前状态什么动作是好的,什么动作不是好的,进而帮助 Actor 进行策略更新。

2.1 code

说的再多,不如看看代码。Actor-Critic 算法

网络

  1. class PolicyNet(torch.nn.Module):
  2. def __init__(self, state_dim, hidden_dim, action_dim):
  3. super(PolicyNet, self).__init__()
  4. self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
  5. self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
  6. def forward(self, x):
  7. x = F.relu(self.fc1(x))
  8. return F.softmax(self.fc2(x), dim=1)
  9. class ValueNet(torch.nn.Module):
  10. def __init__(self, state_dim, hidden_dim):
  11. super(ValueNet, self).__init__()
  12. self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
  13. self.fc2 = torch.nn.Linear(hidden_dim, 1) ## 输出是1个值
  14. def forward(self, x):
  15. x = F.relu(self.fc1(x))
  16. return self.fc2(x)
  1. class ActorCritic:
  2. def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
  3. gamma, device):
  4. # 策略网络
  5. self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
  6. self.critic = ValueNet(state_dim, hidden_dim).to(device) # 价值网络
  7. # 策略网络优化器
  8. self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
  9. lr=actor_lr)
  10. self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
  11. lr=critic_lr) # 价值网络优化器
  12. self.gamma = gamma
  13. self.device = device
  14. def take_action(self, state):
  15. state = torch.tensor([state], dtype=torch.float).to(self.device)
  16. probs = self.actor(state)
  17. action_dist = torch.distributions.Categorical(probs)
  18. action = action_dist.sample()
  19. return action.item()
  20. def update(self, transition_dict):
  21. states = torch.tensor(transition_dict['states'],
  22. dtype=torch.float).to(self.device)
  23. actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
  24. self.device)
  25. rewards = torch.tensor(transition_dict['rewards'],
  26. dtype=torch.float).view(-1, 1).to(self.device)
  27. next_states = torch.tensor(transition_dict['next_states'],
  28. dtype=torch.float).to(self.device)
  29. dones = torch.tensor(transition_dict['dones'],
  30. dtype=torch.float).view(-1, 1).to(self.device)
  31. # 时序差分目标
  32. td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
  33. td_delta = td_target - self.critic(states) # 时序差分误差
  34. log_probs = torch.log(self.actor(states).gather(1, actions))
  35. actor_loss = torch.mean(-log_probs * td_delta.detach())
  36. # 均方误差损失函数
  37. critic_loss = torch.mean(
  38. F.mse_loss(self.critic(states), td_target.detach()))
  39. self.actor_optimizer.zero_grad()
  40. self.critic_optimizer.zero_grad()
  41. actor_loss.backward() # 计算策略网络的梯度
  42. critic_loss.backward() # 计算价值网络的梯度
  43. self.actor_optimizer.step() # 更新策略网络的参数
  44. self.critic_optimizer.step() # 更新价值网络的参数

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/552053
推荐阅读
相关标签
  

闽ICP备14008679号