当前位置:   article > 正文

【论文笔记】Run, Don’t Walk: Chasing Higher FLOPS for Faster Neural Networks_论文run don't walk

论文run don't walk

论文地址:Run, Don't Walk: Chasing Higher FLOPS for Faster Neural Networks

代码地址:https://github.com/jierunchen/fasternet

该论文主要提出了PConv,通过优化FLOPS提出了快速推理模型FasterNet。

在设计神经网络结构的时候,大部分注意力都会放在降低FLOPs( floating-point opera-
tions)上,有的时候FLOPs降低了,并不意味了推理速度加快了,这主要是因为没考虑到FLOPS(floating-point operations per second)。针对该问题,作者提出了PConv( partial convolution),通过提高FLOPS来加快推理速度。

一、引言

      非常多的实时推理模型都将重点放在降低FLOPs上,比如:MobileNet,ShuffleNet,GhostNet等等。虽然这些网络都降低了FLOPs,但是他们没有考虑到FLOPS,所以推理速度仍有优化空间,推理的延时计算公式如下:

由上式可以看出,要想加快推理速度,不仅可以从FLOPs入手,也可以优化FLOPS。作者在多个模型上做了实验,发现很多模型的FLOPS低于ResNet50。于是作者提出了PConv,通过提高FLOPS来加快推理速度。

二、PConv

为了提高FLOPS,作者提出了PConv,其结构如下图:

部分通道数经过卷积运算,其他通道不进行运算。再看了几眼。。。。这个和GhostConv好像呀。。。。

网络整体结构如下:

三、模型性能

FasterNet在ImageNet-1K上的表现如下:

在coco数据集上的表现如下:

四、代码

