当前位置:   article > 正文

第0-(2)章-DRL的细碎笔记-长时序问题中时序信息的获取_长时序是什么意思

长时序是什么意思

文章目录

第0-(2)章-DRL的细碎笔记-长时序问题中时序信息的获取

作者:想要飞的猪

工作地点:北京科技大学

–长时序问题中时序信息的获取–

  个人看来长时序问题可以分为两类:
  (1)状态相关的长时序问题
  (2)时间相关的长时序问题
其中,(1)可以理解成结束状态(done)为某个状态,例如在游戏中战胜对手或者被对手战胜则游戏结束;(2)可以理解为有固定的时间,到达某个时间游戏结束。
  深度强化学习算法训练时使用了replay buffer的方法。replay buffer可以有效利用以往的经验,同时随机采样能够打破前后数据的相关性,能够使网络训练更稳定。但是这样就会出现一个问题,网络如何追踪数据中的时序信息,时序数据之间的相关性会不会因为随机采样而学习不到?
  时序信息不会因为随机采样而消失的,先看DDPG与PPO中buffer是如何采样与训练的,然后再分析时序信息不会消失的原因,最后再抛出问题。
  DDPG算法训练(learn)的时候,从buffer中采样是随机采样,在代码中DDPG从buffer采样的语句是:

# 代码仅仅为了说明DDPG中buffer的采样规则,很多内容没有展现
replay_buffer = ReplayBuffer(state_dim, action_dim)
...
batch_s, batch_a, batch_r, batch_s_, batch_dw = relay_buffer.sample(self.batch_size)  # Sample a batch
  • 1
  • 2
  • 3
  • 4

其中,buffer的定义如下:

class ReplayBuffer(object):
    def __init__(self, state_dim, action_dim):
        ...

    def store(self, s, a, r, s_, dw):
        ...

    def sample(self, batch_size):
        index = np.random.choice(self.size, size=batch_size) 
        batch_s = torch.tensor(self.s[index], dtype=torch.float)
        batch_a = torch.tensor(self.a[index], dtype=torch.float)
        batch_r = torch.tensor(self.r[index], dtype=torch.float)
        batch_s_ = torch.tensor(self.s_[index], dtype=torch.float)
        batch_dw = torch.tensor(self.dw[index], dtype=torch.float)

        return batch_s, batch_a, batch_r, batch_s_, batch_dw
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

可以看出这里用到了np.random.choice(self.size, size=batch_size) ,即对buffer中数据的选取是随机的。当然DDPG也能够体现数据的时序性。因为深度强化学习被定义为一个马尔可夫决策过程,即下一个状态仅跟当前的状态有关系。所以buffer中一条memory保存当前的状态s以及后一个状态s_,即便无序的buffer采样仍然能够保证从中获得时序信息(马尔可夫链的性质)。
  在PPO算法中buffer是序贯导入样本的,代码如下:

# 代码仅仅为了说明DDPG中buffer的采样规则,很多内容没有展现
replay_buffer = ReplayBuffer(args)
...
s, a, a_logprob, r, s_, dw, done = replay_buffer.numpy_to_tensor()  
...
class ReplayBuffer:
    def __init__(self, args):
        ...

    def store(self, s, a, a_logprob, r, s_, dw, done):
        # 按顺序存储s, a, a_logprob, r, s_, dw, done
        ...

    def numpy_to_tensor(self):
        s = torch.tensor(self.s, dtype=torch.float)
        a = torch.tensor(self.a, dtype=torch.long) 
        a_logprob = torch.tensor(self.a_logprob, dtype=torch.float)
        r = torch.tensor(self.r, dtype=torch.float)
        s_ = torch.tensor(self.s_, dtype=torch.float)
        dw = torch.tensor(self.dw, dtype=torch.float)
        done = torch.tensor(self.done, dtype=torch.float)

        return s, a, a_logprob, r, s_, dw, done
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

PPO序贯导入数据的原因是需要计算Advertage函数,下面给出一种Advertage函数的定义
A ^ t G A E ( γ , λ ) = ∑ l = 0 ∞ ( γ λ ) l δ t + l V \hat{A}_t^{G A E(\gamma, \lambda)}=\sum_{l=0}^{\infty}(\gamma \lambda)^l \delta_{t+l}^V A^tGAE(γ,λ)=l=0(γλ)lδt+lV其中, δ t V = r t + γ V ω ( s t + 1 ) − V ω ( s t ) \delta_t^V=r_t+\gamma V_\omega\left(s_{t+1}\right)-V_\omega\left(s_t\right) δtV=rt+γVω(st+1)Vω(st)。这里的Advertage函数是GAE(generalized advantage estimation)函数。但是在训练时仍然在在当前batch数据中采样mini_batch_size训练。在采样mini_batch_size时,仍然采用随机采样的方法。代码如下:

 # Optimize policy for K epochs:
        for _ in range(self.K_epochs):
            # Random sampling and no repetition. 'False' indicates that training will continue even if the number of samples in the last time is less than mini_batch_size
            for index in BatchSampler(SubsetRandomSampler(range(self.batch_size)), self.mini_batch_size, False):
            	...
  • 1
  • 2
  • 3
  • 4
  • 5

其中,SubsetRandomSampler是提供batch_size大小数据集中数据的随机索引,而BatchSampler是从这些索引按照mini_batch_size大小分成若干组,分组进行训练,其中False表示最后采样的数据如果小于mini_batch_size也不抛弃这组数据。所以DDPG与PPO训练的时候还使用了无序采样的方法进行训练,这样训练的原因大家可以参考一下这篇文章。
  问题:深度强化学习被描述为马尔可夫决策过程,也就是说当前状态仅与上一个状态有关,但是现存的研究工作中有设置网络的特征提取层为RNN类网络,这样的设置会让网络在决策时考虑多个先前的状态,这样的设定是否还符合马尔可夫决策过程的假设?
  感恩所有指导帮助过我的人!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Li_阴宅/article/detail/741259
推荐阅读
相关标签
  

闽ICP备14008679号