当前位置:   article > 正文

​超细节的BatchNorm/BN/LayerNorm/LN/SyncLN/ShuffleBN/RMSNorm知识点_mhsa中layernorm的输入参数

mhsa中layernorm的输入参数

Norm,也即 Normalization,已经是深度神经网络模型中非常常规的操作了,但它背后的实现,原理和作用等,其实我们可以理解的更细致,本文会以最常用的 BatchNorm 和 LayerNorm 为例(其他 Norm 方法大同小异),通过 Q&A 的形式,去深入理解关于 Norm 的细节知识点。

  1. BN 在训练和测试时的差异

  2. BN 中的移动平均 Moving Average 是怎么做的?

  3. 移动平均中 Momentum 参数的影响

  4. Norm 中的标准化、平移和缩放的作用

  5. 不同 Norm 方法中都有哪些参数要保存?

  6. BN 和 LN 有哪些差异?

  7. 为什么 BERT 使用 LN,而不使用 BN?

  8. 如何去理解在哪一个维度做 Norm?

所有文字不如代码准确,决定先上一个简化版的MyBN1d和MyLN镇楼:

MyBN1d:

  1. import torch.nn as nn
  2. import torch
  3. class MyBN1d(nn.Module):
  4. def __init__(self, momentum=0.1, eps=1e-5, feat_dim=2):
  5. super(MyBN1d, self).__init__()
  6. # 更新self._running_xxx时的动量
  7. self._momentum = momentum
  8. # 防止分母计算为0
  9. self._eps = eps
  10. # running_mean和running_var都是要存在模型weights里,但是不需要更新参数,所以self.register_buffer
  11. self.register_buffer('_running_mean', torch.zeros(1,feat_dim,1))
  12. self.register_buffer('_running_var', torch.ones(1,feat_dim,1))
  13. # weight和bias都是需要训练时候更新参数的
  14. self._weight = nn.Parameter(torch.ones(1,feat_dim,1))
  15. self._bias = nn.Parameter(torch.zeros(1,feat_dim,1))
  16. def forward(self, x):
  17. if self.training: #self.training是nn.Module自带参数,net.train()和net.eval()会改变这个值
  18. # x_mean = x.mean([0,2])
  19. # x_var = x.var([0,2], correction=0) #correction=0表示分母是不是归一化
  20. x_mean = x.mean(dim=(0,2), keepdims=True)
  21. x_var = x.var(dim=(0,2), keepdims=True, correction=0)
  22. # 对应running_mean的更新公式,下面注释是备用写法
  23. # self._running_mean = (1-self._momentum)*self._running_mean + self._momentum*x_mean
  24. # self._running_var = (1-self._momentum)*self._running_var + self._momentum*x_var
  25. self._running_mean -= self._momentum*(x_mean-self._running_mean)
  26. self._running_var -= self._momentum * (x_mean - self._running_var)
  27. # [None,:,None]不常看有点恶心,相当于x_mean.unsqueeze(0).unsqueeze(2),也就是把一个shape是(feat_dim)的Tensor变成了(1,feat_dim,1),以下是备用写法
  28. # x_hat = (x-x_mean[None,:,None])/torch.sqrt(x_var[None,:,None]+self._eps)
  29. x_hat = (x-x_mean)/torch.sqrt(x_var+self._eps)
  30. else:
  31. # 注意上面训练的时候不要用running_mean做差输出
  32. # x_hat = (x-self._running_mean[None,:,None])/torch.sqrt(self._running_var[None,:,None]+self._eps)
  33. x_hat = (x-self._running_mean)/torch.sqrt(self._running_var+self._eps)
  34. return self._weight*x_hat + self._bias
  35. #CV中的feat_num, num_features, hidden_size都是指中间那个维度
  36. feat_dim=3
  37. x = torch.randn(2,feat_dim,5)
  38. bn1d = nn.BatchNorm1d(feat_dim)
  39. out_bn1d = bn1d(x)
  40. print(out_bn1d)
  41. mybn1d = MyBN1d(feat_dim=feat_dim)
  42. out_mybn1d = mybn1d(x)
  43. print(out_mybn1d)
  44. '''
  45. 两个输出都是一样的:
  46. tensor([[[-0.9176, 1.4579, 0.2473, 0.7218, 1.0444],
  47. [-0.7802, 0.3168, 0.8793, -0.4985, 2.3281],
  48. [ 1.6315, 0.4317, 0.1170, -0.3623, -1.5179]],
  49. [[ 0.8724, 0.0181, -0.9045, -0.7437, -1.7963],
  50. [ 0.0180, -0.1456, -0.6095, -1.5661, 0.0578],
  51. [-1.6694, 0.4632, -0.6507, 0.9601, 0.5968]]],
  52. grad_fn=<NativeBatchNormBackward0>)
  53. tensor([[[-0.9176, 1.4579, 0.2473, 0.7218, 1.0444],
  54. [-0.7802, 0.3168, 0.8793, -0.4985, 2.3281],
  55. [ 1.6315, 0.4317, 0.1170, -0.3623, -1.5179]],
  56. [[ 0.8724, 0.0181, -0.9045, -0.7437, -1.7963],
  57. [ 0.0180, -0.1456, -0.6095, -1.5661, 0.0578],
  58. [-1.6694, 0.4632, -0.6507, 0.9601, 0.5968]]],
  59. grad_fn=<AddBackward0>)
  60. '''

