当前位置:   article > 正文

强化学习——DQN算法_强化学习dqn

强化学习dqn

1、DQN算法介绍

        DQN算与sarsa算法和Q-learning算法类似,对于sarsa和Q-learning,我们使用一个Q矩阵,记录所有的state(状态)和action(动作)的价值,不断学习更新,最后使得机器选择在某种状态下,价值最高的action进行行动。但是当state和action的数量特别大的时候,甚至有限情况下不可数时,这时候再用Q矩阵去记录对应价值就会有很大的局限性,而DQN就是在这一点上进行了改进。在DQN中,对于每种state下采取的action所对应价值,我们将会使用神经网络来进行计算。

2、平衡车游戏的实例

        平衡车游戏在gym库中,这里需要下载gym库,这个游戏中,状态由四个数字来进行表示(我也不知道这四个数字代表什么,但是无伤大雅),接着只会有两种行动,并且reward并不需要我们进行设置,这个游戏进行过程中会自己返回reward。现在先搭建环境

  1. env = gym.make('CartPole-v1')
  2. env.reset()
  3. #打印游戏
  4. def show():
  5. plt.imshow(env.render(mode='rgb_array'))
  6. plt.axis('off')
  7. plt.show()
  8. #show()

 搭建环境后再创建两个神经网络,并且要让两个神经网络的参数一致,在后续的过程中,一个神经网络会延迟更新。这两个神经网络会以四个状态参数作为输入,然后以两个动作的评分作为输出

  1. #创建神经网络
  2. #计算动作模型,也就是真正需要使用的模型
  3. model = torch.nn.Sequential(
  4. torch.nn.Linear(4,128),
  5. torch.nn.ReLU(),
  6. torch.nn.Linear(128,2),
  7. )
  8. #经验网络,用于评估状态分数
  9. next_model = torch.nn.Sequential(
  10. torch.nn.Linear(4,128),
  11. torch.nn.ReLU(),
  12. torch.nn.Linear(128,2)
  13. )
  14. #把两个神经网络的参数统一一下
  15. next_model.load_state_dict(model.state_dict())
  16. #print(model,next_model)

接下来我们需要创建一个样本池,神经网络会在这个样本池中进行学习,随后需要不断更新我们的样本池,当我们有一个新的行动时,我们应该添加新的样本,删除旧的样本,保持样本池最大数量不变

  1. #想样本池中添加一些数据,删除一些古老的数据
  2. def update_date():
  3. old_count = len(datas)
  4. while len(datas) - old_count<200:
  5. #初始化
  6. state = env.reset()
  7. over = False
  8. while not over:
  9. #获取当前状态得到一个动作
  10. action = get_action(state)
  11. #执行动作,得到反馈
  12. next_state,reward,over,_ = env.step(action)
  13. #记录样本
  14. datas.append((state,action,reward,next_state,over))
  15. #更新状态
  16. state = next_state
  17. update_count = len(datas) - old_count
  18. drop_count = max(len(datas)-10000,0)
  19. while len(datas)>10000:
  20. datas.pop(0)
  21. return update_count,drop_count

        接下来需要进行采样,并将样本格式转换为所需要的格式

  1. #获取一批数据样本
  2. def get_sample():
  3. # 从样本池中采样
  4. samples = random.sample(datas, 64)
  5. state = np.array([i[0] for i in samples])
  6. action = np.array([i[1] for i in samples])
  7. reward = np.array([i[2] for i in samples])
  8. next_state = np.array([i[3] for i in samples])
  9. over = np.array([i[4] for i in samples])
  10. state = torch.FloatTensor(state).reshape(-1, 4)
  11. action = torch.LongTensor(action).reshape(-1, 1)
  12. reward = torch.FloatTensor(reward).reshape(-1, 1)
  13. next_state = torch.FloatTensor(next_state).reshape(-1, 4)
  14. over = torch.LongTensor(over).reshape(-1, 1)
  15. return state, action, reward, next_state, over

        接着是价值函数(直接交给神经网络就可以了)

  1. def get_value(state,action):
  2. value = model(state)
  3. value = value.gather(dim=1,index=action)
  4. return value

        接下来是获取target函数,这个函数的意义在于,我们是不知道游戏的全貌的,这样的话在一个状态下所采取的行动,不仅仅至于它本身有关,更和接下来所到达的状态和接下来应该采取的行动有关,价值value应该要想target靠近

  1. def get_target(reward,next_state,over):
  2. with torch.no_grad():
  3. target = next_model(next_state)
  4. target = target.max(dim = 1)[0]
  5. target = target.reshape(-1,1)
  6. target *= 0.98
  7. target *= (1-over)#游戏结束了就不用玩了
  8. target += reward
  9. return target

然后就是测试函数并开始训练了

  1. def train():
  2. model.train()
  3. optimizer = torch.optim.Adam(model.parameters(),lr=2e-3)
  4. loss_fn = torch.nn.MSELoss()
  5. for epoch in range(500):
  6. update_count,drop_count = update_date()
  7. for i in range(200):
  8. state,action,reward,next_state,over = get_sample()
  9. value = get_value(state,action)
  10. target = get_target(reward,next_state,over)
  11. loss = loss_fn(value,target)
  12. optimizer.zero_grad()
  13. loss.backward()
  14. optimizer.step()
  15. if (i+1)%10==0:
  16. next_model.load_state_dict(model.state_dict())
  17. if epoch%50==0:
  18. test_result = sum([tes(play=False) for _ in range(20)])/20
  19. print(f"Epoch: {epoch}, Data Size: {len(datas)}, Update: {update_count}, Drop: {drop_count}, Test Reward: {test_result}")

