赞
踩
和Sarsa不同的就是动作价值函数的更新公式不同
Sarsa:
Q
(
s
t
,
a
t
)
=
Q
(
s
t
,
a
t
)
+
η
∗
(
R
t
+
1
+
γ
∗
Q
(
s
t
+
1
,
a
t
+
1
)
−
Q
(
s
t
,
a
t
)
)
Q(s_t,a_t)=Q(s_t,a_t)+\eta*(R_{t+1}+\gamma *Q(s_{t+1},a_{t+1})-Q(s_t,a_t))
Q(st,at)=Q(st,at)+η∗(Rt+1+γ∗Q(st+1,at+1)−Q(st,at))
Q学习:
Q
(
s
t
,
a
t
)
=
Q
(
s
t
,
a
t
)
+
η
∗
(
R
t
+
1
+
γ
max
a
∗
Q
(
s
t
+
1
,
a
)
−
Q
(
s
t
,
a
t
)
)
Q(s_t,a_t)=Q(s_t,a_t)+\eta*(R_{t+1}+\gamma\max\limits_a *Q(s_{t+1},a)-Q(s_t,a_t))
Q(st,at)=Q(st,at)+η∗(Rt+1+γamax∗Q(st+1,a)−Q(st,at))
Sarsa更新时需要求取下一步动作
s
t
+
1
s_{t+1}
st+1用于更新,Q学习使用状态
s
t
+
1
s_{t+1}
st+1下动作价值函数中的最大值来进行更新,由于Sarsa使用下一个动作
a
t
+
1
a_{t+1}
at+1来更新动作价值函数Q,因此Sarsa算法的特征之一是Q更新依赖于求取
a
t
+
1
a_{t+1}
at+1的策略,策略依赖型特征。由于
ϵ
−
\epsilon-
ϵ−贪婪法产生的随机性不用于更新公式,Q的收敛优于Sarsa
和Sarsa类似
#导入所使用的包
import numpy as np
import pylab as plt
%matplotlib inline
#行为状态0~7,列用↑、→、↓、←表示移动的方向 theta_0 = np.array([[np.nan,1,1,np.nan], #S0 [np.nan,1,np.nan,1], #S1 [np.nan,np.nan,1,1], #S2 [1,1,1,np.nan], #S3 [np.nan,np.nan,1,1], #S4 [1,np.nan,np.nan,np.nan],#S5 [1,np.nan,np.nan,np.nan],#S6 [1,1,np.nan,np.nan], #S7 ]) #S8是目标,无策略 [a,b] = theta_0.shape Q = np.random.rand(a,b) * theta_0 def simple_convert_into_pi_from_theta(theta): [m,n] = theta.shape #获取θ矩阵的大小 pi = np.zeros((m,n)) for i in range(m): pi[i,:]=theta[i,:]/np.nansum(theta[i,:]) pi = np.nan_to_num(pi) #将nan转为0 return pi # 求解初始策略 pi_0 = simple_convert_into_pi_from_theta(theta_0) print(f'初始策略:\n{pi_0}')
初始策略:
[[0. 0.5 0.5 0. ]
[0. 0.5 0. 0.5 ]
[0. 0. 0.5 0.5 ]
[0.33333333 0.33333333 0.33333333 0. ]
[0. 0. 0.5 0.5 ]
[1. 0. 0. 0. ]
[1. 0. 0. 0. ]
[0.5 0.5 0. 0. ]]
# 定义求取动作a def get_action(s,Q,epsilon,pi_0): direction = ['up','right','down','left'] if np.random.rand() < epsilon: next_direction = np.random.choice(direction,p=pi_0[s,:]) else: next_direction = direction[np.nanargmax(Q[s,:])] if next_direction == 'up': action = 0 s_next = s - 3 #向上移动状态数字减3 elif next_direction == 'right': action = 1 elif next_direction == 'down': action = 2 elif next_direction == 'left': action = 3 return action # 定义求取动作a以及1步后移动的状态s def get_s_next(s,a,Q,epsilon,pi_0): direction = ['up','right','down','left'] next_direction = direction[a] if next_direction == 'up': s_next = s - 3 #向上移动状态数字减3 elif next_direction == 'right': s_next = s + 1 #向→移动状态数字+1 elif next_direction == 'down': s_next = s + 3 #向下移动状态数字+3 elif next_direction == 'left': s_next = s - 1 #向左移动状态数字-1 return s_next
def Q_learing(s,a,r,s_next,a_next,Q,eta,gamma):
if s_next == 8:
Q[s,a] = Q[s,a]+eta*(r-Q[s,a])
else:
Q[s,a] = Q[s,a] + eta*(r+gamma*np.nanmax(Q[s_next,:])-Q[s,a])
return Q
#Sarsa求解迷宫问题的函数 def goal_maze_ret_s_a_Q(Q,epsilon,eta,gamma,pi): s = 0#开始地点 a = a_next = get_action(s,Q,epsilon,pi) s_a_history = [[0,np.nan]] #记录智能体移动列表 while(1): a = a_next s_a_history[-1][1] = a #带入当前状态,最后一个状态的动作 s_next = get_s_next(s,a,Q,epsilon,pi) s_a_history.append([s_next,np.nan]) if s_next == 8: r=1 a_next = np.nan else: r=0 a_next = get_action(s_next,Q,epsilon,pi) Q = Q_learing(s,a,r,s_next,a_next,Q,eta,gamma) if s_next == 8: break else: s = s_next return [s_a_history,Q]
eta = 0.1 gamma = 0.9 epsilon = 0.5 v = np.nanmax(Q,axis=1) is_continue = True episode = 1 V=[] V.append(np.nanmax(Q,axis=1)) while is_continue: print(f'当前回合:{episode}') epsilon /=2 [s_a_history,Q] = goal_maze_ret_s_a_Q(Q,epsilon,eta,gamma,pi_0) #状态价值的变化 new_v = np.nanmax(Q,axis = 1) print(np.sum(np.abs(new_v-v))) episode += 1 if episode > 100: break
from matplotlib import animation from IPython.display import HTML import matplotlib.cm as cm fig = plt.figure(figsize=(5,5)) ax = plt.gca() #画出红色的墙壁 plt.plot([1,1],[0,1],color='red',linewidth=2) plt.plot([1,2],[2,2],color='red',linewidth=2) plt.plot([2,2],[2,1],color='red',linewidth=2) plt.plot([2,3],[1,1],color='red',linewidth=2) #画出表示状态的文字S0-S8 plt.text(0.5,2.5,'S0',size=14,ha='center') plt.text(1.5,2.5,'S1',size=14,ha='center') plt.text(2.5,2.5,'S2',size=14,ha='center') plt.text(0.5,1.5,'S3',size=14,ha='center') plt.text(1.5,1.5,'S4',size=14,ha='center') plt.text(2.5,1.5,'S5',size=14,ha='center') plt.text(0.5,0.5,'S6',size=14,ha='center') plt.text(1.5,0.5,'S7',size=14,ha='center') plt.text(2.5,0.5,'S8',size=14,ha='center') plt.text(0.5,2.3,'START',ha='center') plt.text(2.5,0.3,'GOAL',ha='center') #设定画图的范围 ax.set_xlim(0,3) ax.set_ylim(0,3) plt.tick_params(axis='both',which='both',bottom='off',top='off', labelbottom='off',right='off',left='off',labelleft='off') #当前位置S0用绿色圆圈画出 line, = ax.plot([0.5],[2.5],marker="o",color='g',markersize=60) #设定参数θ的初始值theta_0,用于确定初始方案 def init(): line.set_data([],[]) return(line,) def animate(i): line, = ax.plot([0.5],[2.5],marker="s",color=cm.jet(V[i][0]),markersize=85) line, = ax.plot([1.5],[2.5],marker="s",color=cm.jet(V[i][1]),markersize=85) line, = ax.plot([2.5],[2.5],marker="s",color=cm.jet(V[i][2]),markersize=85) line, = ax.plot([0.5],[1.5],marker="s",color=cm.jet(V[i][3]),markersize=85) line, = ax.plot([1.5],[1.5],marker="s",color=cm.jet(V[i][4]),markersize=85) line, = ax.plot([2.5],[1.5],marker="s",color=cm.jet(V[i][5]),markersize=85) line, = ax.plot([0.5],[0.5],marker="s",color=cm.jet(V[i][6]),markersize=85) line, = ax.plot([1.5],[0.5],marker="s",color=cm.jet(V[i][7]),markersize=85) line, = ax.plot([2.5],[0.5],marker="s",color=cm.jet(1.0),markersize=85) return (line,) #初始化函数和绘图函数生成动画 anim = animation.FuncAnimation(fig,animate,init_func=init,frames=len(V), interval = 200,repeat=False) HTML(anim.to_jshtml())
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。