当前位置:   article > 正文

CNN中的BN(伪代码讲解)_cnn伪代码

cnn伪代码
  1. https: // www.cnblogs.com / adong7639 / p / 9145911.
  2. html
  3. 写的很好
  4. '''
  5. 本文讲解的是在CNN中的batch normalization
  6. '''
  7. import torch
  8. import torch.nn as nn
  9. import copy
  10. class Net(nn.Module):
  11. def __init__(self, dim, pretrained):
  12. super(Net, self).__init__()
  13. self.bn = nn.BatchNorm2d(dim, 0)
  14. if pretrained:
  15. self.pretrained()
  16. def forward(self, input):
  17. return self.bn(input)
  18. def pretrained(self):
  19. nn.init.constant_(self.bn.weight, 1)
  20. nn.init.constant_(self.bn.bias, 0)
  21. def train():
  22. dim = 3
  23. model = Net(dim)
  24. print(sum(p.numel() for p in model.parameters() if p.requires_grad))
  25. for p in model.parameters():
  26. print(p, p.requires_grad)
  27. '''
  28. 对于CNN特征图通道数为3的Batch normalization层而言,BN层的learnable parameter有6个,分别是gamma和beta
  29. 在训练过程中gamma和beta才是需要被更新的
  30. 6
  31. Parameter containing:
  32. tensor([0.2322, 0.9405, 0.9887], requires_grad=True) True
  33. Parameter containing:
  34. tensor([0., 0., 0.], requires_grad=True) True
  35. '''
  36. # model.eval()
  37. feature_map = torch.randn((2, 3, 2, 2))
  38. output1 = model(feature_map)
  39. state_dict = model.state_dict()
  40. for k, v in state_dict.items():
  41. print(k, v)
  42. '''
  43. bn.weight tensor([0.2860, 0.5986, 0.0594])
  44. bn.bias tensor([0., 0., 0.])
  45. bn.running_mean tensor([-0.2098, 0.1876, -0.3045])
  46. bn.running_var tensor([0.8099, 1.5140, 0.5880])
  47. bn.num_batches_tracked tensor(1)
  48. 打印字典时,发现batch normalization层有5个参数
  49. 其中bn.weight 对应论文中的gamma bn.bias对应论文中的beta
  50. bn.running_mean则是对于当前batch size的数据所统计出来的平均值
  51. bn.running_var是对于当前batch size的数据所统计出来的方差
  52. '''
  53. print('bn.running_mean', state_dict['bn.running_mean'])
  54. print('bn.running_var', state_dict['bn.running_var'])
  55. #
  56. print(torch.mean(feature_map.permute(1, 0, 2, 3).contiguous().view(dim, -1), 1))
  57. print(torch.var(feature_map.permute(1, 0, 2, 3).contiguous().view(dim, -1), 1))
  58. '''
  59. bn.running_mean tensor([-0.2098, 0.1876, -0.3045])
  60. bn.running_var tensor([0.8099, 1.5140, 0.5880])
  61. tensor([-0.2098, 0.1876, -0.3045])
  62. tensor([0.8099, 1.5140, 0.5880])
  63. 当然这是在设定BN层的momentum=1时,即当前时刻的统计量(running_mean和running_var)完全由统计平均值决定
  64. statistic_t_new=(1-momentum)*stastic_(t-1)+momentum)*stastic_(t)
  65. momentum决定当前时刻的bn.running_mean和bn.running_var数值
  66. (1)当momentum=1时,则数值完全由当前时刻计算出来的统计量决定
  67. (2)由于模型上一次的统计量(由于这里不进行模型的参数更新和迭代训练,故而模型的初始值
  68. bn.running_mean tensor([0., 0., 0.])
  69. bn.running_var tensor([1., 1., 1.])) 可能不是0 0 0 1 1 1,而是随机初始化
  70. 故而如果将momentum设置为0,则模型会一直保持
  71. bn.running_mean tensor([0., 0., 0.])
  72. bn.running_var tensor([1., 1., 1.])
  73. (3)当设置默认参数momentum=0.1时
  74. bn.running_mean tensor([0.0233, 0.0166, 0.0469])
  75. bn.running_var tensor([0.9961, 1.0899, 0.9974])
  76. tensor([0.2329, 0.1663, 0.4691]) 表示用tensor的方法计算出来的统计量
  77. tensor([0.9615, 1.8986, 0.9738])
  78. 刚好bn.running_mean和bn.running_var是统计量的0.1倍
  79. 再次回顾计算BN的方式:
  80. 对于CNN的输入而言(即BN的输出时4-dimension),则
  81. 在batch,H,W 维度上进行normalization,也被称为spatial batch normalization
  82. '''
  83. if __name__ == '__main__':
  84. '''
  85. 在BN层中,一般,bn.weight时随机初始化的,而bn.bias初始化为全0
  86. 假设现在已知输入特征图的数值,和对应batch normalization的参数,求BN输出的结果
  87. momentum=0.1默认值 0.9*(t-1时刻的统计量)+0.1*(t时刻的统计量)
  88. '''
  89. dim = 3
  90. momentum = 0.1
  91. model = Net(dim, True)
  92. input = torch.randn((2, 3, 2, 2))
  93. output1 = model(input)
  94. def bn_simple_train(input, model):
  95. '''
  96. :param input: 卷积神经网络特征图 shape [batch size,C,H,W]
  97. :return:
  98. '''
  99. mean = torch.mean(input.permute(1, 0, 2, 3).contiguous().view(dim, -1), 1) # shape [dim]
  100. var = torch.var(input.permute(1, 0, 2, 3).contiguous().view(dim, -1), 1) # shape [dim]
  101. init_mean = torch.zeros((dim))
  102. init_var = torch.ones((dim))
  103. run_mean = (1 - momentum) * init_mean + momentum * mean # 滑动平均的方式计算新的均值,训练时计算,为测试数据做准备
  104. run_var = (1 - momentum) * init_var + momentum * var # 滑动平均的方式计算新的方差,训练时计算,为测试数据做准备
  105. run_std = torch.sqrt(run_var + 1e-5)
  106. run_mean_exp = run_mean.view(1, input.shape[1], 1, 1).expand(input.shape)
  107. run_std_exp = run_std.view(1, input.shape[1], 1, 1).expand(input.shape)
  108. '''
  109. 这里的tensor复制问题也让我想了很久
  110. tensor1=torch.tensor([1,2,3])
  111. 需要得到一个2*3*2*2的tensor2,然后需要满足
  112. tensor2[:,0,:,:]=1
  113. tensor2[:,1,:,:]=2
  114. tensor2[:,2,:,:]=3
  115. 这个,除了for循环,内部函数也可以实现
  116. 先unsqueeze 到 ,(1, 2, 1, 1) 再 expand(2,3,2,2)
  117. expand, 只能再指定维度进行复制, 不能增加维度, 所以你要先unsqueeze到4个维度
  118. expand的时候会找channel相同的维度,这些维度不变,其他维度复制
  119. '''
  120. # run_mean_exp=torch.zeros((2,3,2,2))
  121. # for i in range(3):
  122. # run_mean_exp[:,i,:,:]=run_mean[i]
  123. # run_std_exp = torch.zeros((2, 3, 2, 2))
  124. # for i in range(3):
  125. # run_std_exp[:, i, :, :] = run_std[i]
  126. output2 = input - run_mean_exp
  127. output2 = output2 / run_std_exp
  128. init_weights = model.state_dict().items()['bn.weights'] # gamma
  129. init_bias = model.state_dict().items()['bn.bias'] # beta
  130. init_weights_exp = init_weights.view(1, input.shape[1], 1, 1).expand(input.shape)
  131. init_bias_exp = init_bias.view(1, input.shape[1], 1, 1).expand(input.shape)
  132. '''
  133. 在训练过程中会一直更新(反向传播时)的可学习参数
  134. '''
  135. # init_weights_exp=torch.zeros((2, 3, 2, 2))
  136. # for i in range(3):
  137. # init_weights_exp[:, i, :, :] = init_weights[i]
  138. #
  139. # init_bias_exp = torch.zeros((2, 3, 2, 2))
  140. # for i in range(3):
  141. # init_bias_exp[:, i, :, :] = init_bias[i]
  142. output2 = output2 * init_weights_exp
  143. output2 = output2 + init_bias_exp
  144. return output2
  145. def bn_for_test(input, model):
  146. '''
  147. 测试过程中,BN层的running mean和running var都是固定值,不再时新的验证数据的统计量,在model.eval()模式下这两个参数会被固定住
  148. 而gamma和beta也不发生改变
  149. :param input:
  150. :param model:
  151. :return:
  152. '''
  153. state_dict = model.state_dict()
  154. init_weights = state_dict.items()['bn.weight']
  155. init_bias = state_dict.items()['bn.bias']
  156. running_mean = state_dict.items()(['bn.running_mean']
  157. running_var = state_dict.tems()['bn.running_var']
  158. mean = running_mean.view(1, input.shape[1], 1, 1).expand(input.shape)
  159. var = running_var.view(1, input.shape[1], 1, 1).expand(input.shape)
  160. weights = init_weights.view(1, input.shape[1], 1, 1).expand(input.shape)
  161. bias = init_bias.view(1, input.shape[1], 1, 1).expand(input.shape)
  162. output = (input - mean) / torch.sqrt(var + 1e-5)
  163. output = output * weights + bias
  164. return output
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/151318?site
推荐阅读
相关标签
  

闽ICP备14008679号