当前位置:   article > 正文

PyTorch基础——torch.nn.BatchNorm2d

nn.batchnorm2d

torch.nn.BatchNorm2d:对输入batch的每一个特征通道进行normalize

【sample】
以input为2 x 3 x 4 x 5输入为例,其中
     2:batch中样本数量
     3:每个样本通道数
     4:每个样本行数
     5:每个样本列数
调用代码如下:

  1. bn = torch.nn.BatchNorm2d(3) # 参数3表示特征通道数
  2. out = bn(input)

input:

out:

【说明】
torch.nn.BatchNorm2d采用公式如下:

y=xE[x]Var[x]+ϵγ+β,其中

y表示输出,x为输入,E[x]表示输入的均值,Var[x]x的方差,ϵ默认值为1e5γ为1,β为0

 

【计算过程】
torch.nn.BatchNorm2d对每一特征通道进行normalize,因此会计算出所有样本每一通道的均值和方差,以第一通道为例

  1. mean_channel1 = torch.mean(input[:, 0, :, :]) # = 142.8500
  2. normed_channel2 = torch.var(input[:,0, :, :], unbiased=False) # = 218.5775
  3. # unbiased参数需要设置为False,否则
  4. # 计算出方差为无偏估计,与当前结果不同
  5. # 第二、第三通道相同设置
  6. normed_result_channel1 = (input[:, 0, :, :]-mean_channel1)/((var_channel1+1e-5)**0.5)

输出结果如下,与torch.nn.BatchNorm2d结果一致
[[[-0.1251, -0.3280,  0.0101,  0.0778,  0.1454],
  [-0.3280, -0.2604,  0.0101,  0.0101,  0.1454],
  [-0.5310, -0.3957, -0.3280, -0.3957,  0.2131],
  [ 6.0977, -0.3957, -0.8692, -0.5310,  0.0101]],

 [[-0.1251, -0.3280, -0.1928, -0.3280, -0.2604],
  [-0.1251, -0.1928, -0.1251, -0.0575, -0.1251],
  [-0.1928, -0.1251, -0.1251,  0.0778, -0.0575],
  [-0.1251,  0.0101,  0.0101,  0.0778,  0.0778]]]

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/606376
推荐阅读
相关标签
  

闽ICP备14008679号