当前位置:   article > 正文

模型压缩(二)yolov5剪枝

yolov5剪枝

一、yolov5s

在yolov5s.ymal文件中,

depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple。

通道深度(残差数)及宽度(通道数)相对标准的比例。

标准的backbone中的C3的number分别为:3、6、9、3

yolov5s的backbone中的C3的number为:1,、2、3、1  (depth_multiple*number)

同理网络宽度width_multiple*args[0]。

head类似。

  1. -------------------------------------0-P1/2----------------------------------------------
  2. model.0.conv.weight --------- torch.Size([32, 3, 6, 6])
  3. model.0.bn.weight --------- torch.Size([32])
  4. model.0.bn.bias --------- torch.Size([32])
  5. -------------------------------------1-P2/4----------------------------------------------
  6. model.1.conv.weight --------- torch.Size([64, 32, 3, 3])
  7. model.1.bn.weight --------- torch.Size([64])
  8. model.1.bn.bias --------- torch.Size([64])
  9. -------------------------------------C3----------------------------------------------
  10. **cv1**
  11. model.2.cv1.conv.weight --------- torch.Size([32, 64, 1, 1])
  12. model.2.cv1.bn.weight --------- torch.Size([32]) ***
  13. model.2.cv1.bn.bias --------- torch.Size([32]) ***
  14. **cv2**
  15. model.2.cv2.conv.weight --------- torch.Size([32, 64, 1, 1])
  16. model.2.cv2.bn.weight --------- torch.Size([32])
  17. model.2.cv2.bn.bias --------- torch.Size([32])
  18. **cv3**
  19. model.2.cv3.conv.weight --------- torch.Size([64, 64, 1, 1])
  20. model.2.cv3.bn.weight --------- torch.Size([64])
  21. model.2.cv3.bn.bias --------- torch.Size([64])
  22. bneck:*1
  23. model.2.m.0.cv1.conv.weight --------- torch.Size([32, 32, 1, 1])
  24. model.2.m.0.cv1.bn.weight --------- torch.Size([32]) ***
  25. model.2.m.0.cv1.bn.bias --------- torch.Size([32]) ***
  26. model.2.m.0.cv2.conv.weight --------- torch.Size([32, 32, 3, 3])
  27. model.2.m.0.cv2.bn.weight --------- torch.Size([32]) ***
  28. model.2.m.0.cv2.bn.bias --------- torch.Size([32]) ***
  29. -------------------------------------3-P3/8----------------------------------------------
  30. model.3.conv.weight --------- torch.Size([128, 64, 3, 3])
  31. model.3.bn.weight --------- torch.Size([128])
  32. model.3.bn.bias --------- torch.Size([128])
  33. -------------------------------------C3----------------------------------------------
  34. **cv1**
  35. model.4.cv1.conv.weight --------- torch.Size([64, 128, 1, 1])
  36. model.4.cv1.bn.weight --------- torch.Size([64]) ***
  37. model.4.cv1.bn.bias --------- torch.Size([64]) ***
  38. **cv2**
  39. model.4.cv2.conv.weight --------- torch.Size([64, 128, 1, 1])
  40. model.4.cv2.bn.weight --------- torch.Size([64])
  41. model.4.cv2.bn.bias --------- torch.Size([64])
  42. **cv3**
  43. model.4.cv3.conv.weight --------- torch.Size([128, 128, 1, 1])
  44. model.4.cv3.bn.weight --------- torch.Size([128])
  45. model.4.cv3.bn.bias --------- torch.Size([128])
  46. **bneck1**
  47. model.4.m.0.cv1.conv.weight --------- torch.Size([64, 64, 1, 1])
  48. model.4.m.0.cv1.bn.weight --------- torch.Size([64])
  49. model.4.m.0.cv1.bn.bias --------- torch.Size([64])
  50. model.4.m.0.cv2.conv.weight --------- torch.Size([64, 64, 3, 3])
  51. model.4.m.0.cv2.bn.weight --------- torch.Size([64])
  52. model.4.m.0.cv2.bn.bias --------- torch.Size([64])
  53. **bneck2**
  54. model.4.m.1.cv1.conv.weight --------- torch.Size([64, 64, 1, 1])
  55. model.4.m.1.cv1.bn.weight --------- torch.Size([64])
  56. model.4.m.1.cv1.bn.bias --------- torch.Size([64])
  57. model.4.m.1.cv2.conv.weight --------- torch.Size([64, 64, 3, 3])
  58. model.4.m.1.cv2.bn.weight --------- torch.Size([64])
  59. model.4.m.1.cv2.bn.bias --------- torch.Size([64])
  60. -------------------------------------5-P4/16----------------------------------------------
  61. model.5.conv.weight --------- torch.Size([256, 128, 3, 3])
  62. model.5.bn.weight --------- torch.Size([256])
  63. model.5.bn.bias --------- torch.Size([256])
  64. 。。。。。。

