当前位置:   article > 正文

Pytorch系列:(七)模型初始化_pytorch lstm初始化

pytorch lstm初始化

为什么要进行初始化

首先假设有一个两层全连接网络,第一层的第一个节点值为 H 11 = ∑ i = 0 n X i ∗ W 1 i H_{11}= \sum_{i=0}^n X_i*W_{1i} H11=i=0nXiW1i,

这个时候,方差为 D ( H 11 ) = ∑ i = 0 n D ( X i ) ∗ D ( W 1 i ) D(H_{11}) = \sum_{i=0}^n D(X_i) * D(W_{1i}) D(H11)=i=0nD(Xi)D(W1i), 这个时候,输入 X i X_i Xi一般会做归一化,那么其方差为1,而权重W如果不进行归一化的话,H的方差就会变得很大,然后多层累计,下一次的输入会越来越大,使得网络不好收敛,如果权重W进行了初始化,使得其方差保持在1/n附近,那么方差H则会收敛在1附近,从而使得网络变得更好优化。 很多初始化都是使用的这个原理,控制每一层的输出,使其保持在一定的范围内。

一些常见初始化方法

Xavier

Xavier初始化也是类似的原理, 假设输入X 以及做了归一化,其方差为1 ,那么Xavier所希望的就是上述公式D(H) 保持在1左右,那么就可以得到公式

H l a y e r 1 = ∑ i = 0 n D ( X i ) ∗ D ( W 1 i ) = n 1 ∗ D ( W ) = 1 H l a y e r 2 = ∑ i = 0 n D ( X i ) ∗ D ( W 1 i ) = n 2 ∗ D ( W ) = 1 H_{layer1} = \sum_{i=0}^n D(X_i) * D(W_{1i})=n_1 *D(W) = 1 \\ H_{layer2} =\sum_{i=0}^n D(X_i) * D(W_{1i}) = n_2 *D(W) = 1 Hlayer1=i=0nD(Xi)D(W1i)=n1D(W)=1Hlayer2=i=0nD(Xi)D(W1i)=n2D(W)=1

其中n1 和 n2 为网络层的输入输出节点数量,一般情况下,输入输出是不一样的,为了均衡考虑,可以做一个平均操作,于是变得到 D ( W ) = 2 n 1 + n 2 D(W) = \frac{2}{n_1+n_2} D(W)=n1+n22

这个时候,我们假设 W服从均匀分布 U [ − a , a ] U[-a, a] U[a,a], 那么在这个条件下,

D ( W ) = ( − a − a ) 2 12 = a 2 3 D(W) = \frac{(-a-a)^2}{12} = \frac{a^2}{3} D(W)=12(aa)2=3a2

推出 a = 6 n 1 + n 2 + 1 a = \frac{\sqrt{6}}{\sqrt{n_1+n_2+1}} a=n1+n2+1 6 ,从而得到:

W ∼ U [ − 6 n 1 + n 2 + 1 , 6 n 1 + n 2 + 1 ] W \sim U[-\frac{\sqrt{6}}{\sqrt{n_1+n_2+1}},\frac{\sqrt{6}}{\sqrt{n_1+n_2+1}}] WU[n1+n2+1 6 ,n1+n2+1 6 ]

这样就可以得到Xavier初始化,在pytorch中使用Xavier初始化方式如下,值得注意的是,Xavier对于sigmoid和tanh比较好,对于其他的可能效果就不是那么好了

nn.init.xavier_uniform_(m.weight.data) 
  • 1

Kaiming

Kaiming 初始化比较适合ReLU激活函数,其原理也跟上述差不多,也是希望将权重的方差保持在一定的范围内,使得正反向传播的值得到有效的控制,在kaiming初始化中,主要将权重的方差设置为 D ( w ) = 2 n i D(w) = \frac{2}{ni} D(w)=ni2,由于考虑到ReLU激活函数,将方差调整为 D ( w ) = 2 ( 1 + a 2 ) ∗ n i D(w)= \frac{2}{(1+a^2)*n_i} D(w)=(1+a2)ni2, 这里的a是ReLU的斜率。

在pytorch中使用Kaiming初始化

nn.init.kaiming_normal_(m.weight.data)
  • 1

LSTM初始化

LSTM中,公式和参数值的设定如下所示

在LSTM中,由于很多门控的权重尺寸是一样的,所以可以使用如下方法进行初始化

def _init_lstm(self, weight):
    for w in weight.chunk(4, 0):
        init.xavier_uniform(w)
        
self._init_lstm(self.lstm.weight_ih_l0)
self._init_lstm(self.lstm.weight_hh_l0)
self.lstm.bias_ih_l0.data.zero_()
self.lstm.bias_hh_l0.data.zero_()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

Embedding进行初始化

self.embedding = nn.Embedding(embedding_tokens, embedding_features, padding_idx=0)
init.xavier_uniform(self.embedding.weight)
  • 1
  • 2

其他通用初始化方法

遍历初始化

for name, param in net.named_parameters():
    if 'weight' in name:
        init.normal_(param, mean=0, std=0.01)
        print(name, param.data)
        
for name, param in net.named_parameters():
    if 'bias' in name:
        init.constant_(param, val=0)
        print(name, param.data)
        
        
## 通过instance 初始化
for m in self.children():
    if isinstance(m, nn.Linear):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, -100)
    # 也可以判断是否为conv2d,使用相应的初始化方式 
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
     
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.item(), 1)
        nn.init.constant_(m.bias.item(), 0)   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

直接使用pytorch内置初始化

from torch.nn import init 

init.normal_(net[0].weight, mean=0, std=0.01) 

init.constant_(net[0].bias, val=0)
  • 1
  • 2
  • 3
  • 4
  • 5

自带初始化方法中,会自动消除梯度反向传播,但是手动情况下必须自己设定

def no_grad_uniform(tensor, a, b):

  with torch.no_grad():

    return tensor.uniform_(a, b)
  • 1
  • 2
  • 3
  • 4
  • 5

使用apply进行初始化

批量初始化方法,注意net里面的apply函数,可以作用网络的所有module

def weights_init(m):                                               # 1

  classname = m.__class__.__name__                             # 2

  if classname.find('Conv') != -1:                               # 3

    nn.init.kaiming_normal_(m.weight.data)                  # 4

  elif classname.find('BatchNorm') != -1:                        # 5

    nn.init.normal_(m.weight.data, 1.0, 0.02)                  # 6

    nn.init.constant_(m.bias.data, 0)                          # 7 

net.apply(weights_init)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/347497
推荐阅读
相关标签
  

闽ICP备14008679号