当前位置:   article > 正文

DQN原理和代码实现_dqn代码

dqn代码

参考:王树森《强化学习》书籍、课程、代码


1、基本概念

折扣回报:
U t = R t + γ ⋅ R t + 1 + γ 2 ⋅ R t + 2 + ⋯ + γ n − t ⋅ R n . U_t=R_t+\gamma\cdot R_{t+1}+\gamma^2\cdot R_{t+2}+\cdots+\gamma^{n-t}\cdot R_n. Ut=Rt+γRt+1+γ2Rt+2++γntRn.
动作价值函数:
Q π ( s t , a t ) = E [ U t ∣ S t = s t , A t = a t ] , Q_\pi(s_t,a_t)=\mathbb{E}\Big[U_t\Big|S_t=s_t,A_t=a_t\Big], Qπ(st,at)=E[Ut St=st,At=at],
最大化动作价值函数<消除 π \pi π的影响>:
Q ⋆ ( s t , a t ) = max ⁡ π Q π ( s t , a t ) , ∀ s t ∈ S , a t ∈ A . Q^\star(s_t,a_t)=\max_{\pi}Q_\pi(s_t,a_t),\quad\forall s_t\in\mathcal{S},\quad a_t\in\mathcal{A}. Q(st,at)=πmaxQπ(st,at),stS,atA.


2、Q-learning算法
Observe a transition  ( s t , a t , r t , s t + 1 ) . TD target:  y t = r t + γ ⋅ max ⁡ a Q ⋆ ( s t + 1 , a ) . TD error:  δ t = Q ⋆ ( s t , a t ) − y t . Update: Q ⋆ ( s t , a t ) ← Q ⋆ ( s t , a t ) − α ⋅ δ t .

Observe a transition (st,at,rt,st+1).TD target: yt=rt+γmaxaQ(st+1,a).TD error: δt=Q(st,at)yt.Update:Q(st,at)Q(st,at)αδt.
Observe a transition (st,at,rt,st+1).TD target: yt=rt+γamaxQ(st+1,a).TD error: δt=Q(st,at)yt.Update:Q(st,at)Q(st,at)αδt.

3、DQN算法
Observe a transition  ( s t , a t , r t , s t + 1 ) . TD target:  y t = r t + γ ⋅ max ⁡ a Q ( s t + 1 , a ; w ) . TD error:  δ t = Q ( s t , a t ; w ) − y t . Update:  w ← w − α ⋅ δ t ⋅ ∂ Q ( s t , a t ; w ) ∂ w .

Observe a transition (st,at,rt,st+1).TD target: yt=rt+γmaxaQ(st+1,a;w).TD error: δt=Q(st,at;w)yt.Update: wwαδtQ(st,at;w)w.
Observe a transition (st,at,rt,st+1).TD target: yt=rt+γamaxQ(st+1,a;w).TD error: δt=Q(st,at;w)yt.Update: wwαδtwQ(st,at;w).
注:DQN使用神经网络近似最大化动作价值函数;


4、经验回放:Experience Replay(训练DQN的一种策略)

优点:可以重复利用离线经验数据;连续的经验具有相关性,经验回放可以在离线经验BUFFER随机抽样,减少相关性;

超参数:Replay Buffer的长度;
∙  Find w by minimizing  L ( w ) = 1 T ∑ t = 1 T δ t 2 2 . ∙  Stochastic gradient descent (SGD): ∙  Randomly sample a transition,  ( s i , a i , r i , s i + 1 ) , from the buffer ∙  Compute TD error,  δ i . ∙  Stochastic gradient: g i = ∂ δ i 2 / 2 ∂ w = δ i ⋅ ∂ Q ( s i , a i ; w ) ∂ w ∙  SGD: w ← w − α ⋅ g i .

 Find w by minimizing L(w)=1Tt=1Tδt22. Stochastic gradient descent (SGD): Randomly sample a transition, (si,ai,ri,si+1),from the buffer Compute TD error, δi. Stochastic gradient: gi=δi2/2w=δiQ(si,ai;w)w SGD: wwαgi.
 Find w by minimizing L(w)=T1t=1T2δt2. Stochastic gradient descent (SGD): Randomly sample a transition, (si,ai,ri,si+1),from the buffer Compute TD error, δi. Stochastic gradient: gi=wδi2/2=δiwQ(si,ai;w) SGD: wwαgi.
