当前位置:   article > 正文

PyTorch权重初始化的几种方法_torch 权重初始化方案

torch 权重初始化方案

PyTorch在自定义变量及其初始化方法:

  1. self.fuse_weight_1 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
  2. self.fuse_weight_1.data.fill_(0.25)

如上是定义一个可学习的标量。也可以定义一个可学习的矩阵:

self.fuse_weight_1 = torch.nn.Parameter(torch.FloatTensor(torch.rand(3,3)), requires_grad=True)

PyTorch自定义卷积层初始化方法:

1.

  1. class Net(nn.Module):
  2. def __init__(self):
  3. super(Net, self).__init__()
  4. self.conv = nn.Sequential(
  5. nn.Conv2d(self.input_dim, 64, 4, 2, 1),
  6. nn.ReLU(),
  7. )
  8. self.fc = nn.Sequential(
  9. nn.Linear(32, 64 * (self.input_height // 2) * (self.input_width // 2)),
  10. nn.BatchNorm1d(64 * (self.input_height // 2) * (self.input_width // 2)),
  11. nn.ReLU(),
  12. )
  13. self.deconv = nn.Sequential(
  14. nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
  15. nn.Sigmoid(),
  16. )
  17. utils.initialize_weights(self)
  18. def forward(self, input):
  19. ...
  20. def initialize_weights(net):
  21. for m in net.modules():
  22. if isinstance(m, nn.Conv2d):
  23. m.weight.data.normal_(0, 0.02)
  24. m.bias.data.zero_()
  25. elif isinstance(m, nn.ConvTranspose2d):
  26. m.weight.data.normal_(0, 0.02)
  27. m.bias.data.zero_()
  28. elif isinstance(m, nn.Linear):
  29. m.weight.data.normal_(0, 0.02)
  30. m.bias.data.zero_()

2. 

  1. def init_weights(m):
  2. print(m)
  3. if type(m) == nn.Linear:
  4. m.weight.data.fill_(1.0)
  5. print(m.weight)
  6. net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
  7. net.apply(init_weights)

 3.

  1. def weights_init(m):
  2. classname = m.__class__.__name__
  3. if classname.find('Conv') != -1:
  4. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  5. m.weight.data.normal_(0, math.sqrt(2. / n))
  6. if m.bias is not None:
  7. m.bias.data.zero_()
  8. elif classname.find('BatchNorm') != -1:
  9. m.weight.data.fill_(1)
  10. m.bias.data.zero_()
  11. elif classname.find('Linear') != -1:
  12. m.weight.data.normal_(0, 0.01)
  13. m.bias.data = torch.ones(m.bias.data.size())
  14. net.apply(init_weights)

4.

  1. self.fuse_weight_1 = nn.Conv2d(1, 1, kernel_size=1, stride=1, bias=False)
  2. self.fuse_weight_1.weight.data.fill_(0.2)

 

 

参考:

1. torch.nn.init

2. Pytorch 细节记录

3. PyTorch参数初始化和Finetune

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/489277
推荐阅读
相关标签
  

闽ICP备14008679号