当前位置:   article > 正文

Pytorch学习-GRU使用

pytorch使用gru训练
  1. import torch.nn as nn
  2. import torch
  3. # gru = nn.GRU(input_size=50, hidden_size=50, batch_first=True)
  4. # embed = nn.Embedding(3, 50)
  5. # x = torch.LongTensor([[0, 1, 2]])
  6. # x_embed = embed(x)
  7. # out, hidden = gru(x_embed)
  8. gru = nn.GRU(input_size=5, hidden_size=6,
  9. num_layers=2, # gru层数
  10. batch_first=False, # 默认参数 True:(batch, seq, feature) FalseTrue:( seq,batch, feature),
  11. bidirectional=False, # 默认参数
  12. )
  13. # N=batch size
  14. # L=sequence length
  15. # D=2 if bidirectional=True else 1
  16. # Hin=input size
  17. # Hout=outout size
  18. input_ = torch.randn(1, 3, 5) # (L,N,hin)(序列长度,batch size大小,输入维度大小)
  19. h0 = torch.randn(2 * 1, 3, 6) # (D∗num_layers,N,Hout)(是否双向乘以层数,batch size大小,输出维度大小)
  20. output, hn = gru(input_, h0)
  21. # output:[1, 3, 6] (L,N,D*Hout)=(1,3,1*6)
  22. # hn:[2, 3, 6] (D*num_layers,N,Hout)(1*2,3,6)
  23. print(output.shape, hn.shape)
  24. # torch.Size([1, 3, 6]) torch.Size([2, 3, 6])
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/349360
推荐阅读
相关标签
  

闽ICP备14008679号