当前位置:   article > 正文

使用Python API实现TRT版BN/hswish/Silu等算子_tensorrt silu

tensorrt silu

BatchNorm2d的TRT实现

TensorRT未提供批量归一化层BatchNorm,但提供了更通用的Scale层。可以使用Scale层来实现BN层。

BatchNorm定义

PyTorch提供的BN层的定义,位于torch.nn.BatchNorm2d,公式已经在注释中说明,或者直接看Pytorch官方文档batchnorm2d也行:

y = x − E [ x ] V a r [ x ] + ϵ ∗ γ + β y= \frac{x-E[x]}{\sqrt{Var[x]+\epsilon}}*\gamma+\beta y=Var[x]+ϵ xE[x]γ+β

简单地, E [ x ] E[x] E[x]是batch的均值, V a r [ x ] Var[x] Var[x]是batch的方差, ϵ \epsilon ϵ为了防止除0, γ \gamma γ对应batch学习得到的权重, β \beta β就是偏置。

下面给出基于通道的BN定义:

B N [ i , : ] = i n [ i , : ] − m e a n [ i ] v a r [ i ] + ϵ ∗ γ [ i ] + β [ i ] BN[i,:]=\frac{in[i,:]-mean[i]}{\sqrt{var[i]+\epsilon}}*\gamma[i]+\beta[i] BN[i,:]=var[i]+ϵ in[i,:]mean[i]γ[i]+β[i]

在PyTorch中相对应的,对于任意一个bn层,它会有如下的结构:

weights  = torch.load(your_model_dict_state_path)
bn_gamma = weights['bn.weight'].numpy()
bn_beta  = weights['bn.bias'].numpy()
bn_mean  = weights['bn.running_mean'].numpy()
bn_var   = weights['bn.running_var'].numpy()
  • 1
  • 2
  • 3
  • 4
  • 5

TRT的Scale层定义

BN层中的乘法是对4维矩阵按通道数进行矩阵乘法,官方指南#16.1中提到,使用IElementWiseLayer构建,这样做太复杂,不推荐。
使用IElementWiseLayer实现batch normalization
本文推荐使用TRT API提供的IScaleLayer

S c a l e = ( i n ∗ s c a l e + s h i f t ) p o w e r Scale=(in*scale+shift)^{power} Scale=(inscale+shift)power

Scale实现BN

令:
s c a l e = γ v a r + ϵ s h i f t = − m e a n v a r + ϵ ∗ γ + β p o w e r = 1 scale=\frac{\gamma}{\sqrt{var+\epsilon}} \\[10pt] shift=-\frac{mean}{\sqrt{var+\epsilon}}*\gamma+\beta \\[5pt] power=1 scale=var+ϵ γshift=var+ϵ meanγ+βpower=1

#---------------BatchNorm层---------------
#获取训练后的BN相关参数
gamma = weights['bn.weight'].numpy()
beta = weights["bn.bias"].numpy()
mean = weights['bn.running_mean'].numpy()
var = weights['bn.running_var'].numpy()

#注意一定要设置为1e-3,如果设置为1e-4等更小的,输出的结果就和tf不完全一样,这是tf特色
scale = trt.Weights(gamma / np.sqrt(var + 1e-3))
shift = trt.Weights(beta-mean / np.sqrt(var + 1e-3) * gamma)
power = trt.Weights(np.ones(len(var), dtype=np.float32))

#添加BN层
bn = network.add_scale(layer.get_output(0), trt.ScaleMode.CHANNEL,shift, scale, power)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

Fused Batch Normalization

进一步,实际上卷积层和BN层在推理过程中是可以融合在一起的,简单来讲,卷积层的过程为:

z = w ∗ x + b z = w * x + b z=wx+b

这里的 z z z替换掉BN公式的 x x x就可以得到:

y = w ∗ γ V a r [ x ] + ϵ ∗ x + b − E [ x ] V a r [ x ] + ϵ ∗ γ + β y= \frac{w*\gamma}{\sqrt{Var[x]+\epsilon}}*x+\frac{b-E[x]}{\sqrt{Var[x]+\epsilon}}*\gamma+\beta y=Var[x]+ϵ wγx+Var[x]+ϵ bE[x]γ+β
当然这里也是矩阵操作。 w ∗ γ V a r [ x ] + ϵ \frac{w*\gamma}{\sqrt{Var[x]+\epsilon}} Var[x]+ϵ wγ就是新的 w w w b − E [ x ] V a r [ x ] + ϵ ∗ γ + β \frac{b-E[x]}{\sqrt{Var[x]+\epsilon}}*\gamma+\beta Var[x]+ϵ bE[x]γ+β就是新的 b b b了。

代码如下:

weights  = torch.load(your_model_dict_state_path)
conv_w   = weights['conv.weight'].numpy()
conv_b   = weights['conv.bias'].numpy()
bn_gamma = weights['bn.weight'].numpy()
bn_beta  = weights['bn.bias'].numpy()
bn_mean  = weights['bn.running_mean'].numpy()
bn_var   = weights['bn.running_var'].numpy()
eps      = 1e-05
bn_var   = np.sqrt(bn_var + eps)

fused_conv_w = conv_w * (bn_gamma / bn_var).reshape([conv_w.shape[0], 1, 1, 1])
fused_conv_b = (conv_b - bn_mean) / bn_var * bn_gamma + bn_beta
fused_conv   = network.add_convolution(input=last_layer.get_output(0), num_output_maps=your_conv_out, kernel_shape=(your_conv_kernel, your_conv_kernel), kernel=fused_conv_w, bias=fused_conv_b)
fused_conv.padding = (your_conv_pad, your_conv_pad)
fused_conv.stride  = (your_conv_stride, your_conv_stride)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

