赞
踩
Actor-Critic算法结合了策略梯度和值函数的优点,我们将其分为两部分,Actor(策略网络)和Critic(价值网络)
AC算法的目的是为了消除策略梯度算法的高仿查问题,可以引用优势函数(advantage function)
A
π
(
s
t
,
a
t
)
A^{\pi}(s_t,a_t)
Aπ(st,at) ,来表示当前当前状态-动作对相对于平均水平的优势:
A
π
(
s
t
,
a
t
)
=
Q
π
(
s
t
,
a
t
)
−
V
π
(
s
t
)
A^{\pi}(s_t,a_t)=Q^{\pi}(s_t,a_t)-V^{\pi}(s_t)
Aπ(st,at)=Qπ(st,at)−Vπ(st)
通过与平均水平相减,可以降低方差。但需要注意的是,相减的是
V
π
(
s
t
)
V^{\pi}(s_t)
Vπ(st) ,即在状态
s
t
s_t
st 下的价值,即状态
s
t
s_t
st 的回报的均值,而不是所有状态
s
s
s 的回报的均值。
可以将目标函数改为:
∇
θ
J
(
θ
)
∝
E
π
θ
[
A
π
(
s
t
,
a
t
)
∇
θ
log
π
θ
(
a
t
∣
s
t
)
]
\nabla_\theta J(\theta)\propto\mathbb{E}_{\pi_\theta}\left[A^\pi(s_t,a_t)\nabla_\theta\log\pi_\theta(a_t\mid s_t)\right]
∇θJ(θ)∝Eπθ[Aπ(st,at)∇θlogπθ(at∣st)]
这就是A2C算法(Advantage Actor-Critic)算法。脱胎于A3C算法,即增加了多个进程,每一个进程都拥有一个独立的网络和环境以供训练。
时序差分能有效解决高方差问题但是是有偏估计,而蒙特卡洛是无偏估计但是会带来高方差问题,因此通常会结合这两个方法形成一种新的估计方式,即 T D ( λ ) TD(\lambda) TD(λ) 估计,通过结合多步,形成新的估计方式,成为广义优势估计(generalized advantage estimation GAE)。
A
GAE
(
γ
,
λ
)
(
s
t
,
a
t
)
=
∑
l
=
0
∞
(
γ
λ
)
l
δ
t
+
l
=
∑
l
=
0
∞
(
γ
λ
)
l
(
r
t
+
l
+
γ
V
π
(
s
t
+
l
+
1
)
−
V
π
(
s
t
+
l
)
)
其中,
δ
t
+
l
\delta_{t+l}
δt+l 为时步
t
+
l
t+l
t+l 的TD误差,为:
δ
t
+
l
=
r
t
+
l
+
γ
V
π
(
s
t
+
l
+
1
)
−
V
π
(
s
t
+
l
)
\delta_{t+l}=r_{t+l}+\gamma V^{\pi}(s_{t+l+1})-V^{\pi}(s_{t+l})
δt+l=rt+l+γVπ(st+l+1)−Vπ(st+l)
当
λ
=
0
\lambda=0
λ=0 时,退化为单步TD误差:
A
G
A
E
(
γ
,
0
)
(
s
t
,
a
t
)
=
δ
t
=
r
t
+
γ
V
π
(
s
t
+
1
)
−
V
π
(
s
t
)
A^{\mathrm{GAE}(\gamma,0)}(s_t,a_t)=\delta_t=r_t+\gamma V^\pi(s_{t+1})-V^\pi(s_t)
AGAE(γ,0)(st,at)=δt=rt+γVπ(st+1)−Vπ(st)
当
λ
=
1
\lambda=1
λ=1 时,则为蒙特卡洛估计:
A
G
A
E
(
γ
,
1
)
(
s
t
,
a
t
)
=
∑
l
=
0
∞
(
γ
λ
)
l
δ
t
+
l
=
∑
l
=
0
∞
(
γ
)
l
δ
t
+
l
A^{\mathrm{GAE}(\gamma,1)}(s_t,a_t)=\sum_{l=0}^\infty(\gamma\lambda)^l\delta_{t+l}=\sum_{l=0}^\infty(\gamma)^l\delta_{t+l}
AGAE(γ,1)(st,at)=l=0∑∞(γλ)lδt+l=l=0∑∞(γ)lδt+l
import gymnasium as gym
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import rl_utils
# 定义策略网络 class PolicyNet(torch.nn.Module): def __init__(self, state_dim, hidden_dim, action_dim): super(PolicyNet, self).__init__() self.fc1 = torch.nn.Linear(state_dim, hidden_dim) self.fc2 = torch.nn.Linear(hidden_dim, action_dim) def forward(self, x): x = F.relu(self.fc1(x)) return F.softmax(self.fc2(x), dim=1) # 定义价值网络,输出一个价值,为一维张量 class ValueNet(torch.nn.Module): def __init__(self, state_dim, hidden_dim): super(ValueNet, self).__init__() self.fc1 = torch.nn.Linear(state_dim, hidden_dim) self.fc2 = torch.nn.Linear(hidden_dim, 1) def forward(self, x): x = F.relu(self.fc1(x)) return self.fc2(x)
现在定义A2C算法的主题,包括采取动作和更新网络参数的两个函数。
class ActorCritic: def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, device): # 策略网络 self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device) self.critic = ValueNet(state_dim, hidden_dim).to(device) # 价值网络 # 策略网络优化器 self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) # 价值网络优化器 self.gamma = gamma self.device = device def take_action(self, state): state = torch.tensor([state], dtype=torch.float).to(self.device) probs = self.actor(state) action_dist = torch.distributions.Categorical(probs) action = action_dist.sample() return action.item() def update(self,transition_dict): states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device) actions = torch.tensor(transition_dict['actions']).view(-1, 1).to( self.device) rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device) next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device) dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device) # 时序差分目标 td_target=rewards+self.gamma*self.critic(next_states)*(1-dones) # 进行时序擦划分 td_delta=td_target-self.critic(states) log_probs=torch.log(self.actor(states).gather(1,actions)) actor_loss=torch.mean(-log_probs*td_delta.detach()) # 均方误差损失函数 critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach())) self.actor_optimizer.zero_grad() self.critic_optimizer.zero_grad() actor_loss.backward() # 计算策略网络的梯度 critic_loss.backward() # 计算价值网络的梯度 self.actor_optimizer.step() # 更新策略网络的参数 self.critic_optimizer.step() # 更新价值网络的参数 actor_lr = 1e-3 critic_lr = 1e-2 num_episodes = 1000 hidden_dim = 128 gamma = 0.98 device = torch.device("cuda") if torch.cuda.is_available() else torch.device( "cpu") env_name = 'CartPole-v0' env = gym.make(env_name) torch.manual_seed(0) state_dim = env.observation_space.shape[0] action_dim = env.action_space.n agent = ActorCritic(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, device) return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes) episodes_list = list(range(len(return_list))) plt.plot(episodes_list, return_list) plt.xlabel('Episodes') plt.ylabel('Returns') plt.title('Actor-Critic on {}'.format(env_name)) plt.show() mv_return = rl_utils.moving_average(return_list, 9) plt.plot(episodes_list, mv_return) plt.xlabel('Episodes') plt.ylabel('Returns') plt.title('Actor-Critic on {}'.format(env_name)) plt.show()
state = torch.tensor([state], dtype=torch.float).to(self.device)
Iteration 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 25.55it/s, episode=100, return=20.400]
Iteration 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 24.48it/s, episode=200, return=51.200]
Iteration 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:14<00:00, 6.91it/s, episode=300, return=151.500]
Iteration 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:25<00:00, 3.88it/s, episode=400, return=256.700]
Iteration 4: 53%|███████████████████████████████████████████████████████████████████████████████▌ | 53/100 [00:17<00:10, 4.51it/s, episode=450, return=235.500]
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。