当前位置:   article > 正文

【转载】初探强化学习DQN的Pytorch代码解析

【转载】初探强化学习DQN的Pytorch代码解析

版权声明:本文为CSDN博主「难受啊!马飞…」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_33328642/article/details/123788966

首先上完整的代码。
这个代码是大连理工的一个小姐姐提供的。小姐姐毕竟是小姐姐,心细如丝,把理论讲的很清楚。但是代码我没怎么听懂。小姐姐在B站的视频可以给大家提供一下。不过就小姐姐这个名字,其实我是怀疑她是抠脚大汉,女装大佬。

不说了,先上完整的代码吧

1. 完整的代码

import gym
import math
import random
import numpy as np
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import time
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode

env = gym.make(‘SpaceInvaders-v0’).unwrapped

# if gpu is to be used
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

######################################################################
# Replay Memory

Transition = namedtuple(‘Transition’,
(‘state’, ‘action’, ‘next_state’, ‘reward’))

class ReplayMemory(object):

<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> capacity<span class="token punctuation">)</span><span class="token punctuation">:</span>
    self<span class="token punctuation">.</span>memory <span class="token operator">=</span> deque<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">]</span><span class="token punctuation">,</span> maxlen<span class="token operator">=</span>capacity<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">push</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> <span class="token operator">*</span>args<span class="token punctuation">)</span><span class="token punctuation">:</span>
    self<span class="token punctuation">.</span>memory<span class="token punctuation">.</span>append<span class="token punctuation">(</span>Transition<span class="token punctuation">(</span><span class="token operator">*</span>args<span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">sample</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> batch_size<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">return</span> random<span class="token punctuation">.</span>sample<span class="token punctuation">(</span>self<span class="token punctuation">.</span>memory<span class="token punctuation">,</span> batch_size<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">__len__</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">return</span> <span class="token builtin">len</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>memory<span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

######################################################################
# DQN algorithm

class DQN(nn.Module):

<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> h<span class="token punctuation">,</span> w<span class="token punctuation">,</span> outputs<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token builtin">super</span><span class="token punctuation">(</span>DQN<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>conv1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">32</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token number">8</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>bn1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span><span class="token number">32</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>conv2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">32</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>bn2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>conv3 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>bn3 <span class="token operator">=</span> nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">)</span>

    <span class="token keyword">def</span> <span class="token function">conv2d_size_out</span><span class="token punctuation">(</span>size<span class="token punctuation">,</span> kernel_size<span class="token punctuation">,</span> stride<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token keyword">return</span> <span class="token punctuation">(</span>size <span class="token operator">-</span> <span class="token punctuation">(</span>kernel_size <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">//</span> stride  <span class="token operator">+</span> <span class="token number">1</span>
    convw <span class="token operator">=</span> conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>w<span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
    convh <span class="token operator">=</span> conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>h<span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
    linear_input_size <span class="token operator">=</span> convw <span class="token operator">*</span> convh <span class="token operator">*</span> <span class="token number">64</span>
    self<span class="token punctuation">.</span>l1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>linear_input_size<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>l2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> outputs<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
    x <span class="token operator">=</span> x<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn1<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv1<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn2<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv2<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn3<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv3<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>l1<span class="token punctuation">(</span>x<span class="token punctuation">.</span>view<span class="token punctuation">(</span>x<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> self<span class="token punctuation">.</span>l2<span class="token punctuation">(</span>x<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

######################################################################
# Input extraction

resize = T.Compose([T.ToPILImage(),
T.Grayscale(num_output_channels=1),
T.Resize((84, 84), interpolation=InterpolationMode.BICUBIC),
T.ToTensor()])

def get_screen():
# Transpose it into torch order (CHW).
screen = env.render(mode=‘rgb_array’).transpose((2, 0, 1))
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
screen = torch.from_numpy(screen)
# Resize, and add a batch dimension (BCHW)
return resize(screen).unsqueeze(0)

######################################################################
# Training

# 参数和网络初始化
BATCH_SIZE = 32
GAMMA = 0.99
EPS_START = 1.0
EPS_END = 0.1
EPS_DECAY = 10000
TARGET_UPDATE = 10

init_screen = get_screen()
_, _, screen_height, screen_width = init_screen.shape

# Get number of actions from gym action space
n_actions = env.action_space.n

policy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(100000)

steps_done = 0

def select_action(state):
global steps_done
sample = random.random()
eps_threshold = EPS_END + (EPS_START - EPS_END)
math.exp(-1. steps_done / EPS_DECAY)
steps_done += 1
if sample > eps_threshold:
with torch.no_grad():
return policy_net(state).max(1)[1].view(1, 1)
else:
return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)

episode_durations = []

def plot_durations():
plt.figure(1)
plt.clf()
durations_t = torch.tensor(episode_durations, dtype=torch.float)
plt.title(‘Training…’)
plt.xlabel(‘Episode’)
plt.ylabel(‘Duration’)
plt.plot(durations_t.numpy())
# Take 100 episode averages and plot them too
if len(durations_t) >= 100:
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
means = torch.cat((torch.zeros(99), means))
plt.plot(means.numpy())

plt<span class="token punctuation">.</span>pause<span class="token punctuation">(</span><span class="token number">0.001</span><span class="token punctuation">)</span>  <span class="token comment"># pause a bit so that plots are updated</span>
  • 1

def optimize_model():
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
batch = Transition(zip(transitions))

<span class="token comment"># Compute a mask of non-final states and concatenate the batch elements</span>
<span class="token comment"># (a final state would've been the one after which simulation ended)</span>
non_final_mask <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token builtin">tuple</span><span class="token punctuation">(</span><span class="token builtin">map</span><span class="token punctuation">(</span><span class="token keyword">lambda</span> s<span class="token punctuation">:</span> s <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span><span class="token punctuation">,</span> batch<span class="token punctuation">.</span>next_state<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                              device<span class="token operator">=</span>device<span class="token punctuation">,</span> dtype<span class="token operator">=</span>torch<span class="token punctuation">.</span><span class="token builtin">bool</span><span class="token punctuation">)</span>
non_final_next_states <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>s <span class="token keyword">for</span> s <span class="token keyword">in</span> batch<span class="token punctuation">.</span>next_state <span class="token keyword">if</span> s <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
state_batch <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>batch<span class="token punctuation">.</span>state<span class="token punctuation">)</span>
action_batch <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>batch<span class="token punctuation">.</span>action<span class="token punctuation">)</span>
reward_batch <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span>batch<span class="token punctuation">.</span>reward<span class="token punctuation">)</span>

state_action_values <span class="token operator">=</span> policy_net<span class="token punctuation">(</span>state_batch<span class="token punctuation">)</span><span class="token punctuation">.</span>gather<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> action_batch<span class="token punctuation">)</span>
next_state_values <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>BATCH_SIZE<span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span>
next_state_values<span class="token punctuation">[</span>non_final_mask<span class="token punctuation">]</span> <span class="token operator">=</span> target_net<span class="token punctuation">(</span>non_final_next_states<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span>detach<span class="token punctuation">(</span><span class="token punctuation">)</span>
expected_state_action_values <span class="token operator">=</span> <span class="token punctuation">(</span>next_state_values <span class="token operator">*</span> GAMMA<span class="token punctuation">)</span> <span class="token operator">+</span> reward_batch

<span class="token comment"># Compute Huber loss</span>
criterion <span class="token operator">=</span> nn<span class="token punctuation">.</span>MSELoss<span class="token punctuation">(</span><span class="token punctuation">)</span>
loss <span class="token operator">=</span> criterion<span class="token punctuation">(</span>state_action_values<span class="token punctuation">,</span> expected_state_action_values<span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment"># Optimize the model</span>
optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">for</span> param <span class="token keyword">in</span> policy_net<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    param<span class="token punctuation">.</span>grad<span class="token punctuation">.</span>data<span class="token punctuation">.</span>clamp_<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

def random_start(skip_steps=30, m=4):
env.reset()
state_queue = deque([], maxlen=m)
next_state_queue = deque([], maxlen=m)
done = False
for i in range(skip_steps):
if (i+1) <= m:
state_queue.append(get_screen())
elif m < (i + 1) <= 2*m:
next_state_queue.append(get_screen())
else:
state_queue.append(next_state_queue[0])
next_state_queue.append(get_screen())

    action <span class="token operator">=</span> env<span class="token punctuation">.</span>action_space<span class="token punctuation">.</span>sample<span class="token punctuation">(</span><span class="token punctuation">)</span>
    _<span class="token punctuation">,</span> _<span class="token punctuation">,</span> done<span class="token punctuation">,</span> _ <span class="token operator">=</span> env<span class="token punctuation">.</span>step<span class="token punctuation">(</span>action<span class="token punctuation">)</span>
    <span class="token keyword">if</span> done<span class="token punctuation">:</span>
        <span class="token keyword">break</span>
<span class="token keyword">return</span> done<span class="token punctuation">,</span> state_queue<span class="token punctuation">,</span> next_state_queue
  • 1
  • 2
  • 3
  • 4
  • 5

######################################################################
# Start Training

num_episodes = 10000
m = 4
for i_episode in range(num_episodes):
# Initialize the environment and state
done, state_queue, next_state_queue = random_start()
if done:
continue

state <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token builtin">tuple</span><span class="token punctuation">(</span>state_queue<span class="token punctuation">)</span><span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
<span class="token keyword">for</span> t <span class="token keyword">in</span> count<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    reward <span class="token operator">=</span> <span class="token number">0</span>
    m_reward <span class="token operator">=</span> <span class="token number">0</span>
    <span class="token comment"># 每m帧完成一次action</span>
    action <span class="token operator">=</span> select_action<span class="token punctuation">(</span>state<span class="token punctuation">)</span>

    <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>m<span class="token punctuation">)</span><span class="token punctuation">:</span>
        _<span class="token punctuation">,</span> reward<span class="token punctuation">,</span> done<span class="token punctuation">,</span> _ <span class="token operator">=</span> env<span class="token punctuation">.</span>step<span class="token punctuation">(</span>action<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        <span class="token keyword">if</span> <span class="token keyword">not</span> done<span class="token punctuation">:</span>
            next_state_queue<span class="token punctuation">.</span>append<span class="token punctuation">(</span>get_screen<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        <span class="token keyword">else</span><span class="token punctuation">:</span>
            <span class="token keyword">break</span>
        m_reward <span class="token operator">+=</span> reward

    <span class="token keyword">if</span> <span class="token keyword">not</span> done<span class="token punctuation">:</span>
        next_state <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token builtin">tuple</span><span class="token punctuation">(</span>next_state_queue<span class="token punctuation">)</span><span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
    <span class="token keyword">else</span><span class="token punctuation">:</span>
        next_state <span class="token operator">=</span> <span class="token boolean">None</span>
        m_reward <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">150</span>
    m_reward <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>m_reward<span class="token punctuation">]</span><span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span>

    memory<span class="token punctuation">.</span>push<span class="token punctuation">(</span>state<span class="token punctuation">,</span> action<span class="token punctuation">,</span> next_state<span class="token punctuation">,</span> m_reward<span class="token punctuation">)</span>

    state <span class="token operator">=</span> next_state
    optimize_model<span class="token punctuation">(</span><span class="token punctuation">)</span>

    <span class="token keyword">if</span> done<span class="token punctuation">:</span>
        episode_durations<span class="token punctuation">.</span>append<span class="token punctuation">(</span>t <span class="token operator">+</span> <span class="token number">1</span><span class="token punctuation">)</span>
        plot_durations<span class="token punctuation">(</span><span class="token punctuation">)</span>
        <span class="token keyword">break</span>

<span class="token comment"># Update the target network, copying all weights and biases in DQN</span>
<span class="token keyword">if</span> i_episode <span class="token operator">%</span> TARGET_UPDATE <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
    target_net<span class="token punctuation">.</span>load_state_dict<span class="token punctuation">(</span>policy_net<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span>policy_net<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'weights/policy_net_weights_{0}.pth'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i_episode<span class="token punctuation">)</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36

print(‘Complete’)
env.close()
torch.save(policy_net.state_dict(), ‘weights/policy_net_weights.pth’)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262

2. 逐个函数的解析

2.1 定义Replay Memary

改代码中使用具名元组namedtuple()定义一个Transition ,用于存储agent与环境交互的(s,a,r,s_)

Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))

 
 
  • 1
  • 1

这个具名元组很简单
举个例子:

Student = namedtuple('Student', ('name', 'gender'))
s = Student('小花', '女')#给属性赋值
  • 1
  • 2

# 属性访问,有多种方法访问属性
第一种方法
print(s.name)
print(s.gender)
‘’’
小花

‘’‘

第二种方法
print(s[0])
print(s[1])
’‘’
小花

‘’‘

还可以迭代
for i in s:
print(i)
’‘’
小花

‘’'

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

2.2 ReplayMemory

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)#deque是为了实现插入和删除操作的双向列表,适用于队列和栈:
    def push(self, *args):
        self.memory.append(Transition(*args))
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)#使用random.sample从memory中随机抽取batch_size个数据
    def __len__(self):
        return len(self.memory)

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • def init(self, capacity)没啥好说的,就是定义一个双向列表。
  • def push(self, *args)就是向memory中添加Transition,这个memary是一个列表,后面会详解。
  • def sample(self, batch_size)是随机采样。random.sample()其中的第一个参数是即将被采样的列表,第二个参数采样的批次。这个大家应该都懂。后面我也有例子。

