当前位置:   article > 正文

DI-engine强化学习入门(十)如何使用RNN——模型构建和包装

DI-engine强化学习入门(十)如何使用RNN——模型构建和包装

一、RNN简介
循环神经网络(Recurrent Neural Network, RNN)是一类用于处理序列数据的神经网络。与传统的前馈神经网络不同,RNN引入了“内部状态”(或称为“隐藏状态”),使得网络能够存储过去的信息,并利用这些信息影响后续的输出。这个内部状态的更新过程使得RNN能够处理不同长度的输入序列,比如文字或语音数据。

RNN的特点是在不同时间步的单元之间存在连接,形成一个沿时间维度展开的有向图。这种结构允许RNN捕捉序列中随时间变化的动态特征,这使得它非常适合时序数据相关的任务,如自然语言处理、语音识别、股票预测等。

RNN在深度强化学习中的应用
在深度强化学习(Deep Reinforcement Learning, DRL)中,RNN被用于解决具有时间依赖性的决策问题。例如,DRQN(Deep Recurrent Q-Learning Network)算法结合了RNN和Q-Learning,以处理在Atari游戏等环境中可能遇到的不完全信息问题。

RNN的变体
随着研究的深入,研究者们发现传统的RNN容易出现梯度消失或梯度爆炸的问题,这限制了模型处理长序列的能力。为了解决这一问题,人们提出了RNN的一些变体,最著名的包括长短期记忆网络(Long Short-Term Memory, LSTM)和门控循环单元(Gated Recurrent Unit, GRU)。这些变体通过引入门控机制来更有效地控制信息的流动,从而更好地学习长距离依赖。

RNN在MDP中的作用
在马尔可夫决策过程(Markov Decision Process, MDP)中,智能体在每个时间步需要根据当前的观测状态以及之前的历史状态来做出决策。RNN通过其内部状态的持续更新,使得智能体能够结合历史信息来进行当前的行为选择。

DI-engine对RNN的支持

DI-engine是一套深度强化学习框架,它支持RNN网络,并提供用户友好的API,使得研究者和开发者能够更容易地实现RNN及其变体。通过这些API,用户可以将RNN集成到他们的强化学习模型中,以解决需要处理序列数据的复杂任务。

DI-engine中的相关组件

这里我们简要的分析一下ding/torch_utils/network/rnn.pyrnn.py
主要功能是实现了不同类型的LSTM单元:

1.定义了一些工具函数:

  • is_sequence: 判断输入是否是列表或元组
  • sequence_mask: 根据序列长度生成掩码
  • LSTMForwardWrapper: 封装LSTM的前后处理逻辑

2.实现了三种LSTM单元:

  • LSTM: 自定义的LSTM单元,使用了LayerNorm
  • PytorchLSTM: 封装PyTorch中的nn.LSTM,格式化输入输出
  • GRU: 封装了nn.GRUCell,也格式化输入输出

3.get_lstm: 根据输入参数返回不同的LSTM单元实现

  • 支持’normal’,’pytorch’,’hpc’,’gru’四种类型
  • hpc类型需要调用HPC平台的实现,其它为普通PyTorch实现

4.每种LSTM单元都实现了forward函数,区别在于:

  • 输入输出格式化的不同
  • 是否使用了LayerNorm
  • 对于隐状态,可以返回Tensor或List两种格式

5.forward函数中会调用LSTMForwardWrapper的钩子函数进行输入输出封装处理
这样设计使得不同的LSTM实现可以通过统一的接口进行调用,隔离了输入输出格式的处理逻辑。该程序实现了灵活可配置的LSTM单元,通过组合PyTorch基础模块,提供了清晰和统一的接口。

LSTM类中forward函数的实现的非常优雅:

  1. 调用钩子函数进行输入状态的预处理,提高复用性
  2. 逐层、逐时间步执行LSTM计算流程,代码结构清晰
  3. 使用列表保存每一时间步的输出,最后stack起来
  4. 添加可配置的dropout操作
  5. 封装next_state的输出格式,提高灵活性

这样的实现既考虑了计算流程的清晰性,也提高了接口的灵活性,使得LSTM单元更易于复用和扩展。

def forward(self,            inputs: torch.Tensor,             prev_state: torch.Tensor,            list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]:    # 调用钩子函数进行输入状态的预处理    prev_state = self._before_forward(inputs, prev_state)      H, C = prev_state    x = inputs    next_state = []    for l in range(self.num_layers):        h, c = H[l], C[l]        new_x = []        for s in range(seq_len):            # 计算不同门的值            gate = ...             i, f, o, u = gate            # LSTM计算公式            c = f * c + i * u            h = o * torch.tanh(c)            new_x.append(h)        next_state.append((h, c))        x = torch.stack(new_x, dim=0)        # 添加dropout        if self.use_dropout and l != self.num_layers - 1:            x = self.dropout(x)    # 封装next_state的格式      next_state = self._after_forward(next_state, list_next_state)    return x, next_state

 DI-engine 中哪些策略支持RNN结构

点击DI-engine强化学习入门(十)如何使用RNN——模型构建和包装 - 古月居 可查看全文

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/576654
推荐阅读
  

闽ICP备14008679号