注:实践中通常使用minibatch SGD,每次抽取多个经验,计算小批量随机梯度;


**5、使用经验回放训练的DQN的算法代码实现**
"""4.3节DQN算法实现。
"""
import argparse
from collections import defaultdict
import os
import random
from dataclasses import dataclass, field
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class QNet(nn.Module):
    """QNet.
    Input: feature
    Output: num_act of values
    """

    def __init__(self, dim_state, num_action):
        super().__init__()
        self.fc1 = nn.Linear(dim_state, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, num_action)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class DQN:
    def __init__(self, dim_state=None, num_action=None, discount=0.9):
        self.discount = discount
        self.Q = QNet(dim_state, num_action)
        self.target_Q = QNet(dim_state, num_action)
        self.target_Q.load_state_dict(self.Q.state_dict())

    def get_action(self, state):
        qvals = self.Q(state)
        return qvals.argmax()

    def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch):
        # 计算s_batch,a_batch对应的值。
        qvals = self.Q(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze()
        # 使用target Q网络计算next_s_batch对应的值。
        next_qvals, _ = self.target_Q(next_s_batch).detach().max(dim=1)
        # 使用MSE计算loss。
        loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals)
        return loss


def soft_update(target, source, tau=0.01):
    """
    update target by target = tau * source + (1 - tau) * target.
    """
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)


@dataclass
class ReplayBuffer:
    maxsize: int
    size: int = 0
    state: list = field(default_factory=list)
    action: list = field(default_factory=list)
    next_state: list = field(default_factory=list)
    reward: list = field(default_factory=list)
    done: list = field(default_factory=list)

    def push(self, state, action, reward, done, next_state):
        """
        :param state: 状态
        :param action: 动作
        :param reward: 奖励
        :param done:
        :param next_state:下一个状态
        :return:
        """
        if self.size < self.maxsize:
            self.state.append(state)
            self.action.append(action)
            self.reward.append(reward)
            self.done.append(done)
            self.next_state.append(next_state)
        else:
            position = self.size % self.maxsize
            self.state[position] = state
            self.action[position] = action
            self.reward[position] = reward
            self.done[position] = done
            self.next_state[position] = next_state
        self.size += 1

    def sample(self, n):
        total_number = self.size if self.size < self.maxsize else self.maxsize
        indices = np.random.randint(total_number, size=n)
        state = [self.state[i] for i in indices]
        action = [self.action[i] for i in indices]
        reward = [self.reward[i] for i in indices]
        done = [self.done[i] for i in indices]
        next_state = [self.next_state[i] for i in indices]
        return state, action, reward, done, next_state


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if not args.no_cuda:
        torch.cuda.manual_seed(args.seed)