二、C3模块

        本文选择yolov5s进行通道剪枝,同样根据BN层稀疏化达到剪枝效果。在yolov5s结构中存在shortcut与cat,主路与支路合并操作。其中shortcut是将前层与后层特征相加,cat是通道连接,而shortcut必须保证前后层的通道数一致才可相加。如果shortcut的前后层参与剪枝,就无法保证前后层的通道数一致,所以剪枝过程中必须剔除参与shortcut操作的卷积层,而cat操作则不影响。

yolov5s的C3模块的Bottleneck结构中存在shortcut操作。为了避免BN层稀疏后,通道数不匹配,所以所有的残差结构都不剪枝。

C3

  1. class Bottleneck(nn.Module):
  2. # Standard bottleneck
  3. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  4. super().__init__()
  5. c_ = int(c2 * e) # hidden channels
  6. self.cv1 = Conv(c1, c_, 1, 1)
  7. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  8. self.add = shortcut and c1 == c2 #通道相同直接相加。
  9. def forward(self, x):
  10. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  11. class C3(nn.Module):
  12. # CSP Bottleneck with 3 convolutions
  13. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  14. super().__init__()
  15. c_ = int(c2 * e) # hidden channels
  16. self.cv1 = Conv(c1, c_, 1, 1)
  17. self.cv2 = Conv(c1, c_, 1, 1)#支路
  18. self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
  19. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  20. # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
  21. def forward(self, x):
  22. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))

 C3结构

所以C3结构中cv1、cv2参与剪枝。

三、剪枝操作

1、稀疏训练

剔除C3结构中不参与剪枝的卷积层 。

  1. #-------------------------------parse---------------------------
  2. srtmp=opt.sr*(1-0.9*epoch/epochs)
  3. if opt.st:
  4. ignore_bn_list=[]
  5. #记录bottleneck中所有bn层
  6. #C3结构中第一个卷积层与bneck中conv层不剪枝
  7. #即参与add操作有三层conv
  8. for k,m in model.named_modules():
  9. if isinstance(m,Bottleneck):
  10. if m.add:
  11. ignore_bn_list.append(k.split('.',2)[0]+'.cv1.bn')
  12. ignore_bn_list.append(k+ '.cv1.bn')
  13. ignore_bn_list.append(k + '.cv2.bn')
  14. if isinstance(k,nn.BatchNorm2d) and (k not in ignore_bn_list):
  15. m.weight.grad.data.add_(srtmp*torch.sign(m.weight.data))
  16. m.bias.grad.data.add_(opt.sr*10 * torch.sign(m.weight.bias))
  17. print(ignore_bn_list)

2、剪枝操作

规整剪枝与正常剪枝。

正常剪枝

需剪枝的bn层

  1. bn_layers= {}
  2. ignore_bn_layers=[]
  3. for layer_name,layer_model in model.named_modules():
  4. if isinstance(layer_model,Bottleneck):
  5. if layer_model.add:
  6. ignore_bn_layers.append(layer_name.rsplit('.',2)[0]+'.cv1.bn')#C3中第一个conv
  7. ignore_bn_layers.append(layer_name+'.cv1.bn')#bottleneck中第一个conv
  8. ignore_bn_layers.append(layer_name+'.cv2.bn')#bottleneck中第一个conv
  9. if isinstance(layer_model,nn.BatchNorm2d) and (layer_name not in ignore_bn_layers):
  10. # print(ignore_bn_layers,layer_name)
  11. #未剔除全,主要是每次遍历进入C3中时,cv1没剔除,直到bneck中才开始。
  12. bn_layers[layer_name]=layer_model
  13. # print(ignore_bn_layers,)
  14. # print(len(ignore_bn_layers))
  15. # print(bn_layers)
  16. # print(len(bn_layers))
  17. # exit()
  18. #再次过滤4个C3中的第一个cv层
  19. bn_layers= {k:v for k,v in bn_layers.items() if k not in ignore_bn_layers}
  20. # print(bn_names)
  21. # print(len(bn_names))
  22. # exit()