2.3 DQN algorithm

class DQN(nn.Module):
    def __init__(self, h, w, outputs):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)#设置第一个卷积层
        self.bn1 = nn.BatchNorm2d(32)#设置第一个卷积层的偏置
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)#设置第二个卷积层
        self.bn2 = nn.BatchNorm2d(64)#设置第2个卷积层的偏置
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)#设置第3个卷积层
        self.bn3 = nn.BatchNorm2d(64)#设置第3个卷积层的偏置
        def conv2d_size_out(size, kernel_size, stride):
            return (size - (kernel_size - 1) - 1) // stride  + 1
        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w, 8, 4), 4, 2), 3, 1)#,输入84 宽  7
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h, 8, 4), 4, 2), 3, 1)#,输入84 高  7
        linear_input_size = convw * convh * 64
        #计算最终的尺寸,因为最后的feature map的尺寸是7*7*64,如果拉长为1*n,则是7*7*64 = 3136
        self.l1 = nn.Linear(linear_input_size, 512)#这边就是先从3136到512.也就是全连接层的神经元的个数,说实话,这个方法好low
        self.l2 = nn.Linear(512, outputs)#最后模型输出为2,两个动作么。
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
    x <span class="token operator">=</span> x<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn1<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv1<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#用激活函数处理C1</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn2<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv2<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#用激活函数处理C2</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn3<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv3<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#用激活函数处理C3</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>l1<span class="token punctuation">(</span>x<span class="token punctuation">.</span>view<span class="token punctuation">(</span>x<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#将第3次卷积的输出拉伸为一行</span>
    <span class="token keyword">return</span> self<span class="token punctuation">.</span>l2<span class="token punctuation">(</span>x<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#-1表示不知道数据由多少行,但是直到最后的数据一定是512列</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

这是一个常规的使用pytorch搭建网络模型的框架,相信大家都懂。而且我在里面也注释了。
需要注意的一点是:

  • def conv2d_size_out(size, kernel_size, stride):这个其实就是求最后一个卷积层的feature map的尺寸。这个DQN输入的是8484的图像,按照上面的代码,最后一层的feature map的尺寸就是77,一共64个。这样做只是为了和第一个全连接层衔接一下。其实吧,这样做感觉有点多余,正常的代码用flatten()就可以了。关于如何拉平feature map,大家可以看看其他方法。
  • 运行下面代码查看,当只有两个动作时,这个网络的输出。我一开始以为网络的输出应该也是按照批次来的,也就是说当模型使出32个批次的两个动作的q值应该是这个样的:[32,1,2].也就说是应该是32个1行两列的。但是实际上,是[32,2].即32行两列。这样就能解释代码的结构了。但是当我把模型拆开了之后才发现
class DQN(nn.Module):
    def __init__(self, h, w, outputs):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)#设置第一个卷积层
        self.bn1 = nn.BatchNorm2d(32)#设置第一个卷积层的偏置
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)#设置第二个卷积层
        self.bn2 = nn.BatchNorm2d(64)#设置第2个卷积层的偏置
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)#设置第3个卷积层
        self.bn3 = nn.BatchNorm2d(64)#设置第3个卷积层的偏置
    <span class="token keyword">def</span> <span class="token function">conv2d_size_out</span><span class="token punctuation">(</span>size<span class="token punctuation">,</span> kernel_size<span class="token punctuation">,</span> stride<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token keyword">return</span> <span class="token punctuation">(</span>size <span class="token operator">-</span> <span class="token punctuation">(</span>kernel_size <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">//</span> stride  <span class="token operator">+</span> <span class="token number">1</span>
    convw <span class="token operator">=</span> conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>w<span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token comment">#,输入84 宽  7</span>
    convh <span class="token operator">=</span> conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>conv2d_size_out<span class="token punctuation">(</span>h<span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token comment">#,输入84 高  7</span>
    linear_input_size <span class="token operator">=</span> convw <span class="token operator">*</span> convh <span class="token operator">*</span> <span class="token number">64</span>
    <span class="token comment">#计算最终的尺寸,因为最后的feature map的尺寸是7*7*64,如果拉长为1*n,则是7*7*64 = 3136</span>
    self<span class="token punctuation">.</span>l1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>linear_input_size<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">)</span><span class="token comment">#这边就是先从3136到512.也就是全连接层的神经元的个数,说实话,这个方法好low</span>
    self<span class="token punctuation">.</span>l2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> outputs<span class="token punctuation">)</span><span class="token comment">#最后模型输出为2,两个动作么。</span>

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment">#x = x.to(device)</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn1<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv1<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#用激活函数处理C1</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn2<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv2<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#用激活函数处理C2</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>bn3<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv3<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#用激活函数处理C3</span>
    x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>l1<span class="token punctuation">(</span>x<span class="token punctuation">.</span>view<span class="token punctuation">(</span>x<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#将第3次卷积的输出拉伸为一行</span>
    <span class="token keyword">return</span> self<span class="token punctuation">.</span>l2<span class="token punctuation">(</span>x<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#-1表示不知道数据由多少行,但是直到最后的数据一定是512列</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

policy_net = DQN(84, 84, 2)#Q
x = torch.rand(32,4,84, 84)

xout = policy_net(x)

print(xout.size())
#[32,2]
print(xout)

tensor([[ 3.4981e-02, 3.1048e-02],
[ 1.4112e-01, -5.2676e-02],
[-3.3868e-01, 3.9583e-02],
[ 7.5908e-02, -1.2230e-01],
[ 1.4027e-01, -1.7528e-02],
[-1.0966e-02, 6.2111e-02],
[-2.2511e-02, -6.1829e-02],
[ 3.2599e-02, -8.9155e-02],
[ 9.7833e-02, -5.0325e-02],
[-6.4633e-02, -8.8093e-02],
[-4.3771e-02, 1.5452e-01],
[-1.7478e-01, -1.3224e-01],
[ 1.9658e-02, 8.1575e-03],
[-1.6989e-01, -6.6487e-03],
[-1.6566e-01, -1.0833e-01],
[-9.5961e-02, 1.1235e-02],
[ 1.0005e-01, -1.1150e-02],
[ 1.8165e-02, 9.9491e-03],
[-2.3947e-01, 9.7802e-02],
[-5.2116e-02, 4.8583e-02],
[ 2.2504e-02, 3.8262e-04],
[-1.1822e-01, -2.0696e-01],
[-1.4129e-01, -1.9254e-01],
[-2.2170e-01, -1.2232e-01],
[ 3.3542e-02, 3.3005e-03],
[ 1.5150e-01, 1.5330e-01],
[-2.3675e-01, -2.4939e-01],
[-1.0502e-01, 7.2696e-02],
[-1.3213e-01, 1.5113e-01],
[ 6.1988e-02, 2.5367e-02],
[-4.2924e-01, -4.0167e-02],
[ 5.1474e-02, 2.6885e-01]], grad_fn=<AddmmBackward0>)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68

2.4 图像预处理

resize = T.Compose([T.ToPILImage(),
                    T.Grayscale(num_output_channels=1),
                    T.Resize((84, 84), interpolation=InterpolationMode.BICUBIC),
                    T.ToTensor()])

 
 
  • 1
  • 2
  • 3
  • 4
  • 1
  • 2
  • 3
  • 4

#Compose法是将多种变换组合在一起。在这个步骤中,有Resize,灰度处理,
#ToTensor将PILImage转变为torch.FloatTensor的数据形式
#ToPILImage将shape为(C,H,W)的Tensor或shape为(H,W,C)的numpy.ndarray转换成PIL.Image,值不变

2.5 截屏函数

def get_screen():
    #截取游戏的屏幕,用于做训练数据的状态
    # Transpose it into torch order (CHW).
    screen = env.render(mode='rgb_array').transpose((2, 0, 1))
    #env.render扮演图像引擎的作用,以便直观地显示当前环境。transpose将图像的通道数换到最前面
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    #ascontiguousarray函数将一个内存不连续存储的数组转换为内存连续存储的数组,使得运行速度更快。
    screen = torch.from_numpy(screen)#即 从numpy.ndarray创建一个张量。
    # Resize, and add a batch dimension (BCHW)
    return resize(screen).unsqueeze(0)#在第0维度增加一个维度,让图像从chw变成bchw。其中b表示批次

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

2.6 超参数

# 参数和网络初始化
BATCH_SIZE = 32#从transition提取样本的批次大小
GAMMA = 0.99#衰减系数
EPS_START = 1.0#贪婪参数初始值
EPS_END = 0.1#贪婪参数最小值
EPS_DECAY = 10000#贪婪参数变化次数
TARGET_UPDATE = 10#target net更新次数
init_screen = get_screen()#采集游戏画面,尺寸[32,4,84,84],第一个是批次的大小,第二个图像数量,最后两个是图像尺寸
_, _, screen_height, screen_width = init_screen.shape#得到画面的尺寸:宽高
n_actions = env.action_space.n#获取游戏的动作空间,左右两个
#初始化模型
policy_net = DQN(screen_height, screen_width, n_actions).to(device)#Q
target_net = DQN(screen_height, screen_width, n_actions).to(device)#T
target_net.load_state_dict(policy_net.state_dict())#初始阶段target net和main net是一样的参数
target_net.eval()#表示步更新,只评估输出。
optimizer = optim.RMSprop(policy_net.parameters())#使用RMSprop优化网络
memory = ReplayMemory(100000)#定义经验池的容量capacity
steps_done = 0

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

这边没什么可说的,大家都看得懂。

policy_net = DQN(screen_height, screen_width, n_actions).to(device)#Q
target_net = DQN(screen_height, screen_width, n_actions).to(device)#T

 
 
  • 1
  • 2
  • 1
  • 2

这两句我师妹问过我是什么意思
这个其实就是初始化模型。只是作者在写这个代码的时候还有其他参数,因此需要带参初始化。
正常情况,我们写一个模型时,初始化没这么麻烦。

2.7 选择动作的函数

#动作选择函数,首先看的就是探索和开发的阈值系数 eps[0,1]
def select_action(state):
    global steps_done
    sample = random.random()## 产生 0 到 1 之间的随机浮点数
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1.*steps_done / EPS_DECAY)#最小到0.427
    steps_done += 1
    if sample > eps_threshold:#判断是随即动作还是最优动作
        #sample是(0,1),eps_threshold越来越小,一开始是选择最优策略(开发)
        with torch.no_grad():#torch.no_grad()一般用于神经网络的推理阶段, 表示张量的计算过程中无需计算梯度
            return policy_net(state).max(1)[1].view(1, 1)#使用最优动作
    else:
        #到后期会越来越趋向于(探索),u而就是随机选择一个动作。
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)#随机选择动作
#random.randrange(N)在0-N之间随机生成一个数,N是动作空间数

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 这边主要解释一下这个控制eps变量的eps_threshold
    其实这是一个单调递减函数,我把这个函数的曲线画出来了。按照作者的意思,这个eps_threshold的最小值时0.427.看下图
    在这里插入图片描述
    大家可以按照下面的函数自己运行一下:
    需要提醒的是,我们可以在这个函数里的i后面乘以一个数来控制eps_threshold的最小值。
    比如我把在i后面乘以2,那么eps_threshold数值会下降2倍。
plt.figure(1)
ax = plt.subplot(111)
x = np.linspace(0, 1000, 1000)  # 在0到2pi之间,均匀产生200点的数组
print(x)
r1 = []
for i in range(1000):
    r = 0.1 + (0.99 - 0.1) * \
        math.exp(-1.*(i / 1000))
    r1.append(r)
print(r1)
ax.plot(x, r1)
plt.show()

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

2.8 画图函数

episode_durations = []#存储训练过程数据的列表
def plot_durations():
    plt.figure(1)
    plt.clf()#清除当前图形及其所有轴,但保持窗口打开,以便可以将其重新用于其他绘图。有了这个再次运行就不要关掉所有figure了
    durations_t = torch.tensor(episode_durations, dtype=torch.float)#转换成张量。
    plt.title('Training...')#图的名字
    plt.xlabel('Episode')#x轴坐标名
    plt.ylabel('Duration')#y轴坐标名
    plt.plot(durations_t.numpy())#画图
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())
plt<span class="token punctuation">.</span>pause<span class="token punctuation">(</span><span class="token number">0.001</span><span class="token punctuation">)</span>  <span class="token comment"># pause a bit so that plots are updated</span>
  • 1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

这个没啥说的

2.9 优化器

def optimize_model():
    if len(memory) < BATCH_SIZE:#查看记忆池是否存满
        return
    transitions = memory.sample(BATCH_SIZE)#从记忆池中随即采集BATCH_SIZE个样本
    batch = Transition(*zip(*transitions))#zip表示交叉元素,*号代表拆分
    # Compute a mask of non-final states and concatenate the batch elements
    # 计算非最终状态的掩码并连接批处理元素
    # (a final state would've been the one after which simulation ended)
    # 最终的状态应该是模拟结束后的状态
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)),device=device, dtype=torch.bool)
    #首先分析map()函数,labbda是一个简单的函数。把transition中的next_state赋值给s。
    #tuple()将状态转换为元组,元组是无法修改的
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch  = torch.cat(batch.state) #合并batch中的状态 32个,竖着合并到一起尺寸是:[32,[s]]
    action_batch = torch.cat(batch.action)#合并batch中的动作,竖着合并到一起尺寸是:[32,[a]]
    reward_batch = torch.cat(batch.reward)#合并batch中的奖励,竖着合并到一起尺寸是:[32,[r]]