def train(args, env, agent):
    replay_buffer = ReplayBuffer(10_000)
    optimizer = torch.optim.Adam(agent.Q.parameters(), lr=args.lr)
    optimizer.zero_grad()

    epsilon = 1
    epsilon_max = 1
    epsilon_min = 0.1
    episode_reward = 0
    episode_length = 0
    max_episode_reward = -float("inf")
    log = defaultdict(list)
    log["loss"].append(0)

    agent.Q.train()
    state, _ = env.reset(seed=args.seed)
    for i in range(args.max_steps):
        if np.random.rand() < epsilon or i < args.warmup_steps:
            action = env.action_space.sample()
        else:
            action = agent.get_action(torch.from_numpy(state))
            action = action.item()
        env.render()
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        episode_reward += reward
        episode_length += 1

        replay_buffer.push(state, action, reward, done, next_state)
        state = next_state

        if done is True:
            log["episode_reward"].append(episode_reward)
            log["episode_length"].append(episode_length)

            print(
                f"i={i}, reward={episode_reward:.0f}, length={episode_length}, max_reward={max_episode_reward}, loss={log['loss'][-1]:.1e}, epsilon={epsilon:.3f}")

            # 如果得分更高,保存模型。
            if episode_reward > max_episode_reward:
                save_path = os.path.join(args.output_dir, "model.bin")
                torch.save(agent.Q.state_dict(), save_path)
                max_episode_reward = episode_reward

            episode_reward = 0
            episode_length = 0
            epsilon = max(epsilon - (epsilon_max - epsilon_min) * args.epsilon_decay, 1e-1)
            state, _ = env.reset()

        if i > args.warmup_steps:
            bs, ba, br, bd, bns = replay_buffer.sample(n=args.batch_size)
            bs = torch.tensor(bs, dtype=torch.float32)
            ba = torch.tensor(ba, dtype=torch.long)
            br = torch.tensor(br, dtype=torch.float32)
            bd = torch.tensor(bd, dtype=torch.float32)
            bns = torch.tensor(bns, dtype=torch.float32)

            loss = agent.compute_loss(bs, ba, br, bd, bns)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            log["loss"].append(loss.item())

            soft_update(agent.target_Q, agent.Q)

    # 3. 画图。
    plt.plot(log["loss"])
    plt.yscale("log")
    plt.savefig(f"{args.output_dir}/loss.png", bbox_inches="tight")
    plt.close()

    plt.plot(np.cumsum(log["episode_length"]), log["episode_reward"])
    plt.savefig(f"{args.output_dir}/episode_reward.png", bbox_inches="tight")
    plt.close()


def eval(args, env, agent):
    agent = DQN(args.dim_state, args.num_action)
    model_path = os.path.join(args.output_dir, "model.bin")
    agent.Q.load_state_dict(torch.load(model_path))

    episode_length = 0
    episode_reward = 0
    state, _ = env.reset()
    for i in range(5000):
        episode_length += 1
        action = agent.get_action(torch.from_numpy(state)).item()
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        env.render()
        episode_reward += reward

        state = next_state
        if done is True:
            print(f"episode reward={episode_reward}, episode length{episode_length}")
            state, _ = env.reset()
            episode_length = 0
            episode_reward = 0


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="CartPole-v1", type=str, help="Environment name.")
    parser.add_argument("--dim_state", default=4, type=int, help="Dimension of state.")
    parser.add_argument("--num_action", default=2, type=int, help="Number of action.")
    parser.add_argument("--discount", default=0.99, type=float, help="Discount coefficient.")
    parser.add_argument("--max_steps", default=50000, type=int, help="Maximum steps for interaction.")
    parser.add_argument("--lr", default=1e-3, type=float, help="Learning rate.")
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument("--seed", default=42, type=int, help="Random seed.")
    parser.add_argument("--warmup_steps", default=5000, type=int, help="Warmup steps without training.")
    parser.add_argument("--output_dir", default="output", type=str, help="Output directory.")
    parser.add_argument("--epsilon_decay", default=1 / 1000, type=float, help="Epsilon-greedy algorithm decay coefficient.")
    parser.add_argument("--do_train", action="store_true", help="Train policy.")
    parser.add_argument("--do_eval", action="store_true", help="Evaluate policy.")
    args = parser.parse_args()

    args.do_train = True
    args.do_eval = True

    args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

    env = gym.make(args.env, render_mode="human")
    set_seed(args)
    agent = DQN(dim_state=args.dim_state, num_action=args.num_action, discount=args.discount)
    agent.Q.to(args.device)
    agent.target_Q.to(args.device)

    if args.do_train:
        train(args, env, agent)

    if args.do_eval:
        eval(args, env, agent)


if __name__ == "__main__":
    main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256

显示:

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/296006
推荐阅读
相关标签
  

闽ICP备14008679号