当前位置:   article > 正文

强化学习经典算法笔记(十五):Soft Actor-Critic算法实现_sac算法实现

sac算法实现

强化学习经典算法笔记(十五):Soft Actor-Critic算法实现

算法简介

Soft Actor Critic,SAC算法是一种Off-policy算法,相比于PPO这种On-policy算法,sample efficiency有了提高,相比于DDPG及其变种D4PG,SAC又是一种随机策略算法。

SAC算法是在最大熵强化学习(Maximum Entropy Reinforcement Learning)的框架下构建起来的,目的是让策略随机化,好处是对于机器人控制问题非常友好,甚至可以在真实环境中使用。

策略的最大熵还意味着对策略空间、轨迹空间的探索比确定型算法要更充分,对于最优动作不止一个的状态来说,SAC就可以输出一个动作的概率分布而非确定的其中一个动作。

总结起来有三点:

  1. 学到的policy可以作为更复杂具体任务的初始化。
  2. 更强的exploration能力,这是显而易见的,能够更容易的在多模态reward (multimodal reward)下找到更好的模式。比如既要求机器人走的好,又要求机器人节约能源。
  3. 更robust鲁棒,更强的generalization。因为要从不同的方式来探索各种最优的可能性,也因此面对干扰的时候能够更容易做出调整。

对SAC算法的更详细解读可以参考
最前沿:深度解读Soft Actor-Critic算法
。来龙去脉讲的非常详细。
在这里插入图片描述

PyTorch实现

import argparse
import pickle
from collections import namedtuple
from itertools import count

import os
import numpy as np

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal,MultivariateNormal
from tensorboardX import SummaryWriter


'''
Implementation of soft actor critic, dual Q network version 
Original paper: https://arxiv.org/abs/1801.01290
'''

device = 'cuda' if torch.cuda.is_available() else 'cpu'
parser = argparse.ArgumentParser()


parser.add_argument("--env_name", default="LunarLanderContinuous-v2")  # OpenAI gym environment name Pendulum-v0
parser.add_argument('--tau',  default=0.005, type=float) # target smoothing coefficient
parser.add_argument('--target_update_interval', default=1, type=int)
parser.add_argument('--epoch', default=1, type=int) # 每次sample batch训练几次

parser.add_argument('--learning_rate', default=3e-4, type=int)
parser.add_argument('--gamma', default=0.99, type=int) # discount gamma
parser.add_argument('--capacity', default=10000, type=int) # replay buffer size
parser.add_argument('--num_episode', default=2000, type=int) #  num of  games
parser.add_argument('--batch_size', default=128, type=int) # mini batch size
parser.add_argument('--max_frame', default=500, type=int) # max frame
parser.add_argument('--seed', default=1, type=int)

# optional parameters
parser.add_argument('--hidden_size', default=64, type=int)
parser.add_argument('--render', default=False, type=bool) # show UI or not
parser.add_argument('--log_interval', default=20, type=int) # 每20episode保存1次模型
parser.add_argument('--load', default=False, type=bool) # load model
args = parser.parse_args()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
class NormalizedActions(gym.ActionWrapper):
    def _action(self, action):
        low = self.action_space.low
        high = self.action_space.high

        action = low + (action + 1.0) * 0.5 * (high - low)
        action = np.clip(action, low, high)

        return action

    def _reverse_action(self, action):
        low = self.action_space.low
        high = self.action_space.high

        action = 2 * (action - low) / (high - low) - 1
        action = np.clip(action, low, high)

        return action


# env = NormalizedActions(gym.make(args.env_name))
env = gym.make(args.env_name)
# Set seeds
env.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
min_Val = torch.tensor(1e-7).float().to(device)
Transition = namedtuple('Transition', ['s', 'a', 'r', 's_', 'd'])

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size, min_log_std=-10, max_log_std=2):
        super(Actor, self).__init__()
        self.h_size = hidden_size
        
        self.fc1 = nn.Linear(state_dim, self.h_size)
        self.fc2 = nn.Linear(self.h_size, self.h_size)
        self.mu_head = nn.Linear(self.h_size, action_dim)
        self.log_std_head = nn.Linear(self.h_size, action_dim)
        self.max_action = max_action

        self.min_log_std = min_log_std
        self.max_log_std = max_log_std

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mu = self.mu_head(x)
        log_std_head = F.relu(self.log_std_head(x))
        log_std_head = torch.clamp(log_std_head, self.min_log_std, self.max_log_std)
        return mu, log_std_head


