当前位置:   article > 正文

bn层学习笔记 卷积层和BN层融合

bn层

目录

详细解释卷积神经网络CNN中卷积层以及BN层的参数

bn层参数

当卷积层后跟batch normalization层时为什么不要偏置b

卷积层和BN层融合

合并代码1:

合并代码2:


详细解释卷积神经网络CNN中卷积层以及BN层的参数

详细解释卷积神经网络CNN中卷积层以及BN层的参数_bn层eval后-CSDN博客

BN层参数

在PyTorch中,批量归一化(Batch Normalization, BN)层是通过torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d等类实现的,针对不同维度的输入数据。这些BN层在训练过程中可以调整多个参数和超参数,以控制层的行为和性能。以下是一些主要的参数:

num_features: 输入特征的数量。对于BatchNorm1d,它是特征的维度;对于BatchNorm2d,它是特征图(feature maps)的数量,即输入张量的通道数。

eps: 用于避免除以零的小值,加在标准差上。默认值通常很小,比如1e-5。

momentum: 用于计算运行(moving)平均和方差的值。这决定了历史信息的权重,与新信息的平衡。默认值通常为0.1。

affine: 布尔值,指定是否对归一化的输出应用可学习的仿射变换(即乘以“gamma”(权重)并加上“beta”(偏置))。默认为True。

track_running_stats: 布尔值,指定是否跟踪整个训练集上的运行平均和方差。在训练模式下,这些统计信息会更新;在评估模式下,会使用这些统计信息进行归一化。默认为True。

除了这些参数外,BatchNorm层在训练过程中自动学习两个重要的参数:

权重(gamma): 归一化值的缩放参数,仅当affine=True时学习。
偏置(beta): 归一化值的偏移参数,仅当affine=True时学习。
调整这些参数和超参数可以影响模型的学习能力和最终性能。例如,较小的momentum值会使运行平均和方差对新批次的数据更敏感,而较大的值则使模型更稳定但可能对新数据的适应性较差。调整eps可以帮助避免数值稳定性问题,尤其是在使用较深的网络或较小的批次大小时。通过affine和track_running_stats选项,你可以控制批量归一化层的行为,以适应特定的训练或评估需求。

当卷积层后跟batch normalization层时为什么不要偏置b

当卷积层后跟batch normalization层时为什么不要偏置b_为什么batchnorm前不使用偏置-CSDN博客

卷积层和BN层融合

解释也不错:

深度学习推理时融合BN,轻松获得约5%的提速 - osc_s7aj86hu的个人空间 - OSCHINA - 中文开源技术交流社区

跟博士请教,分组卷积可以合并,如果是独立卷积,bn是通道的bn,可能不能合并?

1.  为什么要合并BN层

在训练深度网络模型时,BN(Batch Normalization)层能够加速网络收敛,并且能够控制过拟合,一般放在卷积层之后。BN 层将数据归一化后,能够有效解决梯度消失与梯度爆炸问题。虽然 BN 层在训练时起到了积极作用,然而,在网络前向推断时多了一些层的运算,影响了模型的性能,且占用了更多的内存或者显存空间。目前,很多先进的网络模型(ResNet,MobileNet,Xception,ShuffleNet 等)都使用了BN技术,因此,我们有必要将 BN 层的参数合并到卷积层,来提升模型前向推断的速度。

2.  BN层与卷积层合并的数学原理

则有:

合并后:

3.  实验结果

机器:显卡 GTX 1080Ti,i7 CPU

本实验对比了Resnet50 模型合并BN层前后的性能,分类精度保持不变,速度显著提升。

模型    CPU前向时间    GPU前向时间
Resnet50(合并前)    176.17ms    11.03ms
Resnet50(合并后)    161.69ms    7.3ms
提升    8.96%    33.27%
————————————————
版权声明:本文为CSDN博主「小麦草」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/kangdi7547/article/details/81348254

https://github.com/vietnamican/conv-bn-merge/blob/main/convbnmerge/convbnmerge.py

合并代码1:

PyTorch中BN层与CONV层的融合(merge_bn)_pytorch融合bn训练-CSDN博客

  1. import torch
  2. import os
  3. from collections import OrderedDict
  4. import cv2
  5. import numpy as np
  6. import torchvision.transforms as transforms
  7. """ Parameters and variables """
  8. IMAGENET = '/home/zym/ImageNet/ILSVRC2012_img_val_256xN_list.txt'
  9. LABEL = '/home/zym/ImageNet/synset.txt'
  10. TEST_ITER = 10
  11. SAVE = False
  12. TEST_AFTER_MERGE = True
  13. """ Functions """
  14. def merge(params, name, layer):
  15. # global variables
  16. global weights, bias
  17. global bn_param
  18. if layer == 'Convolution':
  19. # save weights and bias when meet conv layer
  20. if 'weight' in name:
  21. weights = params.data
  22. bias = torch.zeros(weights.size()[0])
  23. elif 'bias' in name:
  24. bias = params.data
  25. bn_param = {}
  26. elif layer == 'BatchNorm':
  27. # save bn params
  28. bn_param[name.split('.')[-1]] = params.data
  29. # running_var is the last bn param in pytorch
  30. if 'running_var' in name:
  31. # let us merge bn ~
  32. tmp = bn_param['weight'] / torch.sqrt(bn_param['running_var'] + 1e-5)
  33. weights = tmp.view(tmp.size()[0], 1, 1, 1) * weights
  34. bias = tmp*(bias - bn_param['running_mean']) + bn_param['bias']
  35. return weights, bias
  36. return None, None
  37. """ Main functions """
  38. # import pytorch model
  39. import models.shufflenetv2.shufflenetv2_merge as shufflenetv2
  40. pytorch_net = shufflenetv2.ShuffleNetV2().eval()
  41. model_path = shufflenetv2.weight_file
  42. # load weights
  43. print('Finding trained model weights...')
  44. try:
  45. for file in os.listdir(model_path):
  46. if 'pth' in file:
  47. print('Loading weights from %s ...' % file)
  48. trained_weights = torch.load(os.path.join(model_path, file))
  49. # pytorch_net.load_state_dict(trained_weights)
  50. print('Weights load success')
  51. break
  52. except:
  53. raise ValueError('No trained model found or loading error occurs')
  54. # go through pytorch net
  55. print('Going through pytorch net weights...')
  56. new_weights = OrderedDict()
  57. inner_product_flag = False
  58. for name, params in trained_weights.items():
  59. if len(params.size()) == 4:
  60. _, _ = merge(params, name, 'Convolution')
  61. prev_layer = name
  62. elif len(params.size()) == 1 and not inner_product_flag:
  63. w, b = merge(params, name, 'BatchNorm')
  64. if w is not None:
  65. new_weights[prev_layer] = w
  66. new_weights[prev_layer.replace('weight', 'bias')] = b
  67. else:
  68. # inner product layer
  69. # if meet inner product layer,
  70. # the next bias weight can be misclassified as 'BatchNorm' layer as len(params.size()) == 1
  71. new_weights[name] = params
  72. inner_product_flag = True
  73. # align names in new_weights with pytorch model
  74. # after move BatchNorm layer in pytorch model,
  75. # the layer names between old model and new model will mis-align
  76. print('Aligning weight names...')
  77. pytorch_net_key_list = list(pytorch_net.state_dict().keys())
  78. new_weights_key_list = list(new_weights.keys())
  79. assert len(pytorch_net_key_list) == len(new_weights_key_list)
  80. for index in range(len(pytorch_net_key_list)):
  81. new_weights[pytorch_net_key_list[index]] = new_weights.pop(new_weights_key_list[index])
  82. # save new weights
  83. if SAVE:
  84. torch.save(new_weights, model_path + '/' + file.replace('.pth', '_merged.pth'))
  85. # test merged pytorch model
  86. if TEST_AFTER_MERGE:
  87. try:
  88. pytorch_net.load_state_dict(new_weights)
  89. print('Pytorch net load weights success~')
  90. except:
  91. raise ValueError('Load new weights error')
  92. print('-' * 50)
  93. with open(LABEL) as f:
  94. labels = f.read().splitlines()
  95. with open(IMAGENET) as f:
  96. images = f.read().splitlines()
  97. for _ in range(TEST_ITER):
  98. # cv2 default chann el is BGR
  99. image_path, label = images[np.random.randint(0, len(images))].split(' ')
  100. # image_path, label = images[0].split(' ')
  101. input_image = cv2.imread(image_path)
  102. input_image = cv2.resize(input_image, (224, 224))
  103. input_image = transforms.Compose([transforms.ToTensor(),
  104. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  105. std=[0.229, 0.224, 0.225])
  106. ])(input_image)
  107. input_image = input_image.view(1, 3, 224, 224)
  108. output_logits = pytorch_net(input_image)
  109. _, index = output_logits.max(dim=1)
  110. print('true label: \t%s' % labels[int(label)])
  111. print('predict label:\t%s' % labels[int(index)])
  112. print('-' * 50)

合并代码2:

https://github.com/owphoo/pytorch_merge_bn/blob/master/pytorch_merge_bn.py

  1. import torch
  2. import os
  3. from collections import OrderedDict
  4. import numpy as np
  5. global merged
  6. merged = True
  7. def merge(params, name, layer, deconv_layer_names=['deconv']):
  8. # global variables
  9. global weights, bias
  10. global bn_param
  11. global merged
  12. is_deconv = False
  13. for deconv_name in deconv_layer_names:
  14. if deconv_name in name:
  15. is_deconv = True
  16. break
  17. if layer == 'Convolution':
  18. # save weights and bias when meet conv layer
  19. if 'weight' in name:
  20. weights = params.data
  21. bias = torch.zeros(weights.size()[0], device=weights.device)
  22. if is_deconv:
  23. bias = torch.zeros(weights.size()[1], device=weights.device)
  24. else:
  25. bias = torch.zeros(weights.size()[0], device=weights.device)
  26. merged = False
  27. elif 'bias' in name:
  28. bias = params.data
  29. bn_param = {}
  30. elif layer == 'BatchNorm':
  31. # save bn params
  32. bn_param[name.split('.')[-1]] = params.data
  33. # running_var is the last bn param in pytorch
  34. if 'running_var' in name:
  35. # merge bn
  36. tmp = bn_param['weight'] / torch.sqrt(bn_param['running_var'] + 1e-5)
  37. if is_deconv:
  38. weights = (tmp.view(tmp.size()[0], 1, 1, 1) * weights.permute(1,0,2,3)).permute(1,0,2,3)
  39. else:
  40. weights = tmp.view(tmp.size()[0], 1, 1, 1) * weights
  41. bias = tmp * (bias - bn_param['running_mean']) + bn_param['bias']
  42. return weights, bias
  43. return None, None
  44. import sys
  45. if __name__ == '__main__':
  46. if len(sys.argv) != 2:
  47. print('Usage: python pytorch_merge_bn.py YOU_MODEL')
  48. sys.exit(-1)
  49. model_path = sys.argv[1]
  50. print('input model: ', model_path)
  51. checkpoint = torch.load(model_path)
  52. trained_weights = checkpoint['net_state_dict']
  53. '''
  54. ## conv_bn_relu module
  55. # NAME | SIZE
  56. # conv4.0.weight torch.Size([128, 256, 3, 3])
  57. # conv4.1.weight torch.Size([256])
  58. # conv4.1.bias torch.Size([256])
  59. # conv4.1.running_mean torch.Size([256])
  60. # conv4.1.running_var torch.Size([256])
  61. ## deconv_bn_relu module
  62. # NAME | SIZE
  63. # deconv4.0.weight torch.Size([256, 128, 4, 4])
  64. # deconv4.1.weight torch.Size([128])
  65. # deconv4.1.bias torch.Size([128])
  66. # deconv4.1.running_mean torch.Size([128])
  67. # deconv4.1.running_var torch.Size([128])
  68. '''
  69. # check it in your net modules
  70. deconv_layer_names = ['deconv4', 'deconv3', 'deconv2', 'deconv1']
  71. temp = []
  72. for deconv_name in deconv_layer_names:
  73. temp.append(deconv_name + '.0')
  74. temp.append(deconv_name + '.1')
  75. deconv_layer_names = temp
  76. # go through pytorch net
  77. new_weights = OrderedDict()
  78. inner_product_flag = False
  79. for name, params in trained_weights.items():
  80. print ('name: ', name, params.size())
  81. if len(params.size()) == 4:
  82. _, _ = merge(params, name, 'Convolution', deconv_layer_names=deconv_layer_names)
  83. prev_layer = name
  84. # print('prev1: ', prev_layer)
  85. elif len(params.size()) == 1 and not inner_product_flag:
  86. w, b = merge(params, name, 'BatchNorm', deconv_layer_names=deconv_layer_names)
  87. # print('prev2: ', prev_layer)
  88. if w is not None:
  89. new_weights[prev_layer] = w
  90. new_weights[prev_layer.replace('weight', 'bias')] = b
  91. # mergebn
  92. merged = True
  93. else:
  94. # inner product layer (TODO, inner product layer may have bn module)
  95. if name.find('num_batches_tracked') == -1:
  96. new_weights[name] = params
  97. inner_product_flag = True
  98. else:
  99. pass
  100. # for the last conv/deconv if it has no bn module
  101. if merged is False:
  102. new_weights[prev_layer] = weights
  103. new_weights[prev_layer.replace('weight', 'bias')] = bias
  104. checkpoint['net_state_dict'] = new_weights
  105. # save new weights
  106. model_name = model_path[model_path.rfind('/')+1:]
  107. model_path = model_path[:model_path.rfind('/')]
  108. if model_path.find('/') == -1:
  109. model_path = './'
  110. torch.save(checkpoint, model_path + '/merge_bn_' + model_name)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/空白诗007/article/detail/827987
推荐阅读
相关标签
  

闽ICP备14008679号