接下来是完整代码

  1. #这里会使用神经网络
  2. import gym
  3. import torch
  4. import random
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from IPython import display
  8. from matplotlib.animation import FuncAnimation
  9. env = gym.make('CartPole-v1')
  10. env.reset()
  11. #打印游戏
  12. def show():
  13. plt.imshow(env.render(mode='rgb_array'))
  14. plt.axis('off')
  15. plt.show()
  16. #show()
  17. #创建神经网络
  18. #计算动作模型,也就是真正需要使用的模型
  19. model = torch.nn.Sequential(
  20. torch.nn.Linear(4,128),
  21. torch.nn.ReLU(),
  22. torch.nn.Linear(128,2),
  23. )
  24. #经验网络,用于评估状态分数
  25. next_model = torch.nn.Sequential(
  26. torch.nn.Linear(4,128),
  27. torch.nn.ReLU(),
  28. torch.nn.Linear(128,2)
  29. )
  30. #把两个神经网络的参数统一一下
  31. next_model.load_state_dict(model.state_dict())
  32. #print(model,next_model)
  33. #定义动作函数
  34. def get_action(state):
  35. if random.random()<0.01:
  36. return random.choice([0,1])
  37. state = torch.FloatTensor(state).reshape(1,4)
  38. return model(state).argmax().item()
  39. #样本池
  40. datas = []
  41. #想样本池中添加一些数据,删除一些古老的数据
  42. def update_date():
  43. old_count = len(datas)
  44. while len(datas) - old_count<200:
  45. #初始化
  46. state = env.reset()
  47. over = False
  48. while not over:
  49. #获取当前状态得到一个动作
  50. action = get_action(state)
  51. #执行动作,得到反馈
  52. next_state,reward,over,_ = env.step(action)
  53. #记录样本
  54. datas.append((state,action,reward,next_state,over))
  55. #更新状态
  56. state = next_state
  57. update_count = len(datas) - old_count
  58. drop_count = max(len(datas)-10000,0)
  59. while len(datas)>10000:
  60. datas.pop(0)
  61. return update_count,drop_count
  62. #获取一批数据样本
  63. def get_sample():
  64. # 从样本池中采样
  65. samples = random.sample(datas, 64)
  66. state = np.array([i[0] for i in samples])
  67. action = np.array([i[1] for i in samples])
  68. reward = np.array([i[2] for i in samples])
  69. next_state = np.array([i[3] for i in samples])
  70. over = np.array([i[4] for i in samples])
  71. state = torch.FloatTensor(state).reshape(-1, 4)
  72. action = torch.LongTensor(action).reshape(-1, 1)
  73. reward = torch.FloatTensor(reward).reshape(-1, 1)
  74. next_state = torch.FloatTensor(next_state).reshape(-1, 4)
  75. over = torch.LongTensor(over).reshape(-1, 1)
  76. return state, action, reward, next_state, over
  77. def get_value(state,action):
  78. value = model(state)
  79. value = value.gather(dim=1,index=action)
  80. return value
  81. def get_target(reward,next_state,over):
  82. with torch.no_grad():
  83. target = next_model(next_state)
  84. target = target.max(dim = 1)[0]
  85. target = target.reshape(-1,1)
  86. target *= 0.98
  87. target *= (1-over)
  88. target += reward
  89. return target
  90. #测试函数
  91. def tes(play):
  92. #初始化
  93. state = env.reset()
  94. #记录reward之和,越大越好
  95. reward_sum = 0
  96. over = False
  97. while not over:
  98. #获取动作
  99. action = get_action(state)
  100. #执行动作
  101. state,reward,over,_ = env.step(action)
  102. reward_sum +=reward
  103. #打印动画
  104. if play and random.random()<0.2:#跳帧
  105. display.clear_output(wait=True)
  106. show()
  107. return reward_sum
  108. def train():
  109. model.train()
  110. optimizer = torch.optim.Adam(model.parameters(),lr=2e-3)
  111. loss_fn = torch.nn.MSELoss()
  112. for epoch in range(500):
  113. update_count,drop_count = update_date()
  114. for i in range(200):
  115. state,action,reward,next_state,over = get_sample()
  116. value = get_value(state,action)
  117. target = get_target(reward,next_state,over)
  118. loss = loss_fn(value,target)
  119. optimizer.zero_grad()
  120. loss.backward()
  121. optimizer.step()
  122. if (i+1)%10==0:
  123. next_model.load_state_dict(model.state_dict())
  124. if epoch%50==0:
  125. test_result = sum([tes(play=False) for _ in range(20)])/20
  126. print(f"Epoch: {epoch}, Data Size: {len(datas)}, Update: {update_count}, Drop: {drop_count}, Test Reward: {test_result}")
  127. train()

为什么需要两个神经网络?

这个两个神经网络其实是一致的,只是其中一个会延迟更新参数。

这个原因在于假如在一个游戏中,我们的目标状态并不是固定,可能是一直变换的,就如这个游戏中,平衡的状态是多种多样的,那么我们一直跟踪这个目标就会变得困难,这时我们不妨固定住某一个曾经是目标的状态,让机器先尝试去达到这种状态,再去跟踪下一个固定目标状态,这样的方式会使得机器更容易找到目标状态。这也是为什么需要一个一样的,但是延迟更新的神经网络。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/296014
推荐阅读
相关标签
  

闽ICP备14008679号