class Critic(nn.Module):
    def __init__(self, state_dim, hidden_size):
        super(Critic, self).__init__()
        self.h_size = hidden_size
        self.fc1 = nn.Linear(state_dim, self.h_size)
        self.fc2 = nn.Linear(self.h_size, self.h_size)
        self.fc3 = nn.Linear(self.h_size, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class Q(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size):
        super(Q, self).__init__()
        self.h_size = hidden_size
        self.fc1 = nn.Linear(state_dim + action_dim, self.h_size)
        self.fc2 = nn.Linear(self.h_size, self.h_size)
        self.fc3 = nn.Linear(self.h_size, 1)

    def forward(self, s, a):
        s = s.reshape(-1, state_dim)
        a = a.reshape(-1, action_dim)
        x = torch.cat((s, a), -1) # combination s and a
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87


class SAC():
    def __init__(self):
        super(SAC, self).__init__()

        self.policy_net       = Actor(state_dim, action_dim, args.hidden_size).to(device)
        self.value_net        = Critic(state_dim, args.hidden_size).to(device)
        self.Target_value_net = Critic(state_dim, args.hidden_size).to(device)
        self.Q_net1           = Q(state_dim, action_dim, args.hidden_size).to(device)
        self.Q_net2           = Q(state_dim, action_dim, args.hidden_size).to(device)

        self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=args.learning_rate)
        self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=args.learning_rate)
        self.Q1_optimizer = optim.Adam(self.Q_net1.parameters(), lr=args.learning_rate)
        self.Q2_optimizer = optim.Adam(self.Q_net2.parameters(), lr=args.learning_rate)

        self.replay_buffer = [Transition] * args.capacity
        self.num_transition = 0 # pointer of replay buffer
        self.num_training = 1
        self.writer = SummaryWriter('./exp-SAC_dual_Q_network')

        self.value_criterion = nn.MSELoss()
        self.Q1_criterion = nn.MSELoss()
        self.Q2_criterion = nn.MSELoss()

        for target_param, param in zip(self.Target_value_net.parameters(), self.value_net.parameters()):
            target_param.data.copy_(param.data)

        os.makedirs('./SAC_model/', exist_ok=True)

    def select_action(self, state):
        state = torch.FloatTensor(state).to(device)
        mu, log_sigma = self.policy_net(state)
        sigma = torch.exp(log_sigma)
        dist = Normal(mu, sigma)
        z = dist.sample()
        
        action = torch.tanh(z).detach().cpu().numpy()

        return action# .item() # return a scalar, float32

    def store(self, s, a, r, s_, d):
        index = self.num_transition % args.capacity
        transition = Transition(s, a, r, s_, d)
        self.replay_buffer[index] = transition
        self.num_transition += 1

    def evaluate(self, state):
        # 计算动作概率有点问题,动作是多维随机向量,应该将其看做多变量高斯分布,输出一个概率值
        # 即MultivariateNormal(mu,sigma),mu是向量,sigma是diag矩阵
        # 实际使用Normal时,动作概率是一个和动作同样shape的向量
        # 另外,batch_mu + batch_sigma * z在dist分布下的概率可以用z在noise分布下的概率来近似
        # 所以在202行,才用了一个近似办法,在动作向量的所有维度上求平均,作为联合概率值的近似。
        # 在LunarLanderContinuous-v2游戏上验证了有效性。
        batch_mu, batch_log_sigma = self.policy_net(state)
        batch_sigma = torch.exp(batch_log_sigma)
        
        dist = Normal(batch_mu, batch_sigma)
        noise = Normal(0, 1) # 标准差=1
        z = noise.sample()
        action_tmp = batch_mu + batch_sigma*z.to(device)
        action = torch.tanh(action_tmp)
        # print('r,',batch_mu + batch_sigma*z,self.normal(action_tmp,batch_mu,batch_sigma.pow(2)),noise.log_prob(z).exp(),dist.log_prob(batch_mu + batch_sigma * z).exp())
        log_prob = dist.log_prob(batch_mu + batch_sigma * z.to(device)).mean(-1) - torch.log(1 - action.pow(2) + min_Val).mean(-1)

        return action, log_prob.reshape(-1,1), z, batch_mu, batch_log_sigma

    def normal(self, x, mu, sigma_sq):    # 计算动作x在policy net定义的高斯分布中的概率值
        a = ( -1 * (x-mu).pow(2) / (2*sigma_sq) ).exp()
        b = 1 / ( 2 * sigma_sq * torch.FloatTensor([np.pi]).expand_as(sigma_sq) ).sqrt()      # pi.expand_as(sigma_sq)的意义是将标量π扩展为与sigma_sq同样的维度
        return a*b

    def update(self):
        if self.num_training % 500 == 0:
            print("Training ... {} times ".format(self.num_training))
        s = torch.tensor([t.s for t in self.replay_buffer]).float().to(device)
        a = torch.tensor([t.a for t in self.replay_buffer]).to(device)
        r = torch.tensor([t.r for t in self.replay_buffer]).to(device)
        s_= torch.tensor([t.s_ for t in self.replay_buffer]).float().to(device)
        d = torch.tensor([t.d for t in self.replay_buffer]).float().to(device)

        for _ in range(args.epoch):
            #for index in BatchSampler(SubsetRandomSampler(range(args.capacity)), args.batch_size, False):
            index = np.random.choice(range(args.capacity), args.batch_size, replace=False)
            bn_s = s[index]
            bn_a = a[index].reshape(-1, 1)
            bn_r = r[index].reshape(-1, 1)
            bn_s_= s_[index]
            bn_d = d[index].reshape(-1, 1)

            target_value = self.Target_value_net(bn_s_)
            next_q_value = bn_r + (1 - bn_d) * args.gamma * target_value

            excepted_value = self.value_net(bn_s)
            excepted_Q1 = self.Q_net1(bn_s, bn_a)
            excepted_Q2 = self.Q_net2(bn_s, bn_a)
            sample_action, log_prob, z, batch_mu, batch_log_sigma = self.evaluate(bn_s)
            excepted_new_Q = torch.min(self.Q_net1(bn_s, sample_action), self.Q_net2(bn_s, sample_action))
            next_value = excepted_new_Q - log_prob

            # !!!Note that the actions are sampled according to the current policy,
            # instead of replay buffer. (From original paper)
            V_loss = self.value_criterion(excepted_value, next_value.detach()).mean()  # J_V
            # Dual Q net
            Q1_loss = self.Q1_criterion(excepted_Q1, next_q_value.detach()).mean() # J_Q
            Q2_loss = self.Q2_criterion(excepted_Q2, next_q_value.detach()).mean()

            pi_loss = (log_prob - excepted_new_Q).mean() # according to original paper

            self.writer.add_scalar('Loss/V_loss', V_loss, global_step=self.num_training)
            self.writer.add_scalar('Loss/Q1_loss', Q1_loss, global_step=self.num_training)
            self.writer.add_scalar('Loss/Q2_loss', Q2_loss, global_step=self.num_training)
            self.writer.add_scalar('Loss/policy_loss', pi_loss, global_step=self.num_training)

            # mini batch gradient descent
            self.value_optimizer.zero_grad()
            V_loss.backward(retain_graph=True)
            nn.utils.clip_grad_norm_(self.value_net.parameters(), 0.5)
            self.value_optimizer.step()

            self.Q1_optimizer.zero_grad()
            Q1_loss.backward(retain_graph = True)
            nn.utils.clip_grad_norm_(self.Q_net1.parameters(), 0.5)
            self.Q1_optimizer.step()

            self.Q2_optimizer.zero_grad()
            Q2_loss.backward(retain_graph = True)
            nn.utils.clip_grad_norm_(self.Q_net2.parameters(), 0.5)
            self.Q2_optimizer.step()

            self.policy_optimizer.zero_grad()
            pi_loss.backward(retain_graph = True)
            nn.utils.clip_grad_norm_(self.policy_net.parameters(), 0.5)
            self.policy_optimizer.step()

            # update target v net update
            for target_param, param in zip(self.Target_value_net.parameters(), self.value_net.parameters()):
                target_param.data.copy_(target_param * (1 - args.tau) + param * args.tau)

            self.num_training += 1

    def save(self):
        torch.save(self.policy_net.state_dict(), './SAC_model/policy_net.pth')
        torch.save(self.value_net.state_dict(), './SAC_model/value_net.pth')
        torch.save(self.Q_net1.state_dict(), './SAC_model/Q_net1.pth')
        torch.save(self.Q_net2.state_dict(), './SAC_model/Q_net2.pth')
        print("====================================")
        print("Model has been saved...")
        print("====================================")

    def load(self):
        self.policy_net.load_state_dict(torch.load('./SAC_model/policy_net.pth'))
        self.value_net.load_state_dict(torch.load( './SAC_model/value_net.pth'))
        self.Q_net1.load_state_dict(torch.load('./SAC_model/Q_net1.pth'))
        self.Q_net2.load_state_dict(torch.load('./SAC_model/Q_net2.pth'))
        print("model has been load")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
def main():

    agent = SAC()
    if args.load: agent.load()
    if args.render: env.render()
    print("====================================")
    print("Collection Experience...")
    print("====================================")

    ep_r = 0
    for i in range(args.num_episode):
        state = env.reset()
        for t in range(args.max_frame):
            action = agent.select_action(state)
            # print(action)
            next_state, reward, done, info = env.step(action)# np.float32(action)
            ep_r += reward
            if args.render: env.render()
            agent.store(state, action, reward, next_state, done)

            if agent.num_transition >= args.capacity and t%5==0:
                agent.update()

            state = next_state
            if done or t == args.max_frame-1:
                if i % 10 == 0:
                    print("Ep_i {}, the ep_r is {}, the t is {}".format(i, ep_r, t))
                break
        if i % args.log_interval == 0:
            agent.save()
        agent.writer.add_scalar('ep_r', ep_r, global_step=i)
        ep_r = 0


if __name__ == '__main__':
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/644308
推荐阅读
相关标签
  

闽ICP备14008679号