赞
踩
咱就是说,需要源码请造访:Jane的GitHub
:在这里
上午没写完的,下午继续,是一个小尾巴。其实上午把训练的关键部分和数据的关键部分都写完了,现在就是写一下推理部分
在推理过程为了提高效率,速度更快:
# 很隐蔽,刚开始我没想到接口是在这里的 def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers print('Fusing layers... ') for m in self.model.modules(): if isinstance(m, RepConv): #print(f" fuse_repvgg_block") m.fuse_repvgg_block() elif isinstance(m, RepConv_OREPA): #print(f" switch_to_deploy") m.switch_to_deploy() elif type(m) is Conv and hasattr(m, 'bn'): m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, 'bn') # remove batchnorm m.forward = m.fuseforward # update forward elif isinstance(m, (IDetect, IAuxDetect)): m.fuse() m.forward = m.fuseforward self.info() return self
当遇到conv后面一定是有BN的,所以
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
【 bn.weight 就是以下公式中的gamma,sigma平方是方差bn.running_var
】fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
【bn.bias是上面公式中的β,μ为均值bn.running_mean】fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
def fuse_conv_and_bn(conv, bn): # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ fusedconv = nn.Conv2d(conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, groups=conv.groups, bias=True).requires_grad_(False).to(conv.weight.device) # prepare filters bn.weight 对应论文中的gamma bn.bias对应论文中的beta bn.running_mean则是对于当前batch size的数据所统计出来的平均值 bn.running_var是对于当前batch size的数据所统计出来的方差 w_conv = conv.weight.clone().view(conv.out_channels, -1) w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) # prepare spatial bias b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) return fusedconv
把Repvgg中的卷积和BN合在一起
padding后
所有的都改变以后:model长这样——>
OK,这次真没啦,886~~~~
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。