当前位置:   article > 正文

Q Learning算法实现_qlearn算法

qlearn算法

1.算法思想

Q learning 算法是一种value-based的强化学习算法,Q是quality的缩写,Q函数 Q(state,action)表示在状态state下执行动作action的quality, 也就是能获得的Q value是多少。算法的目标是最大化Q值,通过在状态state下所有可能的动作中选择最好的动作来达到最大化期望reward。

Q learning算法使用Q table来记录不同状态下不同动作的预估Q值。在探索环境之前,Q table会被随机初始化,当agent在环境中探索的时候,它会用贝尔曼方程来迭代更新Q(s,a), 随着迭代次数的增多,agent会对环境越来越了解,Q 函数也能被拟合得越来越好,直到收敛或者到达设定的迭代结束次数。

2.算法设计

1)应用设计:

给出一个4x4的方格,给出起点和终点,红色代表物体,黑色代表陷阱,黄色代表终点。要求物体在不掉进陷阱的情况下尽可能最短路径到达终点。我们的最终目的就是得到一个训练得比较好的Q table,Q table里面的不同位置的值就是我们要训练的结果。当模型训练好了以后,agent会学会怎么去玩这个游戏,我们就可以应用此模型了。当开始一局的新的游戏,agent会根据Q table去查找到达目的地的最优路径。

 

2)Q table:

我们使用Q table来存储agent在不同state下选择不同动作可以获得的Q value。state是指物体所在的位置,action是物体在这个位置上所有能选择的动作。表的每一行表示一个state,每一列表示一个action。表中的值表示在这个state和action的最大期望未来reward。Q table最开始的时候会被初始化,比如初始化为0。如下图所示

3) 选择action:

这里会采用一个exploitation-exploration的方法,它用的Ƴ-greedy 策略选择action。

exploration:探索未知的领域,比如在某个state下随机选择一个action。exploitation :根据当前的信息,由训练的模型做出最佳的决策,即选择Q value最大的动作

做exploitation和exploration的目的是获得一种长期收益最高的策略,这个过程可能对short-term reward有损失。如果exploitation太多,那么模型比较容易陷入局部最优,但是exploration太多,模型收敛速度太慢。这就是exploitation-exploration权衡。

比如我们设Ƴ=0.9,随机化一个[0,1]的值,如果它小于,则进行exploration,随机选择动作;如果它大于,则进行exploitation,选择Q value最大的动作。

在训练过程中,在刚开始的时候会被设得比较大,让agent充分探索,然后逐步减少,agent会开始慢慢选择Q value最大的动作

由于刚开始,Ƴ比较大,agent随机选择一个action。假如在start位置时,agent选择了往右走的动作,到达(0,1)的位置。

4)Q value更新

agent从start位置执行一个right动作,走到(0,1)位置,得到了一个实时奖励 + 1分,然后我们更新Q table里第一行第二列的值。更新的方法是用贝尔曼方程(Bellman equation,下面是Q learning算法更新的方法:

      取。代入贝尔曼方程(Bellman equation):

从而更新Q table如下:

agent在每一个step的时候都会用上面的方法迭代更新一次Q table,直到Q table不在更新或者到达游戏设置的结束局数。

3.代码设计

4.代码实现

1)算法实现:brain.py

  1. import numpy as np
  2. import pandas as pd
  3. class QLearningTable:
  4. def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
  5. self.actions = actions # a list
  6. self.lr = learning_rate
  7. self.gamma = reward_decay
  8. self.epsilon = e_greedy
  9. self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
  10. def choose_action(self, observation):
  11. self.check_state_exist(observation)
  12. # action selection
  13. if np.random.uniform() < self.epsilon:
  14. # choose best action
  15. state_action = self.q_table.loc[observation, :]
  16. # some actions may have the same value, randomly choose on in these actions
  17. action = np.random.choice(state_action[state_action == np.max(state_action)].index)
  18. else:
  19. # choose random action
  20. action = np.random.choice(self.actions)
  21. return action
  22. def learn(self, s, a, r, s_):
  23. self.check_state_exist(s_)
  24. q_predict = self.q_table.loc[s, a]
  25. if s_ != 'terminal':
  26. q_target = r + self.gamma * self.q_table.loc[s_, :].max() # next state is not terminal
  27. else:
  28. q_target = r # next state is terminal
  29. self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update
  30. def check_state_exist(self, state):
  31. if state not in self.q_table.index:
  32. # append new state to q table
  33. self.q_table = self.q_table.append(
  34. pd.Series(
  35. [0]*len(self.actions),
  36. index=self.q_table.columns,
  37. name=state,
  38. )
  39. )

2)表格绘制:

Draw.py其中,Death:[reward -1]

Not death[reward +1]