统计所有BN层通道数量及各通道的权重值,对权重进行排序,并计算得到索引阈值。

  1. bn_size=[da.weight.data.shape[0] for da in bn_layers.values()]
  2. total_size=sum(bn_size)
  3. print(total_size)
  4. bn_weights=torch.zeros(total_size)
  5. start=0
  6. for i,w in enumerate(bn_layers.values()):
  7. size=w.weight.data.shape[0]
  8. bn_weights[start:(start+size)] = w.weight.data.abs().clone()
  9. start+=bn_size[i]
  10. print(bn_weights,bn_weights.shape)
  11. bn_data,id=torch.sort(bn_weights)
  12. thresh_index=int(percent*total_size)
  13. thresh_weight=bn_data[thresh_index]
  14. print(thresh_index,thresh_weight)
  15. print(f'Gamma value that less than {thresh_weight:.4f} are set to zero!')
  16. print("=" * 94)
  17. print(f"|\t{'layer name':<25}{'|':<10}{'origin channels':<20}{'|':<10}{'remaining channels':<20}|")

存在问题:

根据阈值来分隔,可能存在某一BN层所有通道均小于阈值,如果将其过滤掉,会造成层层之间的断开,此时需要做判断进行限制,使得每层最少有一个通道得以保留。

解决方法:获取每个bn层的权重的最大值,然后在这些最大值中取最小值与设定的阈值进行对比,如果小于阈值,则提示修改。
 

  1. # 避免剪掉所有channel的最高阈值(每个BN层的gamma的最大值的最小值即为阈值上限)
  2. highest_thre = []
  3. for bnlayer in bn_layers.values():
  4. highest_thre.append(bnlayer.weight.data.abs().max().item())
  5. # print("highest_thre:",highest_thre)
  6. highest_thre = min(highest_thre)
  7. # 找到highest_thre对应的下标对应的百分比
  8. percent_limit = (bn_data == highest_thre).nonzero()[0, 0].item() / len(bn_weights)
  9. print(f'Suggested Gamma threshold should be less than {highest_thre:.4f}.')
  10. print(f'The corresponding prune ratio is {percent_limit:.3f}, but you can set higher.')

