当前位置:   article > 正文

pytorch之BatchNorm_pytorch batchnorm

pytorch batchnorm

为了解决 Internal Covariate Shift问题引入,该问题具体表现为:

  • 中间层输入分布总是变化,增加了模型拟合的难度。
  • 中间层输入分布会使输出逐渐靠近激活函数梯度较小的地方,导致梯度消失

BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的,在输入到激活函数之前,对每个mini-batch输入,做如下处理:
在这里插入图片描述

整个的过程如下图所示,训练过程为第1-6行,这里需要注意BN层引入了可训练参数的 β \beta β γ \gamma γ

假设有K个激活函数,每个mini-batch在传入每个激活函数之前,需要经过一个BN层

inference的过程为第7-11行,inference过程输出应该完全由输入决定,使用固定的均值和方差(对于只有一个测试集,也无法统计均值和方差)。

这里固定的均值和方差,通过使用对应位置的(某个激活函数之前)所有mini-batch的均值的期望作为均值,所有mini-batch的方差的无偏估计作为方差(第10行)

在这里插入图片描述

图片来自论文Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

pytorch中包含多种计算BatchNorm方法,下面将一一介绍

BatchNorm1d

y = x − E [ x ] Var ⁡ [ x ] + ϵ ∗ γ + β y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta y=Var[x]+ϵ xE[x]γ+β

β \beta β γ \gamma γ的初始值为0,1。

传入BatchNorm1d的参数如下:

  • num_features 需要进行Normalization的那个维度大小 C C C,输入的维度可以是 ( N , C ) (N,C) (N,C) ( N , C , L ) (N,C,L) (N,C,L)
  • eps 上述公式中的 ϵ \epsilon ϵ,防止分母为0,默认为1e-5
  • momentum 用于计算inference过程的均值和方差(滑动平均),默认为0.1
  • affine 是否使用可训练参数的 β \beta β γ \gamma γ,默认为True
  • track_running_stats 是否记录上一个mini-batch的均值running_mean和方差running_var 。如果是,通过滑动平均的方式得到inference时的均值和方差;如果不是,则需要在inference时重新计算。

track_running_stats=True 时,momentum 使用如下:

running_mean = momentum * running_mean + (1 - momentum) * x_mean
running_var = momentum * running_var + (1 - momentum) * x_var
  • 1
  • 2

inference直接使用最后的running_meanrunning_var 作为固定的均值和方差

import torch
import torch.nn as nn
import numpy as np
import math


def validation(x):
    """
    验证函数
    :param x:
    :return:
    """
    x = np.array(x)
    avg = np.mean(x, axis=0)
    std2 = np.var(x, axis=0)

    x_avg = [[item for item in avg] for _ in range(x.shape[0])]
    x_std = [[math.pow(item, 1 / 2) for item in std2] for _ in range(x.shape[0])]
    x_ = (x - np.array(x_avg)) / np.array(x_std)
    return x_


# With Learnable Parameters
# m = nn.BatchNorm1d(4)
# Without Learnable Parameters
m = nn.BatchNorm1d(4, affine=False)
x = [[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]]
input = torch.tensor(x, dtype=torch.float)
output = m(input)
print(output)

res = validation(x)
print("valiadation:\n", res)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

输出:

tensor([[-1.2247, -1.2247, -1.2247, -1.2247],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.2247,  1.2247,  1.2247,  1.2247]])
valiadation:
 [[-1.22474487 -1.22474487 -1.22474487 -1.22474487]
 [ 0.          0.          0.          0.        ]
 [ 1.22474487  1.22474487  1.22474487  1.22474487]]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

也可以输入三维张量:

m = nn.BatchNorm1d(4, affine=False)
x = [[[1, 1], [2, 2], [3, 3], [4, 4]], [[2, 2], [3, 3], [4, 4], [5, 5]], [[3, 3], [4, 4], [5, 5], [6, 6]]]
input = torch.tensor(x, dtype=torch.float)
print(input.size())
output = m(input)
print(output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

输出:

torch.Size([3, 4, 2])
tensor([[[-1.2247, -1.2247],
         [-1.2247, -1.2247],
         [-1.2247, -1.2247],
         [-1.2247, -1.2247]],

        [[ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]],

        [[ 1.2247,  1.2247],
         [ 1.2247,  1.2247],
         [ 1.2247,  1.2247],
         [ 1.2247,  1.2247]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

BatchNorm2d

同上BatchNorm1d
y = x − E [ x ] Var ⁡ [ x ] + ϵ ∗ γ + β y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta y=Var[x]+ϵ xE[x]γ+β

β \beta β γ \gamma γ的初始值为0,1。

  • num_features 需要进行Normalization的那个维度大小 C C C,输入的维度可以是 ( N , C , H , W ) (N, C, H, W) (N,C,H,W)

BatchNorm3d

同上BatchNorm1d
y = x − E [ x ] Var ⁡ [ x ] + ϵ ∗ γ + β y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta y=Var[x]+ϵ xE[x]γ+β

β \beta β γ \gamma γ的初始值为0,1。

  • num_features 需要进行Normalization的那个维度大小 C C C,输入的维度可以是 ( N , C , D , H , W ) (N, C, D, H, W) (N,C,D,H,W)
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/335120
推荐阅读
相关标签
  

闽ICP备14008679号