当前位置:   article > 正文

DQN:深度Q-网络_dqn网络

dqn网络

基本概念

DQN

DQN全名叫Deep Q-Leaning Network,中文名叫深度Q-网络。DQN算法的基本思路来源于Q-Learning,不同于Q-learning,DQN的Q值不是直接通过状态值s和动作a来计算的,而是通过神经网络来计算的。

DQN算法本质上属于Q-Learning算法,在策略选择上和Q-Learning保持一致,采用 ϵ − g r e e d y \epsilon-greedy ϵgreedy策略。在Q-learning的基础上,DQN提出了两个技巧使得Q网络的更新迭代更稳定:

1、经验回放:DQN使用经验池对多条经验 ( s , a , r , s ′ ) (s, a, r, s') (s,a,r,s)进行保存,在训练的时候,随机从经验池中抽取一定数量的数据来进行训练,这样就可以不停的优化网络模型。

2、固定Q目标 Fixed-Q-Target:主要解决算法训练不稳定的问题。复制一个和原来Q网络结构一样的Target Q网络,用于计算Q目标值。DQN中有两个结构相同但是参数不同的网络,当前值( p r e d i c t Q predictQ predictQ)网络用于预测估计的Q值,目标值( t a r g e t Q targetQ targetQ)网络用于预测现实的Q值。当前值网络使用最新的参数,目标值网络会使用很久之前的参数。

其中, t a r g e t Q targetQ targetQ值的计算公式: t a r g e t Q = r + γ ∗ m a x Q ( s ′ , a ∗ ; θ ) targetQ=r+γ∗maxQ(s',a^*;θ) targetQ=r+γmaxQ(s,a;θ
p r e d i c t Q predictQ predictQ的计算公式: p r e d i c t Q = Q ( s , a ; θ ) predictQ=Q(s,a;\theta) predictQ=Q(s,a;θ)

如下图所示,使用均方差损失函数 1 m ∑ j = 1 m ( t a r g e t Q − p r e d i c t Q ) 2 \frac 1 m \sum_{j=1}^{m}(targetQ- predictQ)^2 m1j=1m(targetQpredictQ)2,通过神经网络的梯度反向传播来更新 p r e d i c t Q predictQ predictQ网络的所有参数 θ \theta θ。并且每隔N时间步长,拷贝 p r e d i c t Q predictQ predictQ网络的所有参数到 t a r g e t Q targetQ targetQ网络中。

在这里插入图片描述
简而言之,DQN使用 ϵ − g r e e d y \epsilon-greedy ϵgreedy策略来选择动作并执行,采用经验回收机制,使用经验池存储(状态,动作,价值,下一个状态)信息,存储完成之后,以批量的形式获取数据,使用均方差损失函数,采用梯度随机下降法更新当前值( p r e d i c t Q predictQ predictQ)网络的参数,进行当前值网络的训练,并每隔N时间步长,将参数同步到目标值( t a r g e t Q targetQ targetQ)网络。

DQN与Q-Learning的区别:

整体来说,DQN 与Q-Learning的目标价值以及价值的更新方式都非常相似。但是,DQN将Q-Learning与深度学习结合,用深度网络来近似动作价值函数,而Q-Learning则是采用表格存储;DQN 采用了经验回放的训练方法,从历史数据中随机采样,而Q-Learning直接采用下一个状态的数据进行学习。

DQN的算法如下图所示。
在这里插入图片描述
在这里插入图片描述

在上述代码中, Q ( ϕ j , a j ; θ ) Q(\phi_j,a_j;\theta) Q(ϕj,aj;θ)为当前值( p r e d i c t Q predictQ predictQ)网络预测的Q值, y i = r j + γ m a x a ′ Q ( ϕ j + 1 , a ′ ; θ ) y_i=r_j + \gamma max_{a'}Q(\phi_{j+1},a';\theta) yi=rj+γmaxaQ(ϕj+1,a;θ)为目标值( t a r g e t Q targetQ targetQ)网络预测的Q值。

测试代码

下面是一个DQN贪吃蛇的例子

import random
import sys
from collections import deque

import numpy as np
import pygame as pg
import tensorflow as tf
import cv2 as cv

# 参数
# 游戏帧率
FPS = 5
# 窗口宽度、高度
WINDOW_WIDTH, WINDOW_HEIGHT = 640, 480
# 组成大小
CELL_SIZE = 40
CELL_WIDTH, CELL_HEIGHT = WINDOW_WIDTH // CELL_SIZE, WINDOW_HEIGHT // CELL_SIZE
# 常用颜色
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
DARK_GREEN = (0, 155, 0)
GREEN = (0, 255, 0)
DARK_GRAY = (60, 60, 60)
RED = (255, 0, 0)
# 方向
UP = "up"
DOWN = "down"
LEFT = "left"
RIGHT = "right"
# 神经网络的输出
MOVE_UP = [1, 0, 0, 0]
MOVE_DOWN = [0, 1, 0, 0]
MOVE_LEFT = [0, 0, 1, 0]
MOVE_RIGHT = [0, 0, 0, 1]


def check_for_key_press():
    if len(pg.event.get(pg.QUIT)) > 0:
        pg.quit()
        sys.exit()
    key_up_events = pg.event.get(pg.KEYUP)
    if len(key_up_events) == 0:
        return None
    if key_up_events[0].key == pg.K_ESCAPE:
        pg.quit()
        sys.exit()
    return key_up_events[0].key


def show_start_screen():
    title_font = pg.font.Font("freesansbold.ttf", 100)
    title_surface1 = title_font.render("snake", True, WHITE, DARK_GREEN)
    title_surface2 = title_font.render("snake", True, GREEN)
    degree1 = 0
    degree2 = 0
    press_key_font = pg.font.Font("freesansbold.ttf", 18)
    press_key_surface = press_key_font.render("press a key to play", True, DARK_GRAY)
    while True:
        screen.fill(BLACK)
        # draw snake word
        rotated_surface1 = pg.transform.rotate(title_surface1, degree1)
        rotated_rect1 = rotated_surface1.get_rect()
        rotated_rect1.center = (WINDOW_WIDTH / 2, WINDOW_HEIGHT / 2)
        screen.blit(rotated_surface1, rotated_rect1)
        rotated_surface2 = pg.transform.rotate(title_surface2, degree2)
        rotated_rect2 = rotated_surface2.get_rect()
        rotated_rect2.center = (WINDOW_WIDTH / 2, WINDOW_HEIGHT / 2)
        screen.blit(rotated_surface2, rotated_rect2)
        # draw press key word
        press_key_rect = press_key_surface.get_rect()
        press_key_rect.topleft = (WINDOW_WIDTH - 200, WINDOW_HEIGHT - 30)
        screen.blit(press_key_surface, press_key_rect)
        if check_for_key_press():
            pg.event.get()
            return
        pg.display.update()
        clock.tick(FPS)
        degree1 += 3
        degree2 += 3


def test_not_ok(temp, worm):
    for body in worm:
        if temp['x'] == body['x'] and temp['y'] == body['y']:
            return True
    return False


def get_random_location(worm):
    temp = {'x': random.randint(0, CELL_WIDTH - 1), 'y': random.randint(0, CELL_HEIGHT - 1)}
    while test_not_ok(temp, worm):
        temp = {'x': random.randint(0, CELL_WIDTH - 1), 'y': random.randint(0, CELL_HEIGHT - 1)}
    return temp


# 检查贪吃蛇是否出现180度掉头
def examine_direction(pre_direction):
    if direction == UP and pre_direction == DOWN:
        return False
    if direction == DOWN and pre_direction == UP:
        return False
    if direction == LEFT and pre_direction == RIGHT:
        return False
    if direction == RIGHT and pre_direction == LEFT:
        return False
    return True


def draw_grid():
    for x in range(0, WINDOW_WIDTH, CELL_SIZE):
        pg.draw.line(screen, DARK_GRAY, (x, 0), (x, WINDOW_HEIGHT))
    for y in range(0, WINDOW_HEIGHT, CELL_SIZE):
        pg.draw.line(screen, DARK_GRAY, (0, y), (WINDOW_WIDTH, y))


def draw_worm_coord():
    for body in worm_coord:
        x = body['x'] * CELL_SIZE
        y = body['y'] * CELL_SIZE
        body_rect = pg.Rect(x, y, CELL_SIZE, CELL_SIZE)
        pg.draw.rect(screen, DARK_GREEN, body_rect)
        body_inner_rect = pg.Rect(x + 4, y + 4, CELL_SIZE - 8, CELL_SIZE - 8)
        pg.draw.rect(screen, GREEN, body_inner_rect)


def draw_apple():
    x = apple['x'] * CELL_SIZE
    y = apple['y'] * CELL_SIZE
    apple_rect = pg.Rect(x, y, CELL_SIZE, CELL_SIZE)
    pg.draw.rect(screen, WHITE, apple_rect)


def run_game(action=None):
    global direction, worm_coord, head, apple
    pre_direction = direction
    if action == MOVE_UP and direction != DOWN:
        direction = UP
    elif action == MOVE_DOWN and direction != UP:
        direction = DOWN
    elif action == MOVE_LEFT and direction != RIGHT:
        direction = LEFT
    elif action == MOVE_RIGHT and direction != LEFT:
        direction = RIGHT
    for event in pg.event.get():
        if event.type == pg.QUIT:
            pg.quit()
            sys.exit()
        elif event.type == pg.KEYUP:
            if (event.key == pg.K_LEFT or event.key == pg.K_a) and direction != RIGHT:
                direction = LEFT
            elif (event.key == pg.K_RIGHT or event.key == pg.K_d) and direction != LEFT:
                direction = RIGHT
            elif (event.key == pg.K_UP or event.key == pg.K_w) and direction != DOWN:
                direction = UP
            elif (event.key == pg.K_DOWN or event.key == pg.K_s) and direction != UP:
                direction = DOWN
            elif event.key == pg.K_ESCAPE:
                pg.quit()
                sys.exit()

    reward = 0
    # 检测贪吃蛇有没有碰到墙壁
    if worm_coord[head]['x'] == -1 or worm_coord[head]['x'] == CELL_WIDTH \
            or worm_coord[head]['y'] == -1 or worm_coord[head]['y'] == CELL_HEIGHT:
        worm_coord = [{'x': start_x, 'y': start_y},
                      {'x': start_x - 1, 'y': start_y},
                      {'x': start_x - 2, 'y': start_y}]
        direction = RIGHT
        screen_image = pg.surfarray.array3d(pg.display.get_surface())
        reward = -1
        return reward, screen_image
    # 检测贪吃蛇有没有碰到自己
    for worm_body in worm_coord[1:]:
        if worm_body['x'] == worm_coord[head]['x'] and worm_body['y'] == worm_coord[head]['y']:
            worm_coord = [{'x': start_x, 'y': start_y},
                          {'x': start_x - 1, 'y': start_y},
                          {'x': start_x - 2, 'y': start_y}]
            direction = RIGHT
            screen_image = pg.surfarray.array3d(pg.display.get_surface())
            reward = -1
            return reward, screen_image
    # 检测贪吃蛇有没有吃到苹果
    # 如果吃到苹果,不用删除末尾,相当于增加一节
    if worm_coord[head]['x'] == apple['x'] and worm_coord[head]['y'] == apple['y']:
        reward = 1
        apple = get_random_location(worm_coord)
    # 如果没有吃到苹果,删除末尾一节
    else:
        del worm_coord[-1]
    # 贪吃蛇移动逻辑
    # 如果贪吃蛇出现180度旋转,则方向和原来方向保持不变
    if not examine_direction(pre_direction):
        direction = pre_direction
    # 根据贪吃蛇方向确定新头部的位置
    new_head = {}
    if direction == UP:
        new_head = {'x': worm_coord[head]['x'], 'y': worm_coord[head]['y'] - 1}
    elif direction == DOWN:
        new_head = {'x': worm_coord[head]['x'], 'y': worm_coord[head]['y'] + 1}
    elif direction == LEFT:
        new_head = {'x': worm_coord[head]['x'] - 1, 'y': worm_coord[head]['y']}
    elif direction == RIGHT:
        new_head = {'x': worm_coord[head]['x'] + 1, 'y': worm_coord[head]['y']}
    worm_coord.insert(0, new_head)
    screen.fill(BLACK)
    draw_grid()
    draw_apple()
    draw_worm_coord()
    pg.display.update()
    clock.tick(FPS)
    screen_image = pg.surfarray.array3d(pg.display.get_surface())
    return reward, screen_image


def run():
    global screen, clock
    pg.init()
    screen = pg.display.set_mode((WINDOW_WIDTH, WINDOW_HEIGHT))
    clock = pg.time.Clock()
    show_start_screen()
    # while True:
    #     run_game()
    #     clock.tick(FPS)
    #     show_game_over_screen(screen)


start_x, start_y = 5, 5
head = 0
worm_coord = [{'x': start_x, 'y': start_y},
              {'x': start_x - 1, 'y': start_y},
              {'x': start_x - 2, 'y': start_y}]
direction = RIGHT
apple = get_random_location(worm_coord)
# run()

# 训练参数
LEARNING_RATE = 0.99
INITIAL_EPSILON = 1.0
FINAL_EPSILON = 0.05
EXPLORE = 50000
OBSERVE = 100
REPLAY_MEMORY = 1024
BATCH = 14

tf.compat.v1.disable_eager_execution()
input_image = tf.compat.v1.placeholder("float", [None, 160, 120, 4])
action = tf.compat.v1.placeholder("float", [None, 4])


def convolutional_neural_network(input_image):
    weights = {"w_conv1": tf.Variable(tf.zeros([8, 8, 4, 32])),
               "w_conv2": tf.Variable(tf.zeros([4, 4, 32, 64])),
               "w_conv3": tf.Variable(tf.zeros([3, 3, 64, 64])),
               "w_fc4": tf.Variable(tf.zeros([128, 64])),
               "w_out": tf.Variable(tf.zeros([64, 4]))}
    bias = {"b_conv1": tf.Variable(tf.zeros([32])),
            "b_conv2": tf.Variable(tf.zeros([64])),
            "b_conv3": tf.Variable(tf.zeros([64])),
            "b_fc4": tf.Variable(tf.zeros([64])),
            "b_out": tf.Variable(tf.zeros([4]))}
    conv1 = tf.nn.relu(tf.nn.conv2d(input_image, weights["w_conv1"], strides=[1, 4, 4, 1], padding="VALID")
                       + bias["b_conv1"])
    conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
    conv2 = tf.nn.relu(tf.nn.conv2d(conv1, weights["w_conv2"], strides=[1, 2, 2, 1], padding="VALID")
                       + bias["b_conv2"])
    conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")

    conv3 = tf.nn.relu(tf.nn.conv2d(conv2, weights["w_conv3"], strides=[1, 1, 1, 1], padding="VALID")
                       + bias["b_conv3"])
    conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
    conv3_flat = tf.reshape(conv3, [-1, 128])
    fc4 = tf.nn.relu(tf.matmul(conv3_flat, weights["w_fc4"]) + bias["b_fc4"])
    out = tf.matmul(fc4, weights["w_out"] + bias["b_out"])
    return out


def train(input_image):
    tf.compat.v1.disable_eager_execution()
    predict_action = convolutional_neural_network(input_image)
    argmax = tf.compat.v1.placeholder("float", [None, 4])
    gt = tf.compat.v1.placeholder("float", [None])
    # 定义均方差损失函数的计算过程
    action = tf.reduce_sum(tf.multiply(predict_action, argmax))
    cost = tf.reduce_mean(tf.square(action - gt))
    # 定义机器学习过程
    optimizer = tf.compat.v1.train.AdamOptimizer(1e-2).minimize(cost)
    run()
    D = deque()
    _, image = run_game()
    image = cv.cvtColor(cv.resize(image, (120, 160)), cv.COLOR_BGR2GRAY)
    ret, image = cv.threshold(image, 1, 255, cv.THRESH_BINARY)
    input_image_data = np.stack((image, image, image, image), axis=2)
    with tf.compat.v1.Session() as sess:
        sess.run(tf.compat.v1.initialize_all_variables())
        # saver = tf.train.Saver()
        n = 0
        epsilon = INITIAL_EPSILON
        while True:
            action_t = predict_action.eval(feed_dict={input_image: [input_image_data]})[0]
            argmax_t = np.zeros([4], dtype=np.int)
            # 每个状态以epsilon的概率进行探索
            if random.random() <= epsilon:
                max_index = random.randrange(4)
            else:
                max_index = np.argmax(action_t)
            argmax_t[max_index] = 1
            if epsilon > FINAL_EPSILON:
                epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE
            reward, image = run_game(list(argmax_t))
            image = cv.cvtColor(cv.resize(image, (120, 160)), cv.COLOR_BGR2GRAY)
            ret, image = cv.threshold(image, 1, 255, cv.THRESH_BINARY)
            image = np.reshape(image, (160, 120, 1))
            input_image_data1 = np.append(image, input_image_data[:, :, 0: 3], axis=2)
            D.append((input_image_data, argmax_t, reward, input_image_data1))
            if len(D) > REPLAY_MEMORY:
                D.popleft()
            if n > OBSERVE:
                min_batch = random.sample(D, BATCH)
                input_image_data_batch = [d[0] for d in min_batch]
                argmax_batch = [d[1] for d in min_batch]
                reward_batch = [d[2] for d in min_batch]
                input_image_data1_batch = [d[3] for d in min_batch]
                gt_batch = []
                out_batch = predict_action.eval(feed_dict={input_image: input_image_data1_batch})
                for i in range(0, len(min_batch)):
                    gt_batch.append(reward_batch[i] + LEARNING_RATE * np.max(out_batch[i]))
                # 利用梯度反向传播更新模型参数
                optimizer.run(feed_dict={gt: gt_batch, argmax: argmax_batch, input_image: input_image_data_batch})
            input_image_data = input_image_data1
            n = n + 1
            print(n, "epsilon:", epsilon, " ", "action:", max_index, " ", "reward:", reward)


train(input_image)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334

测试结果

在这里插入图片描述
测试中,贪吃蛇进行了50000轮训练,每一次训练,贪吃蛇通过策略函数选取合适的动作,并将结果存储到经验池中,也就是上述代码中的双向队列Q。贪吃蛇已基本具备躲避边缘和寻找最优路径吃苹果的能力。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/601520
推荐阅读
相关标签
  

闽ICP备14008679号