赞
踩
BatchNorm, 批规范化,主要用于解决协方差偏移问题,主要分三部分:
算法内容如下:
需要说明几点:
以BatchNorm2d为例,分析其中变量和参数的意义:
affine: 仿射的开关,决定是否使用仿射这个过程。
training:模型为训练状态和测试状态时的运行逻辑是不同的。
track_running_stats: 决定是否跟踪整个训练过程中的batch的统计特性,而不仅仅是当前batch的特性。
num_batches_tracked:如果设置track_running_stats为真,这个就会起作用,代表跟踪的batch个数,即统计了多少个batch的特性。
momentum: 滑动平均计算running_mean和running_var
x ^ new = ( 1 − \hat{x}_{\text {new }}=(1- x^new =(1− momentum ) × x ^ + ) \times \hat{x}+ )×x^+ momentum × x t \times x_{t} ×xt
class _NormBase(Module): """Common base of _InstanceNorm and _BatchNorm""" _version = 2 __constants__ = ['track_running_stats', 'momentum', 'eps', 'num_features', 'affine'] def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(_NormBase, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats if self.affine: self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) if self.track_running_stats: self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) else: self.register_parameter('running_mean', None) self.register_parameter('running_var', None) self.register_parameter('num_batches_tracked', None) self.reset_parameters()
training和tracking_running_stats有四种组合:
更新过程:
参考文献:
https://blog.csdn.net/LoseInVain/article/details/86476010
https://blog.csdn.net/yangwangnndd/article/details/94901175
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。