赞
踩
nn.dropout 每次将 p 元素设置为 0,剩下的元素乘以 1/(1-p)
eval()模式不进行dropout
使用方法如下:
- In [1]: import torch
-
- In [2]: p = 0.5
-
- In [3]: module = torch.nn.Dropout(p)
-
- In [4]: module.training
- Out[4]: True
-
- In [5]: inp = torch.ones(3,5)
-
- In [6]: print(inp)
- tensor([[1., 1., 1., 1., 1.],
- [1., 1., 1., 1., 1.],
- [1., 1., 1., 1., 1.]])
-
- In [7]: module(inp)
- Out[7]:
- tensor([[0., 0., 0., 2., 0.],
- [2., 2., 0., 0., 2.],
- [2., 0., 2., 2., 2.]])
-
- In [10]: 1/(1-p)
- Out[10]: 2.0
-
- In [11]: module.eval()
- Out[11]: Dropout(p=0.5, inplace=False)
-
- In [12]: module.training
- Out[12]: False
-
- In [13]: module(inp)
- Out[13]:
- tensor([[1., 1., 1., 1., 1.],
- [1., 1., 1., 1., 1.],
- [1., 1., 1., 1., 1.]])
- In [1]: import torch
-
- In [2]: module = torch.nn.Linear(10,20)
-
- In [3]: module
- Out[3]: Linear(in_features=10, out_features=20, bias=True)
-
- In [4]: n_samples = 40
-
- In [5]: inp_2d = torch.rand(n_samples,10)
-
- In [6]: module(inp_2d).shape
- Out[6]: torch.Size([40, 20])
-
- In [7]: inp_3d = torch.rand(n_samples,33,10)
-
- In [8]: module(inp_3d).shape
- Out[8]: torch.Size([40, 33, 20])
-
- In [10]: input_7d = torch.rand(n_samples,2,3,4,5,6,10)
-
- In [12]: module(input_7d).shape
- Out[12]: torch.Size([40, 2, 3, 4, 5, 6, 20])
可学习参数为0
- In [1]: import torch
-
- In [2]: inp = torch.tensor([[0,4.],[-1,7],[3,5]])
-
- In [3]: n_samples,n_features = inp.shape
-
- In [4]: print(n_samples)
- 3
-
- In [5]: module = torch.nn.LayerNorm(n_features, elementwise_affine=False)
-
- In [6]: sum(p.numel() for p in module.parameters() if p.requires_grad)
- Out[6]: 0
-
- In [7]: inp.mean(-1),inp.std(-1,unbiased=False)
- Out[7]: (tensor([2., 3., 4.]), tensor([2., 4., 1.]))
-
- In [8]: module(inp).mean(-1),module(inp).std(-1,unbiased=False)
- Out[8]:
- (tensor([ 0.0000e+00, -2.9802e-08, 1.1921e-07]),
- tensor([1.0000, 1.0000, 1.0000]))
可学习参数为4
- In [9]: module = torch.nn.LayerNorm(n_features, elementwise_affine=True)
-
- In [10]: sum(p.numel() for p in module.parameters() if p.requires_grad)
- Out[10]: 4
-
- In [11]: (module.bias,module.weight)
- Out[11]:
- (Parameter containing:
- tensor([0., 0.], requires_grad=True),
- Parameter containing:
- tensor([1., 1.], requires_grad=True))
-
- In [12]: module(inp).mean(-1),module(inp).std(-1,unbiased=False)
- Out[12]:
- (tensor([ 0.0000e+00, -2.9802e-08, 1.1921e-07], grad_fn=<MeanBackward1>),
- tensor([1.0000, 1.0000, 1.0000], grad_fn=<StdBackward1>))
-
- In [13]: module.bias.data += 1
-
- In [14]: module.weight.data *= 4
-
- In [15]: module(inp).mean(-1),module(inp).std(-1,unbiased=False)
- Out[15]:
- (tensor([1.0000, 1.0000, 1.0000], grad_fn=<MeanBackward1>),
- tensor([4.0000, 4.0000, 4.0000], grad_fn=<StdBackward1>))
只更新n_features,样本间独立,与n_samples无关
- In [16]: module(torch.rand(n_samples,2,3,4,5,6,n_features)).shape
- Out[16]: torch.Size([3, 2, 3, 4, 5, 6, 2])
-
- In [17]: module(torch.rand(n_samples,2,3,4,5,6,n_features)).mean(-1)
- Out[17]:
- tensor([[[[[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
- [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
- [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
- [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
- [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
-
- [[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
- [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
- [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
- [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
- [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]]]]],
- grad_fn=<MeanBackward1>)
-
- In [18]: module(torch.rand(n_samples,2,3,4,5,6,n_features)).std(-1,unbiased=False)
- Out[18]:
- tensor([[[[[[3.9998, 3.9962, 3.9993, 3.9966, 3.9657, 3.9954],
- [3.9998, 3.9997, 3.9996, 3.9981, 3.9981, 3.9983],
- [3.9799, 3.9996, 3.9657, 3.9998, 3.9998, 3.9986],
- [3.9990, 3.9950, 3.9885, 3.9996, 3.9582, 3.9996],
- [3.9987, 3.9989, 3.9900, 3.9992, 3.9992, 3.9994]],
-
- [[3.9996, 3.9996, 3.9972, 3.9931, 3.9998, 3.5468],
- [3.9998, 3.9808, 3.9974, 3.9985, 3.9992, 3.9986],
- [1.1207, 3.9993, 3.9998, 3.6870, 3.9997, 3.9981],
- [3.9998, 3.9998, 3.9986, 3.9676, 3.9999, 3.9998],
- [3.9993, 3.9999, 3.9998, 3.9852, 3.9993, 3.9891]]]]]],
- grad_fn=<StdBackward1>)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。