重新设置模型文件

  1. pruned_num=0
  2. pruned_yaml = {}
  3. nc = model.model[-1].nc
  4. with open(cfg, encoding='ascii', errors='ignore') as f:
  5. model_yamls = yaml.safe_load(f) # model dict
  6. # # Define model
  7. pruned_yaml["nc"] = model.model[-1].nc
  8. pruned_yaml["depth_multiple"] = model_yamls["depth_multiple"]
  9. pruned_yaml["width_multiple"] = model_yamls["width_multiple"]
  10. pruned_yaml["anchors"] = model_yamls["anchors"]
  11. anchors = model_yamls["anchors"]
  12. pruned_yaml["backbone"] = [
  13. [-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  14. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  15. [-1, 3, C3Pruned, [128]],
  16. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
  17. [-1, 6, C3Pruned, [256]],
  18. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
  19. [-1, 9, C3Pruned, [512]],
  20. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
  21. [-1, 3, C3Pruned, [1024]],
  22. [-1, 1, SPPFPruned, [1024, 5]], # 9
  23. ]
  24. pruned_yaml["head"] = [
  25. [-1, 1, Conv, [512, 1, 1]],
  26. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  27. [[-1, 6], 1, Concat, [1]], # cat backbone P4
  28. [-1, 3, C3Pruned, [512, False]], # 13
  29. [-1, 1, Conv, [256, 1, 1]],
  30. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  31. [[-1, 4], 1, Concat, [1]], # cat backbone P3
  32. [-1, 3, C3Pruned, [256, False]], # 17 (P3/8-small)
  33. [-1, 1, Conv, [256, 3, 2]],
  34. [[-1, 14], 1, Concat, [1]], # cat head P4
  35. [-1, 3, C3Pruned, [512, False]], # 20 (P4/16-medium)
  36. [-1, 1, Conv, [512, 3, 2]],
  37. [[-1, 10], 1, Concat, [1]], # cat head P5
  38. [-1, 3, C3Pruned, [1024, False]], # 23 (P5/32-large)
  39. [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  40. ]

模型重构:

  1. maskbndict={}
  2. remain_num=0
  3. for name,layer in model.named_modules():
  4. if isinstance(layer,nn.BatchNorm2d):
  5. bn_model=layer
  6. mask=obtain_bn_mask(bn_model,thresh_weight)
  7. # print(mask)
  8. if name in ignore_bn_layers:
  9. # print('-----')
  10. mask=torch.ones(layer.weight.data.size()).cuda()
  11. maskbndict[name]=mask
  12. # print(mask)
  13. remain_num+=int(mask.sum())
  14. bn_model.weight.data.mul_(mask)
  15. bn_model.bias.data.mul_(mask)
  16. print(f"|\t{name:<25}{'|':<10}{bn_model.weight.data.size()[0]:<20}{'|':<10}{int(mask.sum()):<20}|")
  17. assert int(
  18. mask.sum()) > 0, "Current remaining channel must greater than 0!!! please set prune percent to lower thesh, or you can retrain a more sparse model..."
  19. print("=" * 94)
  20. pruned_model=ModelPruned(maskbndict=maskbndict,cfg=pruned_yaml,ch=3).cuda()
  21. for m in pruned_model.modules():
  22. if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
  23. m.inplace = True # pytorch 1.7.0 compatibility
  24. elif type(m) is Conv:
  25. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  26. from_to_map=pruned_model.from_to_map
  27. pruned_model_state=pruned_model.state_dict()

参数拷贝:

  1. #-----------------------------参数拷贝----------------------------
  2. modelstate = model.state_dict()
  3. changed_state=[]
  4. for((layername,layermodel),(pruned_layername,pruned_layermodel)) in zip(model.named_modules(),pruned_model.named_modules()):
  5. if isinstance(layermodel,nn.Conv2d) and not layername.startswith("model.24"):
  6. convname=layername[:-4]+"bn"
  7. if convname in from_to_map.keys():
  8. former=from_to_map[convname]
  9. if isinstance(former,str):
  10. out_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[layername[:-4] + "bn"].cpu().numpy())))
  11. in_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[former].cpu().numpy())))
  12. w = layermodel.weight.data[:, in_idx, :, :].clone()
  13. if len(w.shape) == 3: # remain only 1 channel.
  14. w = w.unsqueeze(1)
  15. w = w[out_idx, :, :, :].clone()
  16. pruned_layermodel.weight.data = w.clone()
  17. changed_state.append(layername + ".weight")
  18. if isinstance(former, list):
  19. orignin = [modelstate[i + ".weight"].shape[0] for i in former]
  20. formerin = []
  21. for it in range(len(former)):
  22. name = former[it]
  23. tmp = [i for i in range(maskbndict[name].shape[0]) if maskbndict[name][i] == 1]
  24. if it > 0:
  25. tmp = [k + sum(orignin[:it]) for k in tmp]
  26. formerin.extend(tmp)
  27. out_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[layername[:-4] + "bn"].cpu().numpy())))
  28. w = layermodel.weight.data[out_idx, :, :, :].clone()
  29. pruned_layermodel.weight.data = w[:, formerin, :, :].clone()
  30. changed_state.append(layername + ".weight")
  31. else:
  32. out_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[layername[:-4] + "bn"].cpu().numpy())))
  33. w = layermodel.weight.data[out_idx, :, :, :].clone()
  34. assert len(w.shape) == 4
  35. pruned_layermodel.weight.data = w.clone()
  36. changed_state.append(layername + ".weight")
  37. if isinstance(layermodel, nn.BatchNorm2d):
  38. out_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[layername].cpu().numpy())))
  39. pruned_layermodel.weight.data = layermodel.weight.data[out_idx].clone()
  40. pruned_layermodel.bias.data = layermodel.bias.data[out_idx].clone()
  41. pruned_layermodel.running_mean = layermodel.running_mean[out_idx].clone()
  42. pruned_layermodel.running_var = layermodel.running_var[out_idx].clone()
  43. changed_state.append(layername + ".weight")
  44. changed_state.append(layername + ".bias")
  45. changed_state.append(layername + ".running_mean")
  46. changed_state.append(layername + ".running_var")
  47. changed_state.append(layername + ".num_batches_tracked")
  48. if isinstance(layermodel, nn.Conv2d) and layername.startswith("model.24"):
  49. former = from_to_map[layername]
  50. in_idx = np.squeeze(np.argwhere(np.asarray(maskbndict[former].cpu().numpy())))
  51. pruned_layermodel.weight.data = layermodel.weight.data[:, in_idx, :, :]
  52. pruned_layermodel.bias.data = layermodel.bias.data
  53. changed_state.append(layername + ".weight")
  54. changed_state.append(layername + ".bias")
  55. missing = [i for i in pruned_model_state.keys() if i not in changed_state]
  56. pruned_model.eval()
  57. pruned_model.names = model.names
  58. # =============================================================================================== #
  59. torch.save({"model": model}, "weights/pruned_model/orign_model.pt")
  60. model = pruned_model
  61. torch.save({"model": model}, "weights/pruned_model/pruned_model.pt")
  62. model.cuda().eval()

参考:

YOLOv5模型剪枝压缩(2)-YOLOv5模型简介和剪枝层选择_MidasKing的博客-CSDN博客_yolov5剪枝

yolov5模型压缩之模型剪枝_小小小绿叶的博客-CSDN博客_yolov5模型裁剪

GitHub - midasklr/yolov5prune

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/961795
推荐阅读
相关标签
  

闽ICP备14008679号