当前位置:   article > 正文

torch中的BatchNorm LayerNorm InstanceNorm使用方法_torch.batchnorm

torch.batchnorm

1. torch中标准normalization函数与自定义函数的对比,以说明标准库函数的使用方法。

同时,代码中对4维数据、3维数据以及2维数据进行了对比。

注意在2维数据的情况下,nn.InstanceNorm1d是不能求解的,不存在1维的Instancenorm.

  1. #-*- coding:utf-8 -*-
  2. #Author LJB Create on 2021/8/27
  3. import torch
  4. import torch.nn as nn
  5. import numpy as np
  6. #对于4维数据[batchsize,channel,H,W],其中[H,W]表示一个实例的二维特征维度
  7. #LayerNorm: axis=(1,2,3)
  8. #InstanceNorm: axis=(2,3)
  9. #BatchNorm: axis=(0,2,3)
  10. def MyNorm_4d(x, axis,gamma=1.0, beta=0.0):
  11. # x_shape:[B, C, H, W]
  12. results = 0.
  13. eps = 1e-5
  14. x_mean = np.mean(x, axis=axis, keepdims=True)
  15. x_var = np.var(x, axis=axis, keepdims=True)
  16. x_normalized = (x - x_mean) / np.sqrt(x_var + eps)
  17. results = gamma * x_normalized + beta
  18. return results
  19. #对于3维数据[batchsize,channel,d_model],d_model表示一个实例的一维特征维度
  20. #LayerNorm: axis=(1,2)
  21. #InstanceNorm: axis=(2)
  22. #BatchNorm: axis=(0,2)
  23. def MyNorm_3d(x, axis,gamma=1.0, beta=0.0):
  24. # x_shape:[B, C, D], D表示d_model
  25. results = 0.
  26. eps = 1e-5
  27. x_mean = np.mean(x, axis=axis, keepdims=True)
  28. x_var = np.var(x, axis=axis, keepdims=True)
  29. x_normalized = (x - x_mean) / np.sqrt(x_var + eps)
  30. results = gamma * x_normalized + beta
  31. return results
  32. #对于2维数据[batchsize,d_model],d_model表示一个实例的一维特征维度
  33. #LayerNorm: axis=(1)
  34. #InstanceNorm: axis=(1)
  35. #BatchNorm: axis=(0)
  36. def MyNorm_2d(x, axis,gamma=1.0, beta=0.0):
  37. # x_shape:[B, D], D表示d_model
  38. results = 0.
  39. eps = 1e-5
  40. x_mean = np.mean(x, axis=axis, keepdims=True)
  41. x_var = np.var(x, axis=axis, keepdims=True)
  42. x_normalized = (x - x_mean) / np.sqrt(x_var + eps)
  43. results = gamma * x_normalized + beta
  44. # print('x_mean x_std:',x_mean,np.sqrt(x_var + eps))
  45. return results
  46. if __name__=='__main__':
  47. #模拟数据
  48. batchsize = 3
  49. channel = 2
  50. H = 5
  51. W = 7
  52. data = torch.randint(0,9,(batchsize,channel,H,W)).float()
  53. #标准layernorm
  54. n1 = nn.LayerNorm(normalized_shape=[channel,H,W])(data) #参数为shape
  55. n2 = MyNorm_4d(data.numpy(),axis=(1,2,3)) #参数为坐标轴
  56. # print('+++ Normal Layernorm:',n1)
  57. # print('+++ My Layernorm:',n2)
  58. #当数据为4维时,一个实例是[H,W]形的矩阵,此时nn.LayerNorm也可以当InstanceNorm使用
  59. n3 = nn.LayerNorm(normalized_shape=[H,W])(data) #参数为shape
  60. n4 = MyNorm_4d(data.numpy(),axis=(2,3)) #参数为坐标轴
  61. #标准InstanceNorm
  62. n5 = nn.InstanceNorm2d(channel)(data) #参数为通道数
  63. # print('+++ Normal Instance(Layernorm):')
  64. # print(n3)
  65. # print('+++ My Instance norm:')
  66. # print(n4)
  67. # print('+++ Normal InstanceNorm:')
  68. # print(n5)
  69. n6 = nn.BatchNorm2d(channel)(data)
  70. n7 = MyNorm_4d(data.numpy(),axis=(0,2,3)) #参数为坐标轴
  71. # print('+++ Normal BatchNorm2d:')
  72. # print(n6)
  73. # print('+++ My batch norm:')
  74. # print(n7)
  75. ############################################################
  76. #对3维数据(形状为[batchsize,seqlen,d_model]时
  77. batchsize_3d = 5
  78. seqlen = 4
  79. d_model = 3
  80. data_3d = torch.randint(0,9,(batchsize_3d,seqlen,d_model)).float()
  81. # print('+++ Normal LayerNorm:')
  82. # print(nn.LayerNorm([seqlen,d_model])(data_3d)) #此时seqlen视为通道维
  83. # print('+++ My LayerNorm:')
  84. # print(MyNorm_3d(data_3d.numpy(),axis=(1,2)))
  85. #
  86. # print('+++ Normal InstanceNorm1d:')
  87. # print(nn.InstanceNorm1d(seqlen)(data_3d))
  88. # print('+++ Normal Instance(LayerNorm):')
  89. # print(nn.LayerNorm(d_model)(data_3d)) #此时seqlen视为通道维
  90. # print('+++ My Instance Norm:')
  91. # print(MyNorm_3d(data_3d.numpy(),axis=(2)))
  92. #
  93. # print('+++ Normal BatchNorm1d:')
  94. # print(nn.BatchNorm1d(seqlen)(data_3d))
  95. # print('+++ My Batch Norm:')
  96. # print(MyNorm_3d(data_3d.numpy(),axis=(0,2)))
  97. print('#'*80)
  98. batchsize_2d = 5
  99. d_model = 3
  100. data_2d = torch.randint(0,9,(batchsize_2d,d_model)).float()
  101. print('+++ Normal LayerNorm:')
  102. print(nn.LayerNorm([d_model])(data_2d))
  103. print('+++ My LayerNorm:')
  104. print(MyNorm_2d(data_2d.numpy(),axis=(1)))
  105. print('+++ Normal InstanceNorm1d: Error')
  106. #print(nn.InstanceNorm1d(1)(data_2d) #二维数据时,会报错
  107. print('+++ Normal Instance(LayerNorm):')
  108. print(nn.LayerNorm(d_model)(data_2d))
  109. print('+++ My Instance Norm:')
  110. print(MyNorm_2d(data_2d.numpy(),axis=(1)))
  111. print('+++ Normal BatchNorm1d:')
  112. print(nn.BatchNorm1d(d_model)(data_2d)) #二维数据时,用d_model,详看函数定义
  113. print('+++ My Batch Norm:')
  114. print(MyNorm_2d(data_2d.numpy(),axis=(0)))

2. 说明了在eval模式下,只有BatchNorm会屏蔽,其他Norm函数不会屏蔽

  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. class NormModel(nn.Module):
  5. def __init__(self,size):
  6. super(NormModel,self).__init__()
  7. self.layer_norm = nn.LayerNorm(size) #不受eval 影响
  8. self.instance_norm = nn.InstanceNorm2d(3)#不受eval影响
  9. self.batch_norm = nn.BatchNorm2d(3) #当模型为eval时,会屏蔽batchnorm
  10. norm = NormModel((5,7))
  11. data = torch.randint(0,9,(2,3,5,7)).float()
  12. # print(list(norm.named_parameters()))
  13. norm.eval()
  14. print(data)
  15. print('*'*50)
  16. print(norm.layer_norm(data))
  17. print('*'*50)
  18. print(norm.instance_norm(data))
  19. print('*'*50)
  20. print(norm.batch_norm(data))

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/606258
推荐阅读
相关标签
  

闽ICP备14008679号