当前位置:   article > 正文

强化学习——PPO算法_ppo算法一般要训练多少次

ppo算法一般要训练多少次

        PPO算法是对应TRPO算法的简化,PPO相对于TRPO算法更加简洁却更加高效。

        PPO算法的目的主要在于更新一个损失函数

        其中r表示在状态s下所选择的行动a的概率除以旧策略下相同状态选择相同行动的概率,这个称作重要性采样比

        A被称作优势估计,这个函数的实际过程是一个时序差分的过程:

                                

        这里,表示的是在s状态下采取行动a的价值与采取行动θ的价值的差,这是衡量行动a相对于平均水平的价值,假如A大于0,那么a行动会更好,否则就更差。

        clipe是一个切片函数,它的主要作用是限制更新过程中,新决策和旧决策的变化比率从而保证性能的稳定性。clipe函数一般会是以下的形式。

        ​​​​​​​        ​​​​​​​        ​​​​​​​        

        接下来,策略选择函数和优势估计函数都会使用神经网络,并以此为例来进行平衡车的游戏

  1. import gym
  2. from matplotlib import pyplot as plt
  3. import torch
  4. import random
  5. import numpy as np
  6. from IPython import display
  7. #创建环境
  8. env = gym.make('CartPole-v1')
  9. env.reset()
  10. #打印游戏
  11. def show():
  12. plt.imshow(env.render(mode='rgb_array'))
  13. plt.axis('off')
  14. plt.show()
  15. #定义模型
  16. #策略梯度
  17. model = torch.nn.Sequential(
  18. torch.nn.Linear(4,128),
  19. torch.nn.ReLU(),
  20. torch.nn.Linear(128,2),
  21. torch.nn.Softmax(dim=1)
  22. )
  23. #时序差分
  24. model_td = torch.nn.Sequential(
  25. torch.nn.Linear(4,128),
  26. torch.nn.ReLU(),
  27. torch.nn.Linear(128,1)
  28. )
  29. #获取动作
  30. def get_action(state):
  31. state = torch.FloatTensor(state).reshape(1, 4)
  32. prob = model(state)
  33. prob_normalized = prob[0].tolist()
  34. prob_sum = sum(prob_normalized)
  35. prob_normalized = [p / prob_sum for p in prob_normalized]
  36. action = np.random.choice(range(2), p=prob_normalized, size=1)[0]
  37. return action
  38. def get_Date():
  39. states = []
  40. rewards = []
  41. actions = []
  42. next_states = []
  43. overs = []
  44. state = env.reset()
  45. over = False
  46. while not over:
  47. action = get_action(state)
  48. next_state,reward,over,_ = env.step(action)
  49. states.append(state)
  50. rewards.append(reward)
  51. actions.append(action)
  52. next_states.append(next_state)
  53. overs.append(over)
  54. state = next_state
  55. states = torch.FloatTensor(states).reshape(-1,4)
  56. rewards = torch.FloatTensor(rewards).reshape(-1,1)
  57. actions = torch.LongTensor(actions).reshape(-1, 1) # 使用 LongTensor 存储动作索引
  58. next_states = torch.FloatTensor(next_states).reshape(-1,4)
  59. overs = torch.FloatTensor(overs).reshape(-1,1)
  60. return states,rewards,actions,next_states,overs
  61. def test(play):
  62. state = env.reset()
  63. reward_sum = 0
  64. over = False
  65. while not over:
  66. action = get_action(state)
  67. state,reward,over,_ = env.step(action)
  68. reward_sum += reward
  69. if play and random.random()<0.2:
  70. display.clear_output(wait=True)
  71. show()
  72. return reward_sum
  73. #优势函数
  74. def get_advantages(deltas):
  75. advantages = []
  76. #反向遍历
  77. s = 0.0
  78. for delta in deltas[::-1]:
  79. s = 0.98*0.95*s+delta
  80. advantages.append(s)
  81. #逆序
  82. advantages.reverse()
  83. return advantages
  84. def train():
  85. optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
  86. optimizer_td = torch.optim.Adam(model_td.parameters(),lr=1e-2)
  87. loss_fn = torch.nn.MSELoss()
  88. #玩n局每局训练m次
  89. for epoch in range(500):
  90. states,rewards,actions,next_states,overs = get_Date()
  91. #计算value和target
  92. values = model_td(states)
  93. targets = model_td(next_states).detach()
  94. targets = targets*0.98
  95. targets = targets*(1-overs)
  96. targets += rewards
  97. deltas = (targets-values).squeeze(dim=1).tolist()
  98. advantages = get_advantages(deltas)
  99. advantages = torch.FloatTensor(advantages).reshape(-1,1)
  100. old_probs = model(states)
  101. old_probs = old_probs.gather(dim=1,index=actions)
  102. old_probs = old_probs.detach()
  103. for _ in range(10):
  104. new_probs = model(states)
  105. new_probs=new_probs.gather(dim=1,index=actions)
  106. rations = new_probs/old_probs
  107. #计算截断的和不截断的两份loss取其中最小的
  108. surr1 = rations*advantages
  109. surr2 = torch.clamp(rations,0.8,1.2)*advantages
  110. loss = -torch.min(surr1,surr2)
  111. loss = loss.mean()
  112. #重新计算value,并计算时序差分loss
  113. values = model_td(states)
  114. loss_td = loss_fn(values,targets)
  115. #更新参数
  116. optimizer.zero_grad()
  117. loss.backward()
  118. optimizer.step()
  119. optimizer_td.zero_grad()
  120. loss_td.backward()
  121. optimizer_td.step()
  122. if epoch % 50 ==0:
  123. test_result = sum([test(play=False) for _ in range(10)])/10
  124. print(epoch,test_result)
  125. train()
  126. test(play=True)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/331443
推荐阅读
相关标签
  

闽ICP备14008679号