其中,conv是需要融合的卷积层,fused_conv是与bn融合后的卷积层,你需要规定fused_convconv拥有相同的参数(padding, stride, kernel_shape, num_output_maps)

BatchNorm1d的TRT实现

同样地,参考BatchNorm2d的实现方法,这里需要添加一个IShuffleLayer将1D的tensor转成2D,再在2D进行BN,最后转回1D,这里你需要规定输入tensor的大小,因为TRT在shuffle的时候需要知道该参数。大概的实现代码如下所示:

weights  = torch.load(your_model_dict_state_path)
bn_gamma = weights['bn.weight'].numpy()
bn_beta  = weights['bn.bias'].numpy()
bn_mean  = weights['bn.running_mean'].numpy()
bn_var   = weights['bn.running_var'].numpy()
eps      = 1e-05
bn_var   = np.sqrt(bn_var + eps)

bn_scale = bn_gamma / bn_var
bn_shift = - bn_mean / bn_var * bn_gamma + bn_beta

# reshape to 2D
shuffle  = network.add_shuffle(last_layer.get_output(0))
shuffle.reshape_dims = (your_input_shape, your_input_shape, 1)

# do bn1d
bn       = network.add_scale(input=shuffle.get_output(0), mode=trt.ScaleMode.CHANNEL, shift=bn_shift, scale=bn_scale)

# reshape to 1D
shuffle  = network.add_shuffle(bn.get_output(0))
shuffle.reshape_dims = (your_input_shape, your_input_shape, 1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

TRT实现hswish

参考PyTorch的hswish的实现:

class hswish(nn.Module):
    def forward(self, x):
        out = x * F.relu6(x + 3, inplace=True) / 6
        return out
  • 1
  • 2
  • 3
  • 4

参考relu6的公式:
R e L U 6 ( x ) = m i n ( m a x ( 0 , x ) , 6 ) ReLU6(x)=min(max(0,x),6) ReLU6(x)=min(max(0,x),6)
我们可以得到如下TRT的实现代码:

# x + 3
shape  = (1, ) * len(your_input_shape)
tensor = 3.0 * torch.ones(shape, dtype=trt.float32).cpu().numpy()
trt_3  = network.add_constant(shape, tensor)
tmp    = network.add_elementwise(last_layer.get_output(0), trt_3.get_output(0), trt.ElementWiseOperation.SUM)

# relu6(x + 3)
relu   = network.add_activation(input=tmp.get_output(0), type=trt.ActivationType.RELU)
shape  = (1, ) * len(your_input_shape)
tensor = 6.0 * torch.ones(shape, dtype=trt.float32).cpu().numpy()
trt_6  = network.add_constant(shape, tensor)
relu_6 = network.add_elementwise(relu.get_output(0), trt_6.get_output(0), trt.ElementWiseOperation.MIN)

# x * relu6(x + 3)
tmp    = network.add_elementwise(last_layer.get_output(0), tmp.get_output(0), trt.ElementWiseOperation.PROD)

# x * relu6(x + 3) / 6
out    = network.add_elementwise(tmp.get_output(0), trt_6.get_output(0), trt.ElementWiseOperation.DIV)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

TRT实现Silu算子

方法1:转ONNX时替换源码silu算子

当将模型转换为onnx模型时,会出现以下报错:

RuntimeError: Exporting the operator silu to ONNX opset version 11 is not supported. Please open a bug to request ONNX export support for the missing operator.
  • 1

原因是onnx1.8.0还未支持silu算子,可以修改torch源码。源码位置:{your_python_path}/lib/python3.7/site-packages/torch/nn/modules/activation.py
源代码:

class SiLU(Module):

    __constants__ = ['inplace']
    inplace: bool

    def __init__(self, inplace: bool = False):
        super(SiLU, self).__init__()
        self.inplace = inplace

    def forward(self, input: Tensor) -> Tensor:
        return F.silu(input, inplace=self.inplace)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

F.silu替换掉,修改为:

class SiLU(Module):

    __constants__ = ['inplace']
    inplace: bool

    def __init__(self, inplace: bool = False):
        super(SiLU, self).__init__()
        self.inplace = inplace

    def forward(self, input: Tensor) -> Tensor:
        # return F.silu(input, inplace=self.inplace)
        return input * torch.sigmoid(input)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

方法2:TRT API间接实现Silu

激活函数Silu

f ( x ) = x ⋅ σ ( x ) f ′ ( x ) = f ( x ) + σ ( x ) ( 1 − f ( x ) ) f(x)=x⋅\sigma(x)\\[5pt] f^{'}(x)=f(x)+\sigma(x)(1−f(x)) f(x)=xσ(x)f(x)=f(x)+σ(x)(1f(x))
Silu与Relu

图3 Silu与Relu

从上面公式可以看出来其实就是给sigmoid激活函数加了一个权重,这个权重恰恰就是输入。

同样,TensorRT中也没有直接提供Silu的api,通过add_activation配合add_elementwise中的乘操作可以轻松构建Silu

    sig = network.add_activation(bn1.get_output(0), trt.ActivationType.SIGMOID)
    silu = network.add_elementwise(bn1.get_output(0), sig.get_output(0), trt.ElementWiseOperation.PROD)
  • 1
  • 2

BN与track_running_stats

参考3. BN与track_running_stats_麦克斯韦恶魔CSDN

参考

Pytorch中的Batch Normalization操作

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/161704
推荐阅读
相关标签
  

闽ICP备14008679号