当前位置:   article > 正文

Pytorch实现RNN、LSTM和GRU等经典循环网络模型,简直不能再简单。_利用pytorch实现rnn网络、gru网络和lstm网络

利用pytorch实现rnn网络、gru网络和lstm网络


惊呆了,居然只是一行代码的事
注:文中图片均来自台大李宏毅教授的PPT

1.RNN

RNN中文名字叫做循环神经网络,在连续状态、时间序列数据方面具有很大优势,但是由于存在长序列训练过程中的梯度消失和梯度爆炸问题,才会有LSTM、GRU等新的网络结构被提出。
在这里插入图片描述

pytorch中使用RNN方法如下:

nn.RNNCell(input_size, hidden_size, bias=True, nonlinearity=‘tanh’)
  • 1

也可直接定义网络结构

torch.nn.RNN(*args, **kwargs)
  • 1

2.LSTM

LSTM网络中引入了两个新的变量,一个负责遗忘,一个负责记忆。
在这里插入图片描述

LSTM代码实现如下:

torch.nn.LSTMCell(input_size, hidden_size, bias=True)
  • 1

也可以直接定义网络

torch.nn.LSTM(*args, **kwargs)
  • 1

3.GRU

为了进一步简化参数,一种新的GRU(Gate Recurrent Unit)被提出。网络中负责记忆的变量由两个变为一个。
在这里插入图片描述

原文地址如下:
https://arxiv.org/pdf/1412.3555.pdf
GRU代码实现如下:

torch.nn.GRUCell(input_size, hidden_size, bias=True)
  • 1

也可以直接定义网络:

torch.nn.GRU(*args,**kwargs)
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/373262
推荐阅读
相关标签
  

闽ICP备14008679号