<span class="token comment">#然后将这些数据,首先是state_batch按批次送到网络中,</span>
<span class="token comment">#策略函数输入状态:image,输出一个,512列的张量。在批处理中,应该是[32,1,512]</span>
state_action_values <span class="token operator">=</span> policy_net<span class="token punctuation">(</span>state_batch<span class="token punctuation">)</span><span class="token punctuation">.</span>gather<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> action_batch<span class="token punctuation">)</span><span class="token comment">#列号变动,因为是512列</span>
next_state_values <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>BATCH_SIZE<span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span><span class="token comment">#32维的张量</span>
next_state_values<span class="token punctuation">[</span>non_final_mask<span class="token punctuation">]</span> <span class="token operator">=</span> target_net<span class="token punctuation">(</span>non_final_next_states<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span>detach<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token comment">#按行求最大值,并提取对应的最大值。</span>
expected_state_action_values <span class="token operator">=</span> reward_batch <span class="token operator">+</span> <span class="token punctuation">(</span>next_state_values <span class="token operator">*</span> GAMMA<span class="token punctuation">)</span><span class="token comment">#更新状态值函数</span>

<span class="token comment"># Compute Huber loss</span>
criterion <span class="token operator">=</span> nn<span class="token punctuation">.</span>MSELoss<span class="token punctuation">(</span><span class="token punctuation">)</span>
loss <span class="token operator">=</span> criterion<span class="token punctuation">(</span>state_action_values<span class="token punctuation">,</span> expected_state_action_values<span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#计算损失函数</span>

<span class="token comment"># Optimize the model</span>
optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">for</span> param <span class="token keyword">in</span> policy_net<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    param<span class="token punctuation">.</span>grad<span class="token punctuation">.</span>data<span class="token punctuation">.</span>clamp_<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

来了,来了。我之前跟学生讲课的时候经常说,向看懂一个代码,一个算法。一定要搞清楚他们数据的流向,以及数据尺寸的变换流程。
然后第一块需要详细了解代码,这两块代码跟上面两个说会在后面讲是一块的。我需要举个例子。

    transitions = memory.sample(BATCH_SIZE)#从记忆池中随即采集BATCH_SIZE个样本
    batch = Transition(*zip(*transitions))#zip表示交叉元素,*号代表拆分

 
 
  • 1
  • 2
  • 1
  • 2

首先第一行是从memory中随机抽取一批样本,我们默认是32.。
然后就是下面的batch了。我们具体举个例子,一看便知。

import torch
import random
from collections import namedtuple, deque
#创建一个双向数组,队列长度是100。跟上面一样的
memory = deque([], maxlen=100)
#定义我们的Transition 。跟上面一样的
Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))
#给Transition 实例化
s1 = Transition(2,3,4,5)
s2 = Transition(1,2,3,4)
s3 = Transition(1,4,5,2)
s4 = Transition(2,5,7,3)
#然后赋值给memory
memory.append(s1)
memory.append(s2)
memory.append(s3)
memory.append(s4)
print(memory)
#原始的memory是这样的
#deque([Transition(state=2, action=3, next_state=4, reward=5), Transition(state=1, action=2, next_state=3, reward=4), Transition(state=1, action=4, next_state=5, reward=2), Transition(state=2, action=5, next_state=7, reward=3)], maxlen=100)
#随机采样2个批次
m2 = random.sample(memory, 2)
#采样后是这样的
#[Transition(state=1, action=4, next_state=5, reward=2), Transition(state=2, action=3, next_state=4, reward=5)]
#来了来了,
batch = Transition(*zip(*m2))
print(batch)
#Transition(state=(1, 2), action=(4, 3), next_state=(5, 4), reward=(2, 5))
#batch = Transition(*zip(*transitions))这句代码的一些列操作为了把单个的s,a,r,s_都给合并到一起。
#接着上面的代码,我们逐行下下看数据的变换格式
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool)
print(non_final_mask)
#输出的是:tensor([True, True])
#也就是说,这个non_final_mask生成的是bool型变量,判断该状态是不是最终状态。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