给出PConv的代码,也是非常简单:

  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT License.
  3. import torch
  4. import torch.nn as nn
  5. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  6. from functools import partial
  7. from typing import List
  8. from torch import Tensor
  9. import copy
  10. import os
  11. try:
  12. from mmdet.models.builder import BACKBONES as det_BACKBONES
  13. from mmdet.utils import get_root_logger
  14. from mmcv.runner import _load_checkpoint
  15. has_mmdet = True
  16. except ImportError:
  17. print("If for detection, please install mmdetection first")
  18. has_mmdet = False
  19. class Partial_conv3(nn.Module):
  20. def __init__(self, dim, n_div, forward):
  21. super().__init__()
  22. self.dim_conv3 = dim // n_div
  23. self.dim_untouched = dim - self.dim_conv3
  24. self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
  25. if forward == 'slicing':
  26. self.forward = self.forward_slicing
  27. elif forward == 'split_cat':
  28. self.forward = self.forward_split_cat
  29. else:
  30. raise NotImplementedError
  31. def forward_slicing(self, x: Tensor) -> Tensor:
  32. # only for inference
  33. x = x.clone() # !!! Keep the original input intact for the residual connection later
  34. x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
  35. return x
  36. def forward_split_cat(self, x: Tensor) -> Tensor:
  37. # for training/inference
  38. x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
  39. x1 = self.partial_conv3(x1)
  40. x = torch.cat((x1, x2), 1)
  41. return x
  42. class MLPBlock(nn.Module):
  43. def __init__(self,
  44. dim,
  45. n_div,
  46. mlp_ratio,
  47. drop_path,
  48. layer_scale_init_value,
  49. act_layer,
  50. norm_layer,
  51. pconv_fw_type
  52. ):
  53. super().__init__()
  54. self.dim = dim
  55. self.mlp_ratio = mlp_ratio
  56. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  57. self.n_div = n_div
  58. mlp_hidden_dim = int(dim * mlp_ratio)
  59. mlp_layer: List[nn.Module] = [
  60. nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False),
  61. norm_layer(mlp_hidden_dim),
  62. act_layer(),
  63. nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False)
  64. ]
  65. self.mlp = nn.Sequential(*mlp_layer)
  66. self.spatial_mixing = Partial_conv3(
  67. dim,
  68. n_div,
  69. pconv_fw_type
  70. )
  71. if layer_scale_init_value > 0:
  72. self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  73. self.forward = self.forward_layer_scale
  74. else:
  75. self.forward = self.forward
  76. def forward(self, x: Tensor) -> Tensor:
  77. shortcut = x
  78. x = self.spatial_mixing(x)
  79. x = shortcut + self.drop_path(self.mlp(x))
  80. return x
  81. def forward_layer_scale(self, x: Tensor) -> Tensor:
  82. shortcut = x
  83. x = self.spatial_mixing(x)
  84. x = shortcut + self.drop_path(
  85. self.layer_scale.unsqueeze(-1).unsqueeze(-1) * self.mlp(x))
  86. return x
  87. class BasicStage(nn.Module):
  88. def __init__(self,
  89. dim,
  90. depth,
  91. n_div,
  92. mlp_ratio,
  93. drop_path,
  94. layer_scale_init_value,
  95. norm_layer,
  96. act_layer,
  97. pconv_fw_type
  98. ):
  99. super().__init__()
  100. blocks_list = [
  101. MLPBlock(
  102. dim=dim,
  103. n_div=n_div,
  104. mlp_ratio=mlp_ratio,
  105. drop_path=drop_path[i],
  106. layer_scale_init_value=layer_scale_init_value,
  107. norm_layer=norm_layer,
  108. act_layer=act_layer,
  109. pconv_fw_type=pconv_fw_type
  110. )
  111. for i in range(depth)
  112. ]
  113. self.blocks = nn.Sequential(*blocks_list)
  114. def forward(self, x: Tensor) -> Tensor:
  115. x = self.blocks(x)
  116. return x
  117. class PatchEmbed(nn.Module):
  118. def __init__(self, patch_size, patch_stride, in_chans, embed_dim, norm_layer):
  119. super().__init__()
  120. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, bias=False)
  121. if norm_layer is not None:
  122. self.norm = norm_layer(embed_dim)
  123. else:
  124. self.norm = nn.Identity()
  125. def forward(self, x: Tensor) -> Tensor:
  126. x = self.norm(self.proj(x))
  127. return x
  128. class PatchMerging(nn.Module):
  129. def __init__(self, patch_size2, patch_stride2, dim, norm_layer):
  130. super().__init__()
  131. self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=patch_size2, stride=patch_stride2, bias=False)
  132. if norm_layer is not None:
  133. self.norm = norm_layer(2 * dim)
  134. else:
  135. self.norm = nn.Identity()
  136. def forward(self, x: Tensor) -> Tensor:
  137. x = self.norm(self.reduction(x))
  138. return x
  139. class FasterNet(nn.Module):
  140. def __init__(self,
  141. in_chans=3,
  142. num_classes=1000,
  143. embed_dim=96,
  144. depths=(1, 2, 8, 2),
  145. mlp_ratio=2.,
  146. n_div=4,
  147. patch_size=4,
  148. patch_stride=4,
  149. patch_size2=2, # for subsequent layers
  150. patch_stride2=2,
  151. patch_norm=True,
  152. feature_dim=1280,
  153. drop_path_rate=0.1,
  154. layer_scale_init_value=0,
  155. norm_layer='BN',
  156. act_layer='RELU',
  157. fork_feat=False,
  158. init_cfg=None,
  159. pretrained=None,
  160. pconv_fw_type='split_cat',
  161. **kwargs):
  162. super().__init__()
  163. if norm_layer == 'BN':
  164. norm_layer = nn.BatchNorm2d
  165. else:
  166. raise NotImplementedError
  167. if act_layer == 'GELU':
  168. act_layer = nn.GELU
  169. elif act_layer == 'RELU':
  170. act_layer = partial(nn.ReLU, inplace=True)
  171. else:
  172. raise NotImplementedError
  173. if not fork_feat:
  174. self.num_classes = num_classes
  175. self.num_stages = len(depths)
  176. self.embed_dim = embed_dim
  177. self.patch_norm = patch_norm
  178. self.num_features = int(embed_dim * 2 ** (self.num_stages - 1))
  179. self.mlp_ratio = mlp_ratio
  180. self.depths = depths
  181. # split image into non-overlapping patches
  182. self.patch_embed = PatchEmbed(
  183. patch_size=patch_size,
  184. patch_stride=patch_stride,
  185. in_chans=in_chans,
  186. embed_dim=embed_dim,
  187. norm_layer=norm_layer if self.patch_norm else None
  188. )
  189. # stochastic depth decay rule
  190. dpr = [x.item()
  191. for x in torch.linspace(0, drop_path_rate, sum(depths))]
  192. # build layers
  193. stages_list = []
  194. for i_stage in range(self.num_stages):
  195. stage = BasicStage(dim=int(embed_dim * 2 ** i_stage),
  196. n_div=n_div,
  197. depth=depths[i_stage],
  198. mlp_ratio=self.mlp_ratio,
  199. drop_path=dpr[sum(depths[:i_stage]):sum(depths[:i_stage + 1])],
  200. layer_scale_init_value=layer_scale_init_value,
  201. norm_layer=norm_layer,
  202. act_layer=act_layer,
  203. pconv_fw_type=pconv_fw_type
  204. )
  205. stages_list.append(stage)
  206. # patch merging layer
  207. if i_stage < self.num_stages - 1:
  208. stages_list.append(
  209. PatchMerging(patch_size2=patch_size2,
  210. patch_stride2=patch_stride2,
  211. dim=int(embed_dim * 2 ** i_stage),
  212. norm_layer=norm_layer)
  213. )
  214. self.stages = nn.Sequential(*stages_list)
  215. self.fork_feat = fork_feat
  216. if self.fork_feat:
  217. self.forward = self.forward_det
  218. # add a norm layer for each output
  219. self.out_indices = [0, 2, 4, 6]
  220. for i_emb, i_layer in enumerate(self.out_indices):
  221. if i_emb == 0 and os.environ.get('FORK_LAST3', None):
  222. raise NotImplementedError
  223. else:
  224. layer = norm_layer(int(embed_dim * 2 ** i_emb))
  225. layer_name = f'norm{i_layer}'
  226. self.add_module(layer_name, layer)
  227. else:
  228. self.forward = self.forward_cls
  229. # Classifier head
  230. self.avgpool_pre_head = nn.Sequential(
  231. nn.AdaptiveAvgPool2d(1),
  232. nn.Conv2d(self.num_features, feature_dim, 1, bias=False),
  233. act_layer()
  234. )
  235. self.head = nn.Linear(feature_dim, num_classes) \
  236. if num_classes > 0 else nn.Identity()
  237. self.apply(self.cls_init_weights)
  238. self.init_cfg = copy.deepcopy(init_cfg)
  239. if self.fork_feat and (self.init_cfg is not None or pretrained is not None):
  240. self.init_weights()
  241. def cls_init_weights(self, m):
  242. if isinstance(m, nn.Linear):
  243. trunc_normal_(m.weight, std=.02)
  244. if isinstance(m, nn.Linear) and m.bias is not None:
  245. nn.init.constant_(m.bias, 0)
  246. elif isinstance(m, (nn.Conv1d, nn.Conv2d)):
  247. trunc_normal_(m.weight, std=.02)
  248. if m.bias is not None:
  249. nn.init.constant_(m.bias, 0)
  250. elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
  251. nn.init.constant_(m.bias, 0)
  252. nn.init.constant_(m.weight, 1.0)
  253. # init for mmdetection by loading imagenet pre-trained weights
  254. def init_weights(self, pretrained=None):
  255. logger = get_root_logger()
  256. if self.init_cfg is None and pretrained is None:
  257. logger.warn(f'No pre-trained weights for '
  258. f'{self.__class__.__name__}, '
  259. f'training start from scratch')
  260. pass
  261. else:
  262. assert 'checkpoint' in self.init_cfg, f'Only support ' \
  263. f'specify `Pretrained` in ' \
  264. f'`init_cfg` in ' \
  265. f'{self.__class__.__name__} '
  266. if self.init_cfg is not None:
  267. ckpt_path = self.init_cfg['checkpoint']
  268. elif pretrained is not None:
  269. ckpt_path = pretrained
  270. ckpt = _load_checkpoint(
  271. ckpt_path, logger=logger, map_location='cpu')
  272. if 'state_dict' in ckpt:
  273. _state_dict = ckpt['state_dict']
  274. elif 'model' in ckpt:
  275. _state_dict = ckpt['model']
  276. else:
  277. _state_dict = ckpt
  278. state_dict = _state_dict
  279. missing_keys, unexpected_keys = \
  280. self.load_state_dict(state_dict, False)
  281. # show for debug
  282. print('missing_keys: ', missing_keys)
  283. print('unexpected_keys: ', unexpected_keys)
  284. def forward_cls(self, x):
  285. # output only the features of last layer for image classification
  286. x = self.patch_embed(x)
  287. x = self.stages(x)
  288. x = self.avgpool_pre_head(x) # B C 1 1
  289. x = torch.flatten(x, 1)
  290. x = self.head(x)
  291. return x
  292. def forward_det(self, x: Tensor) -> Tensor:
  293. # output the features of four stages for dense prediction
  294. x = self.patch_embed(x)
  295. outs = []
  296. for idx, stage in enumerate(self.stages):
  297. x = stage(x)
  298. if self.fork_feat and idx in self.out_indices:
  299. norm_layer = getattr(self, f'norm{idx}')
  300. x_out = norm_layer(x)
  301. outs.append(x_out)
  302. return outs
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/运维做开发/article/detail/990436
推荐阅读
  

闽ICP备14008679号