当前位置:   article > 正文

Pytorch学习1-GRU使用和参数说明_gru参数

gru参数
import torch.nn as nn
import torch

# gru = nn.GRU(input_size=50, hidden_size=50, batch_first=True)
# embed = nn.Embedding(3, 50)
# x = torch.LongTensor([[0, 1, 2]])
# x_embed = embed(x)
# out, hidden = gru(x_embed)


gru = nn.GRU(input_size=5, hidden_size=6,
             num_layers=2,  # gru层数
             batch_first=False,  # 默认参数 True:(batch, seq, feature) False:True:( seq,batch, feature),
             bidirectional=False,  # 默认参数
             )

# N=batch size
# L=sequence length
# D=2 if bidirectional=True else 1
# Hin=input size
# Hout=outout size


input_ = torch.randn(1, 3, 5)  # (L,N,hin)(序列长度,batch size大小,输入维度大小)
h0 = torch.randn(2 * 1, 3, 6)  # (D∗num_layers,N,Hout)(是否双向乘以层数,batch size大小,输出维度大小)

output, hn = gru(input_, h0)
# output:[1, 3, 6] (L,N,D*Hout)=(1,3,1*6)
# hn:[2, 3, 6] (D*num_layers,N,Hout)(1*2,3,6)

print(output.shape, hn.shape)
# torch.Size([1, 3, 6]) torch.Size([2, 3, 6])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/327561
推荐阅读
相关标签
  

闽ICP备14008679号