下面面大家按照这个转换格式,就知道最后走势如何处理的了。
下面看一下这个语句

    state_action_values = policy_net(state_batch).gather(1, action_batch)#列号标动,因为是2列

 
 
  • 1
  • 1

这个gather其实不是理解的聚集。
而类似与Qtable中的查表。计算的是Q值

  • policy_net(state_batch)这部分输入的是48484的图像,输出的是一个3212的张量,表示动作Q值。32是批次
  • .gather(1, action_batch),参考这个博客
  • 主要是gather中的这个action_batch,这个变量是动作标号。
    解释这个模块目前来讲直接解释还是有点困难,因为它是在很多前处理之后的。
    我们先向下看:

2.10 随机开始

def random_start(skip_steps=30, m=4):
    env.reset()#重新初始化函数,智能体每进行一次尝试到达终止状态后,都要重新开始再尝试,所以需要智能体有重新初始化功能。
    state_queue = deque([], maxlen=m)     #  当前状态    m等于4表示采集四张图像,每采集4帧会跳30帧
    next_state_queue = deque([], maxlen=m)#下一个状态
    done = False#done又是是否结束
    for i in range(skip_steps):
        if (i+1) <= m:   #i<m表示还没采集满4张图,
            state_queue.append(get_screen())#则向状态序列中继续添加图像
        elif m < (i + 1) <= 2*m:#如果大于4张,小于8张,
            next_state_queue.append(get_screen())#则将这些图像保存到下一个状态
        else:
            state_queue.append(next_state_queue[0])
            #否则的话就是大于8张,就是大于两个状态的,把上一个nextstate中的图像放到这个当前的state_queue
            next_state_queue.append(get_screen())
            #把当前的图像继续存放到下一个状态中。
            #由于两个状态容器都是用deque()的方式,因此
    action <span class="token operator">=</span> env<span class="token punctuation">.</span>action_space<span class="token punctuation">.</span>sample<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token comment">#采集一个动作</span>
    _<span class="token punctuation">,</span> _<span class="token punctuation">,</span> done<span class="token punctuation">,</span> _ <span class="token operator">=</span> env<span class="token punctuation">.</span>step<span class="token punctuation">(</span>action<span class="token punctuation">)</span><span class="token comment">#输入动作action,输出为:下一步状态,立即回报,是否终止,调试信息</span>
    <span class="token keyword">if</span> done<span class="token punctuation">:</span>
        <span class="token keyword">break</span>
