当前位置:   article > 正文

PyTorch的nn.Linear()详解_pytorch nn.linear

pytorch nn.linear

参考链接PyTorch的nn.Linear()详解 - douzujun - 博客园 (cnblogs.com)

这里演示了二维张量的全连接 :

 其实还可以输入三维张量,演示如下:

  1. from torch import nn
  2. import torch
  3. # in_features由输入张量的形状决定,out_features则决定了输出张量的形状
  4. linear = nn.Linear(in_features=64 * 3, out_features=5)
  5. # 10个 大小为7*64*3, 3个channel 的张量
  6. a = torch.rand(10, 3, 7, 64 * 3)
  7. print(a.shape) # torch.Size([10, 3, 7, 192])
  8. print(linear.weight.shape) # torch.Size([5, 192])
  9. b = linear(a)
  10. print(b.shape) # torch.Size([10, 3, 7, 5])

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/290994
推荐阅读
相关标签
  

闽ICP备14008679号