当前位置:   article > 正文

pytorch 中的 torch.nn.RNN 的参数_pytorch rnn参数

pytorch rnn参数

1、定义RNN的网络结构的参数(类似于CNN中定义 in_channel,out_channel,kernel_size等等)

         input_size   输入x的特征大小(以mnist图像为例,特征大小为28*28 = 784)
         hidden_size   隐藏层h的特征大小
         num_layers    循环层的数量(RNN中重复的部分)
         nonlinearity   激活函数 默认为tanh,可以设置为relu
         bias   是否设置偏置,默认为True
         batch_first   默认为false, 设置为True之后,输入输出为(batch_size, seq_len, input_size)
         dropout   默认为0
         bidirectional   默认为False,True设置为RNN为双向

【注】下图红色的部分为num_layers的个数(num_layers = 2 )

 

 

2、输入RNN网络与输出的参数

(1)输入:input:(seq_len,batch_size,input_size)    #(序列长度,batch_size,特征大小(数量))

                    h0:(num_layers*directions,batch_size,hidden_size)

(2)输出:hn:(num_layers*directions,batch_size,hidden_size)

                    output:(seq_len,batch_size,hidden_size*directions)

【注】bidirectional为Ture,则 directions=2,否则 directions=1 。

 

RNN的一个解析:https://www.jianshu.com/p/298116084ec7

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/926242
推荐阅读
相关标签
  

闽ICP备14008679号