赞
踩
参考链接PyTorch的nn.Linear()详解 - douzujun - 博客园 (cnblogs.com)
这里演示了二维张量的全连接 :
其实还可以输入三维张量,演示如下:
- from torch import nn
- import torch
-
- # in_features由输入张量的形状决定,out_features则决定了输出张量的形状
- linear = nn.Linear(in_features=64 * 3, out_features=5)
-
- # 10个 大小为7*64*3, 3个channel 的张量
- a = torch.rand(10, 3, 7, 64 * 3)
-
- print(a.shape) # torch.Size([10, 3, 7, 192])
-
- print(linear.weight.shape) # torch.Size([5, 192])
-
- b = linear(a)
-
- print(b.shape) # torch.Size([10, 3, 7, 5])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。