赞
踩
torch.nn.BatchNorm2d:对输入batch的每一个特征通道进行normalize
【sample】
以input为2 x 3 x 4 x 5输入为例,其中
2:batch中样本数量
3:每个样本通道数
4:每个样本行数
5:每个样本列数
调用代码如下:
- bn = torch.nn.BatchNorm2d(3) # 参数3表示特征通道数
- out = bn(input)
input:
out:
【说明】
torch.nn.BatchNorm2d采用公式如下:
【计算过程】
torch.nn.BatchNorm2d对每一特征通道进行normalize,因此会计算出所有样本每一通道的均值和方差,以第一通道为例
- mean_channel1 = torch.mean(input[:, 0, :, :]) # = 142.8500
- normed_channel2 = torch.var(input[:,0, :, :], unbiased=False) # = 218.5775
- # unbiased参数需要设置为False,否则
- # 计算出方差为无偏估计,与当前结果不同
- # 第二、第三通道相同设置
- normed_result_channel1 = (input[:, 0, :, :]-mean_channel1)/((var_channel1+1e-5)**0.5)
输出结果如下,与torch.nn.BatchNorm2d结果一致
[[[-0.1251, -0.3280, 0.0101, 0.0778, 0.1454],
[-0.3280, -0.2604, 0.0101, 0.0101, 0.1454],
[-0.5310, -0.3957, -0.3280, -0.3957, 0.2131],
[ 6.0977, -0.3957, -0.8692, -0.5310, 0.0101]],
[[-0.1251, -0.3280, -0.1928, -0.3280, -0.2604],
[-0.1251, -0.1928, -0.1251, -0.0575, -0.1251],
[-0.1928, -0.1251, -0.1251, 0.0778, -0.0575],
[-0.1251, 0.0101, 0.0101, 0.0778, 0.0778]]]
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。