Success [reward+10]

  1. import numpy as np
  2. import time
  3. import sys
  4. if sys.version_info.major == 2:
  5. import Tkinter as tk
  6. else:
  7. import tkinter as tk
  8. UNIT = 40 # pixels
  9. MAZE_H = 4 # grid height
  10. MAZE_W = 4 # grid width
  11. class Maze(tk.Tk, object):
  12. def __init__(self):
  13. super(Maze, self).__init__()
  14. self.action_space = ['u', 'd', 'l', 'r']
  15. self.n_actions = len(self.action_space)
  16. self.title('maze')
  17. self.geometry('{0}x{1}'.format(MAZE_W * UNIT, MAZE_H * UNIT))
  18. self._build_maze()
  19. def _build_maze(self):
  20. self.canvas = tk.Canvas(self, bg='white',
  21. height=MAZE_H * UNIT,
  22. width=MAZE_W * UNIT)
  23. # create grids
  24. for c in range(0, MAZE_W * UNIT, UNIT):
  25. x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
  26. self.canvas.create_line(x0, y0, x1, y1)
  27. for r in range(0, MAZE_H * UNIT, UNIT):
  28. x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
  29. self.canvas.create_line(x0, y0, x1, y1)
  30. # create origin
  31. origin = np.array([20, 20])
  32. # hell
  33. hell1_center = origin + np.array([UNIT * 2, UNIT])
  34. self.hell1 = self.canvas.create_rectangle(
  35. hell1_center[0] - 15, hell1_center[1] - 15,
  36. hell1_center[0] + 15, hell1_center[1] + 15,
  37. fill='black')
  38. # hell
  39. hell2_center = origin + np.array([UNIT, UNIT * 2])
  40. self.hell2 = self.canvas.create_rectangle(
  41. hell2_center[0] - 15, hell2_center[1] - 15,
  42. hell2_center[0] + 15, hell2_center[1] + 15,
  43. fill='black')
  44. # create oval
  45. oval_center = origin + UNIT * 2
  46. self.oval = self.canvas.create_oval(
  47. oval_center[0] - 15, oval_center[1] - 15,
  48. oval_center[0] + 15, oval_center[1] + 15,
  49. fill='yellow')
  50. # create red rect
  51. self.rect = self.canvas.create_rectangle(
  52. origin[0] - 15, origin[1] - 15,
  53. origin[0] + 15, origin[1] + 15,
  54. fill='red')
  55. # pack all
  56. self.canvas.pack()
  57. def reset(self):
  58. self.update()
  59. time.sleep(0.5)
  60. self.canvas.delete(self.rect)
  61. origin = np.array([20, 20])
  62. self.rect = self.canvas.create_rectangle(
  63. origin[0] - 15, origin[1] - 15,
  64. origin[0] + 15, origin[1] + 15,
  65. fill='red')
  66. # return observation
  67. return self.canvas.coords(self.rect)
  68. def step(self, action):
  69. s = self.canvas.coords(self.rect)
  70. base_action = np.array([0, 0])
  71. if action == 0: # up
  72. if s[1] > UNIT:
  73. base_action[1] -= UNIT
  74. elif action == 1: # down
  75. if s[1] < (MAZE_H - 1) * UNIT:
  76. base_action[1] += UNIT
  77. elif action == 2: # right
  78. if s[0] < (MAZE_W - 1) * UNIT:
  79. base_action[0] += UNIT
  80. elif action == 3: # left
  81. if s[0] > UNIT:
  82. base_action[0] -= UNIT
  83. self.canvas.move(self.rect, base_action[0], base_action[1]) # move agent
  84. s_ = self.canvas.coords(self.rect) # next state
  85. # reward function
  86. if s_ == self.canvas.coords(self.oval):
  87. reward = 1
  88. done = True
  89. s_ = 'terminal'
  90. elif s_ in [self.canvas.coords(self.hell1), self.canvas.coords(self.hell2)]:
  91. reward = -1
  92. done = True
  93. s_ = 'terminal'
  94. else:
  95. reward = 0
  96. done = False
  97. return s_, reward, done
  98. def render(self):
  99. time.sleep(0.1)
  100. self.update()
  101. def update():
  102. for t in range(10):
  103. s = env.reset()
  104. while True:
  105. env.render()
  106. a = 1
  107. s, r, done = env.step(a)
  108. if done:
  109. break
  110. if __name__ == '__main__':
  111. env = Maze()
  112. env.after(100, update)
  113. env.mainloop()

3)运行:

  1. from maze_env import Maze
  2. from RL_brain import QLearningTable
  3. def update():
  4. for episode in range(100):
  5. # initial observation
  6. observation = env.reset()
  7. while True:
  8. # fresh env
  9. env.render()
  10. # RL choose action based on observation
  11. action = RL.choose_action(str(observation))
  12. # RL take action and get next observation and reward
  13. observation_, reward, done = env.step(action)
  14. # RL learn from this transition
  15. RL.learn(str(observation), action, reward, str(observation_))
  16. # swap observation
  17. observation = observation_
  18. # break while loop when end of this episode
  19. if done:
  20. break
  21. # end of game
  22. print('game over')
  23. env.destroy()
  24. if __name__ == "__main__":
  25. env = Maze()
  26. RL = QLearningTable(actions=list(range(env.n_actions)))
  27. env.after(100, update)
  28. env.mainloop();

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/喵喵爱编程/article/detail/1012690
推荐阅读
相关标签
  

闽ICP备14008679号