MyLN:

  1. import torch
  2. from torch import nn
  3. class MyLN(nn.Module):
  4. def __init__(self, normalized_shape, # 在哪个维度上做LN
  5. eps: float = 1e-5, # 防止分母为0
  6. elementwise_affine: bool = True): # 是否使用可学习的缩放因子和偏移因子
  7. super(MyLN, self).__init__()
  8. # 需要对哪个维度的特征做LN, torch.size查看维度
  9. self.normalized_shape = normalized_shape # [c,w*h]
  10. self.eps = eps
  11. self.elementwise_affine = elementwise_affine
  12. # 构造可训练的缩放因子和偏置
  13. if self.elementwise_affine:
  14. self.weight = nn.Parameter(torch.ones(normalized_shape)) # [c,w*h]
  15. self.bias = nn.Parameter(torch.zeros(normalized_shape)) # [c,w*h]
  16. def forward(self, x: torch.Tensor): # [b,c,w*h]
  17. # 需要做LN的维度和输入特征图对应维度的shape相同
  18. assert self.normalized_shape == x.shape[-len(self.normalized_shape):] # [-2:]
  19. # 需要做LN的维度索引
  20. dims = [-(i + 1) for i in range(len(self.normalized_shape))] # [b,c,w*h]维度上取[-1,-2]维度,即[c,w*h]
  21. # 计算特征图对应维度的均值和方差
  22. mean = x.mean(dim=dims, keepdims=True) # [b,1,1]
  23. mean_x2 = (x ** 2).mean(dim=dims, keepdims=True) # [b,1,1]
  24. var = mean_x2 - mean ** 2 # [b,c,1,1]
  25. x_norm = (x - mean) / torch.sqrt(var + self.eps) # [b,c,w*h]
  26. # 线性变换
  27. if self.elementwise_affine:
  28. x_norm = self.weight * x_norm + self.bias # [b,c,w*h]
  29. return x_norm
  30. if __name__ == '__main__':
  31. x = torch.randn(2,3,5)
  32. my_ln = MyLN(x.shape[1:])
  33. print(my_ln(x))
  34. ln = nn.LayerNorm(x.shape[1:])
  35. print(ln(x))
  36. '''
  37. 两个输出都是一样的:
  38. tensor([[[-0.4581, 1.5668, 0.6686, -0.4423, -0.7992],
  39. [ 0.1808, 0.7245, 0.3380, -1.1207, 1.3641],
  40. [ 0.4380, -2.3911, -0.1461, -0.7776, 0.8543]],
  41. [[ 1.3458, 0.4072, -2.2993, -0.7033, -0.7776],
  42. [ 0.3645, 0.2430, 0.0801, 0.5956, 0.5822],
  43. [ 0.1153, 0.9114, -0.0091, 1.0779, -1.9336]]],
  44. grad_fn=<AddBackward0>)
  45. tensor([[[-0.4581, 1.5668, 0.6686, -0.4423, -0.7992],
  46. [ 0.1808, 0.7245, 0.3380, -1.1207, 1.3641],
  47. [ 0.4380, -2.3911, -0.1461, -0.7776, 0.8543]],
  48. [[ 1.3458, 0.4072, -2.2993, -0.7033, -0.7776],
  49. [ 0.3645, 0.2430, 0.0801, 0.5956, 0.5822],
  50. [ 0.1153, 0.9114, -0.0091, 1.0779, -1.9336]]],
  51. grad_fn=<NativeLayerNormBackward0>)
  52. '''

BN在训练和测试时的差异

对于 BN,在训练时,是对每一个 batch 的训练数据进行归一化,也即用每一批数据的均值和方差。

而在测试时,比如进行一个样本的预测,就并没有 batch 的概念,因此,这个时候用的均值和方差是在训练过程中通过滑动平均得到的均值和方差,这个会和模型权重一起,在训练完成后一并保存下来。

对于 BN,是对每一批数据进行归一化到一个相同的分布,而每一批数据的均值和方差会有一定的差别,而不是用固定的值,这个差别实际上也能够增加模型的鲁棒性,并会在一定程度上减少过拟合。

但是一批数据和全量数据的均值和方差相差太多,又无法较好地代表训练集的分布,因此,BN 一般要求将训练集完全打乱,并用一个较大的 batch 值,去缩小与全量数据的差别。

卷积层使用 Batch Normalization#

卷积层, 例如对于图像的卷积的时候, 我们往往不会考虑每一个像素, 注意, 实际上, 往往每一个像素是作为一个特征, 并且还有其 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/458191
推荐阅读
相关标签