当前位置:   article > 正文

TBPTT算法——Truncated Backpropagation Through Time

truncated backpropagation through time

一、介绍

在训练循环神经网络RNN中,往往使用BPTT(Backpropagation Through Time)更新参数。关于BPTT的详细原理可以参考https://www.cntofu.com/book/85/dl/rnn/bptt.md

然而使用BPTT,对于单个参数的更新可能会带来很大的开销。例如对长度为1000的输入序列进行反向传播,其代价相当于1000层的神经网络进行前向后向传播。有两个方法对付这种庞大的开销问题:

  • 一种简单的解决方法是:将长度为1000的序列分成50个长度为20的序列,再对这20个序列进行单独训练。这种方法虽然可行,但是它忽略了每个拆分序列之间的依赖关系。
  • TBPTT:TBPTT和上面的方法类似。对于一个目标序列,每次处理一个时间步(timestep),当处理到k1个时间步时,使用BPTT后向传播k2个时间步。如果 k2 比较小,那么其计算开销就会降低。这样,它的每一个隐层状态可能经过多次时间步迭代计算产生的,也包含了更多更长的过去信息。在一定程度上,避免了上面方法中无法获取截断时间步之外信息的问题。

二、参数确定

TBPTT需要考虑两个参数:

  • k1:前向传播的时间步。一般来说,这个参数影响模型训练的快慢,即权重更新的频率。
  • k2:使用BPTT反向传播的时间步。一般来说,这个参数需要大一点,这样网络能更好的学习序列信息。但是这个参数太大的话可能会导致梯度消失。

参考Williams and Peng, An Efficient Gradient-Based Algorithm for On-Line Training of Recurrent Network Trajectories, 1990    确定k1,k2的取值(TBPTT(k1,k2))

这里,n代表序列的总长度。

  1. TBPTT(n, n): 在序列处理结束之后更新参数,即传统的BPTT;
  2. TBPTT(1, n): 每向前处理一个时间步,便后向传播所有已看到的时间步。(Williams and Peng提出的经典的TBPTT);
  3. TBPTT(k1,1): 因为每次反向传播一个时间步,所以网络并没有足够的时序上下文来学习,严重的依赖内部状态和输入;
  4. TBPTT(k1,k2):  k1 < k2 < n:  对于每个序列,都进行了多次更新,可以加速训练;
  5. TBPTT(k1,k2),:  k1=k2:对应上面提到的那种简单方法。

参考资料:

https://machinelearningmastery.com/gentle-introduction-backpropagation-time/

https://www.cnblogs.com/shiyublog/p/10542682.html#_label2

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/396503
推荐阅读
相关标签
  

闽ICP备14008679号