<span class="token keyword">return</span> done<span class="token punctuation">,</span> state_queue<span class="token punctuation">,</span> next_state_queue
  • 1
  • 2
  • 3
  • 4
  • 5
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

2.11 开始训练

# Start Training
  • 1

num_episodes = 10000
m = 4 #4张图像S
for i_episode in range(num_episodes):#迭代10000次
# Initialize the environment and state初始化环境和状态
done, state_queue, next_state_queue = random_start()
if done:
continue
state = torch.cat(tuple(state_queue), dim=1)#状态转换成元组
for t in count():
reward = 0
m_reward = 0
# 每m帧完成一次action
action = select_action(state)#根据当前状态选择一个动作。
for i in range(m):
_, reward, done, _ = env.step(action.item())#与环境交互获取奖励和是否终止
if not done:#如果不是终止状态,则
next_state_queue.append(get_screen())#采集图像添加到下一个状态
else:#如果是终止状态(者打完),就跳出循环
break
m_reward += reward#增加奖励

    <span class="token keyword">if</span> <span class="token keyword">not</span> done<span class="token punctuation">:</span><span class="token comment">#如果没有结束,</span>
        next_state <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token builtin">tuple</span><span class="token punctuation">(</span>next_state_queue<span class="token punctuation">)</span><span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
    <span class="token keyword">else</span><span class="token punctuation">:</span><span class="token comment">#如果结束,</span>
        next_state <span class="token operator">=</span> <span class="token boolean">None</span><span class="token comment">#没有下一个状态,表示是死亡</span>
        m_reward <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">150</span><span class="token comment">#那么奖励直接-150</span>
    m_reward <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>m_reward<span class="token punctuation">]</span><span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span>
    memory<span class="token punctuation">.</span>push<span class="token punctuation">(</span>state<span class="token punctuation">,</span> action<span class="token punctuation">,</span> next_state<span class="token punctuation">,</span> m_reward<span class="token punctuation">)</span><span class="token comment">#将这个环节的transition添加memary中</span>
    state <span class="token operator">=</span> next_state<span class="token comment">#将这个nextstate更新为当前状态</span>
    optimize_model<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token comment">#开始优化模型</span>

    <span class="token keyword">if</span> done<span class="token punctuation">:</span><span class="token comment">#如果结束了,</span>
        episode_durations<span class="token punctuation">.</span>append<span class="token punctuation">(</span>t <span class="token operator">+</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token comment">#将过程数据添加到列表中</span>
        plot_durations<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token comment">#画图</span>
        <span class="token keyword">break</span>

<span class="token comment"># Update the target network, copying all weights and biases in DQN</span>
<span class="token keyword">if</span> i_episode <span class="token operator">%</span> TARGET_UPDATE <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span><span class="token comment">#怕那段是否达到指定步骤,到达指定步骤则更新target</span>
    target_net<span class="token punctuation">.</span>load_state_dict<span class="token punctuation">(</span>policy_net<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span>policy_net<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'weights/policy_net_weights_{0}.pth'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>i_episode<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token comment">#保存模型</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

print(‘Complete’)
env.close()#关闭环境
torch.save(policy_net.state_dict(), ‘weights/policy_net_weights.pth’)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

在这里插入图片描述
详细细节大家直接运行代码可能会很麻烦
我自己写了个效地demo来验证数据的流程了

import random
import torch
from collections import namedtuple, deque
  • 1
  • 2
  • 3

state_que = deque([], maxlen=4)

memory = deque([], maxlen=100)
Transition = namedtuple(‘Transition’,(‘state’, ‘action’, ‘next_state’, ‘reward’))
st1 = torch.rand(2,2)
st2 = torch.rand(2,2)
st3 = torch.rand(2,2)
st4 = torch.rand(2,2)

a1 = torch.ones(1)
a2 = torch.ones(1)
a3 = torch.ones(1)
a4 = torch.ones(1)

#模拟截屏代码get_screen,并将其处理成(1,1,84,84)的格式,在本文中,我是用图像格式为2*2
nst1 = torch.rand(2,2)#unsqueeze(0)
nst1 = nst1.unsqueeze(0)
nst1 = nst1.unsqueeze(0)
nst2 = torch.rand(2,2)
nst2 = nst2.unsqueeze(0)
nst2 = nst2.unsqueeze(0)
nst3 = torch.rand(2,2)
nst3 = nst3.unsqueeze(0)
nst3 = nst3.unsqueeze(0)
nst4 = torch.rand(2,2)
nst4 = nst4.unsqueeze(0)
nst4 = nst4.unsqueeze(0)

#将相应的变量添加到Transition中
s1 = Transition(st1,a1,nst1,5)
s2 = Transition(st2,a2,nst2,4)
s3 = Transition(st3,a3,nst3,2)
s4 = Transition(st4,a4,nst4,3)
#添加到state_que中
state_que.append(nst1)
state_que.append(nst2)
state_que.append(nst3)
state_que.append(nst4)
print(‘state_que’,state_que)
#转换成元组
print(‘转换成元组和拼接’)
state = torch.cat(tuple(state_que), dim=1)
print(‘state’,state)
print(‘statesize’,state.size())

memory.append(s1)
memory.append(s2)
memory.append(s3)
memory.append(s4)

#print(memory)

m2 = random.sample(memory, 2)
print(‘m2’,m2)
print()
batch = Transition(zip(m2))
print(‘zip*-----------------------’)
print(‘batch:000’,batch.state)
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.state)), dtype=torch.bool)
print(non_final_mask)
state_batch = torch.cat(batch.next_state)
print(‘next_state_batch’,state_batch)
print('state_batch_size = ',state_batch.size())
action_batch = torch.cat(batch.action)
print(‘action_batch’,action_batch)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号