赞
踩
论文链接:[2403.19967] Rewrite the Stars
github仓库:GitHub - ma-xu/Rewrite-the-Stars: [CVPR 2024] Rewrite the Stars
CVPR2024 Rewrite the Stars论文揭示了star operation
(元素乘法)在无需加宽网络下,将输入映射到高维非线性特征空间的能力。基于此提出了StarNet
,在紧凑的网络结构和较低的能耗下展示了令人印象深刻的性能和低延迟。
高维和非线性特征变换 (High-Dimensional and Non-Linear Feature Transformation)
高效网络设计 (Efficient Network Design)
多层次隐式特征扩展 (Multi-Layer Implicit Feature Expansion)
计算复杂度与性能的平衡 (Balance Between Computational Complexity and Performance)
特征表示的丰富性 (Richness of Feature Representation)
简化网络设计 (Simplified Network Design)
1. 在上文提到的仓库中下载imagenet/starnet.py
2. 修改starnet.py中的forward函数,并且添加out_dices参数使其能够输出不同stage的特征向量
3. 将class StarNet注册并且在__init__()函数中进行修改
4. 修改配置文件,主要是调整YOLOv5 neck和head的输入输出通道数
- """
- Implementation of Prof-of-Concept Network: StarNet.
- We make StarNet as simple as possible [to show the key contribution of element-wise multiplication]:
- - like NO layer-scale in network design,
- - and NO EMA during training,
- - which would improve the performance further.
- Created by: Xu Ma (Email: ma.xu1@northeastern.edu)
- Modified Date: Mar/29/2024
- """
- import torch
- import torch.nn as nn
- from timm.models.layers import DropPath, trunc_normal_
- from typing import List, Sequence, Union
-
- # from timm.models.registry import register_model
- from mmyolo.registry import MODELS
-
- model_urls = {
- "starnet_s1": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar",
- "starnet_s2": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar",
- "starnet_s3": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar",
- "starnet_s4": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar",
- }
-
-
- class ConvBN(torch.nn.Sequential):
- def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, with_bn=True):
- super().__init__()
- self.add_module('conv', torch.nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation, groups))
- if with_bn:
- self.add_module('bn', torch.nn.BatchNorm2d(out_planes))
- torch.nn.init.constant_(self.bn.weight, 1)
- torch.nn.init.constant_(self.bn.bias, 0)
-
-
- class Block(nn.Module):
- def __init__(self, dim, mlp_ratio=3, drop_path=0.):
- super().__init__()
- self.dwconv = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=True)
- self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
- self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
- self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True)
- self.dwconv2 = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=False)
- self.act = nn.ReLU6()
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
-
- def forward(self, x):
- input = x
- x = self.dwconv(x)
- x1, x2 = self.f1(x), self.f2(x)
- x = self.act(x1) * x2
- x = self.dwconv2(self.g(x))
- x = input + self.drop_path(x)
- return x
-
- @MODELS.register_module()
- class StarNet(nn.Module):
- def __init__(self, base_dim=32, out_indices: Sequence[int] = (0, 1, 2), depths=[3, 3, 12, 5], mlp_ratio=4,
- drop_path_rate=0.0, num_classes=1000, **kwargs):
- super().__init__()
- self.num_classes = num_classes
- self.in_channel = 32
- self.out_indices = out_indices
- self.depths = depths
- # stem layer
- self.stem = nn.Sequential(ConvBN(3, self.in_channel, kernel_size=3, stride=2, padding=1), nn.ReLU6())
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth
- # build stages
- self.stages = nn.ModuleList()
- cur = 0
- for i_layer in range(len(depths)):
- embed_dim = base_dim * 2 ** i_layer
- down_sampler = ConvBN(self.in_channel, embed_dim, 3, 2, 1)
- self.in_channel = embed_dim
- blocks = [Block(self.in_channel, mlp_ratio, dpr[cur + i]) for i in range(depths[i_layer])]
- cur += depths[i_layer]
- self.stages.append(nn.Sequential(down_sampler, *blocks))
- # head
- # self.norm = nn.BatchNorm2d(self.in_channel)
- # self.avgpool = nn.AdaptiveAvgPool2d(1)
- # self.head = nn.Linear(self.in_channel, num_classes)
- # self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear or nn.Conv2d):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm or nn.BatchNorm2d):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- def forward(self, x):
- x = self.stem(x)
- ##记录stage的输出
- outs = []
-
- for i in range(len(self.depths)):
- x = self.stages[i](x)
- if i in self.out_indices:
- outs.append(x)
-
- return tuple(outs)
-
-
- @MODELS.register_module()
- def starnet_s1(pretrained=False, **kwargs):
- model = StarNet(24, (0, 1, 2), [2, 2, 8, 3], **kwargs)
- if pretrained:
- url = model_urls['starnet_s1']
- checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
- model.load_state_dict(checkpoint["state_dict"])
- return model
-
-
- @MODELS.register_module()
- def starnet_s2(pretrained=False, **kwargs):
- model = StarNet(32, (0, 1, 2), [1, 2, 6, 2], **kwargs)
- if pretrained:
- url = model_urls['starnet_s2']
- checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
- model.load_state_dict(checkpoint["state_dict"])
- return model
-
-
- @MODELS.register_module()
- def starnet_s3(pretrained=False, **kwargs):
- model = StarNet(32, (0, 1, 2), [2, 2, 8, 4], **kwargs)
- if pretrained:
- url = model_urls['starnet_s3']
- checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
- model.load_state_dict(checkpoint["state_dict"])
- return model
-
-
- @MODELS.register_module()
- def starnet_s4(pretrained=False, **kwargs):
- model = StarNet(32, (0, 1, 2), [3, 3, 12, 5], **kwargs)
- if pretrained:
- url = model_urls['starnet_s4']
- checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
- model.load_state_dict(checkpoint["state_dict"])
- return model
-
-
- # very small networks #
- @MODELS.register_module()
- def starnet_s050(pretrained=False, **kwargs):
- return StarNet(16, (0, 1, 2), [1, 1, 3, 1], 3, **kwargs)
-
-
- @MODELS.register_module()
- def starnet_s100(pretrained=False, **kwargs):
- return StarNet(20, (0, 1, 2), [1, 2, 4, 1], 4, **kwargs)
-
-
- @MODELS.register_module()
- def starnet_s150(pretrained=False, **kwargs):
- return StarNet(24, (0, 1, 2), [1, 2, 4, 2], 3, **kwargs)
-
- if __name__ == '__main__':
- model = StarNet()
- input_tensor = torch.randn(1, 3, 224, 224)
- outputs = model(input_tensor)
- # Copyright (c) OpenMMLab. All rights reserved.
- from .base_backbone import BaseBackbone
- from .csp_darknet import YOLOv5CSPDarknet, YOLOv8CSPDarknet, YOLOXCSPDarknet
- from .csp_resnet import PPYOLOECSPResNet
- from .cspnext import CSPNeXt
- from .efficient_rep import YOLOv6CSPBep, YOLOv6EfficientRep
- from .yolov7_backbone import YOLOv7Backbone
- from .starnet import StarNet
- __all__ = [
- 'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep',
- 'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet',
- 'YOLOv8CSPDarknet','StarNet'
- ]
- _base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']
-
- # ========================Frequently modified parameters======================
- # -----data related-----
- data_root = 'data/coco/' # Root path of data
- # Path of train annotation file
- train_ann_file = 'annotations/instances_train2017.json'
- train_data_prefix = 'train2017/' # Prefix of train image path
- # Path of val annotation file
- val_ann_file = 'annotations/instances_val2017.json'
- val_data_prefix = 'val2017/' # Prefix of val image path
-
- num_classes = 80 # Number of classes for classification
- # Batch size of a single GPU during training
- train_batch_size_per_gpu = 16
- # Worker to pre-fetch data for each single GPU during training
- train_num_workers = 8
- # persistent_workers must be False if num_workers is 0
- persistent_workers = True
-
- # -----model related-----
- # Basic size of multi-scale prior box
- anchors = [
- [(10, 13), (16, 30), (33, 23)], # P3/8
- [(30, 61), (62, 45), (59, 119)], # P4/16
- [(116, 90), (156, 198), (373, 326)] # P5/32
- ]
-
- # -----train val related-----
- # Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs
- base_lr = 0.01
- max_epochs = 300 # Maximum training epochs
-
- model_test_cfg = dict(
- # The config of multi-label for multi-class prediction.
- multi_label=True,
- # The number of boxes before NMS
- nms_pre=30000,
- score_thr=0.001, # Threshold to filter out boxes.
- nms=dict(type='nms', iou_threshold=0.65), # NMS type and threshold
- max_per_img=300) # Max number of detections of each image
-
- # ========================Possible modified parameters========================
- # -----data related-----
- img_scale = (640, 640) # width, height
- # Dataset type, this will be used to define the dataset
- dataset_type = 'YOLOv5CocoDataset'
- # Batch size of a single GPU during validation
- val_batch_size_per_gpu = 1
- # Worker to pre-fetch data for each single GPU during validation
- val_num_workers = 2
-
- # Config of batch shapes. Only on val.
- # It means not used if batch_shapes_cfg is None.
- batch_shapes_cfg = dict(
- type='BatchShapePolicy',
- batch_size=val_batch_size_per_gpu,
- img_size=img_scale[0],
- # The image scale of padding should be divided by pad_size_divisor
- size_divisor=32,
- # Additional paddings for pixel scale
- extra_pad_ratio=0.5)
-
- # -----model related-----
- # The scaling factor that controls the depth of the network structure
- deepen_factor = 0.33
- # The scaling factor that controls the width of the network structure
- widen_factor = 0.5
- # Strides of multi-scale prior box
- strides = [8, 16, 32]
- num_det_layers = 3 # The number of model output scales
- norm_cfg = dict(type='BN', momentum=0.03, eps=0.001) # Normalization config
-
- # -----train val related-----
- affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio
- loss_cls_weight = 0.5
- loss_bbox_weight = 0.05
- loss_obj_weight = 1.0
- prior_match_thr = 4. # Priori box matching threshold
- # The obj loss weights of the three output layers
- obj_level_weights = [4., 1., 0.4]
- lr_factor = 0.01 # Learning rate scaling factor
- weight_decay = 0.0005
- # Save model checkpoint and validation intervals
- save_checkpoint_intervals = 10
- # The maximum checkpoints to keep.
- max_keep_ckpts = 3
- # Single-scale training is recommended to
- # be turned on, which can speed up training.
- env_cfg = dict(cudnn_benchmark=True)
- '''
- starnet_channel,base_dim,depths,mlp_ratio
- s1:24,[48, 96, 192],[2, 2, 8, 3],4
- s2:32,[64, 128, 256],[1, 2, 6, 2],4
- s3:32,[64, 128, 256],[2, 2, 8, 4],4
- s4:32,[64, 128, 256],[3, 3, 12, 5],4
- starnet_s050:16,[32,64,128],[1, 1, 3, 1],3
- starnet_s0100:20,[40, 80, 120],[1, 2, 4, 1],4
- starnet_s150:24,[48, 96, 192],[1, 2, 4, 2],3
- '''
- starnet_channel=[48, 96, 192]
- depths=[1, 2, 6, 2]
- # ===============================Unmodified in most cases====================
- model = dict(
- type='YOLODetector',
- data_preprocessor=dict(
- type='mmdet.DetDataPreprocessor',
- mean=[0., 0., 0.],
- std=[255., 255., 255.],
- bgr_to_rgb=True),
- backbone=dict(
- ##s1
- type='StarNet',
- base_dim=24,
- out_indices=(0,1,2),
- depths=depths,
- mlp_ratio=4,
- num_classes=num_classes,
- # deepen_factor=deepen_factor,
- # widen_factor=widen_factor,
- # norm_cfg=norm_cfg,
- # act_cfg=dict(type='SiLU', inplace=True)
-
- ),
-
- neck=dict(
- type='YOLOv5PAFPN',
- deepen_factor=deepen_factor,
- widen_factor=widen_factor,
- in_channels=starnet_channel,
- out_channels=starnet_channel,
- num_csp_blocks=3,
- norm_cfg=norm_cfg,
- act_cfg=dict(type='SiLU', inplace=True)),
- bbox_head=dict(
- type='YOLOv5Head',
- head_module=dict(
- type='YOLOv5HeadModule',
- num_classes=num_classes,
- in_channels=starnet_channel,
- widen_factor=widen_factor,
- featmap_strides=strides,
- num_base_priors=3),
- prior_generator=dict(
- type='mmdet.YOLOAnchorGenerator',
- base_sizes=anchors,
- strides=strides),
- # scaled based on number of detection layers
- loss_cls=dict(
- type='mmdet.CrossEntropyLoss',
- use_sigmoid=True,
- reduction='mean',
- loss_weight=loss_cls_weight *
- (num_classes / 80 * 3 / num_det_layers)),
- # 修改此处实现IoU损失函数的替换
- loss_bbox=dict(
- type='IoULoss',
- focal=True,
- iou_mode='ciou',
- bbox_format='xywh',
- eps=1e-7,
- reduction='mean',
- loss_weight=loss_bbox_weight * (3 / num_det_layers),
- return_iou=True),
- loss_obj=dict(
- type='mmdet.CrossEntropyLoss',
- use_sigmoid=True,
- reduction='mean',
- loss_weight=loss_obj_weight *
- ((img_scale[0] / 640) ** 2 * 3 / num_det_layers)),
- prior_match_thr=prior_match_thr,
- obj_level_weights=obj_level_weights),
- test_cfg=model_test_cfg)
-
- albu_train_transforms = [
- dict(type='Blur', p=0.01),
- dict(type='MedianBlur', p=0.01),
- dict(type='ToGray', p=0.01),
- dict(type='CLAHE', p=0.01)
- ]
-
- pre_transform = [
- dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
- dict(type='LoadAnnotations', with_bbox=True)
- ]
-
- train_pipeline = [
- *pre_transform,
- dict(
- type='Mosaic',
- img_scale=img_scale,
- pad_val=114.0,
- pre_transform=pre_transform),
- dict(
- type='YOLOv5RandomAffine',
- max_rotate_degree=0.0,
- max_shear_degree=0.0,
- scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
- # img_scale is (width, height)
- border=(-img_scale[0] // 2, -img_scale[1] // 2),
- border_val=(114, 114, 114)),
- dict(
- type='mmdet.Albu',
- transforms=albu_train_transforms,
- bbox_params=dict(
- type='BboxParams',
- format='pascal_voc',
- label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),
- keymap={
- 'img': 'image',
- 'gt_bboxes': 'bboxes'
- }),
- dict(type='YOLOv5HSVRandomAug'),
- dict(type='mmdet.RandomFlip', prob=0.5),
- dict(
- type='mmdet.PackDetInputs',
- meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
- 'flip_direction'))
- ]
-
- train_dataloader = dict(
- batch_size=train_batch_size_per_gpu,
- num_workers=train_num_workers,
- persistent_workers=persistent_workers,
- pin_memory=True,
- sampler=dict(type='DefaultSampler', shuffle=True),
- dataset=dict(
- type=dataset_type,
- data_root=data_root,
- ann_file=train_ann_file,
- data_prefix=dict(img=train_data_prefix),
- filter_cfg=dict(filter_empty_gt=False, min_size=32),
- pipeline=train_pipeline))
-
- test_pipeline = [
- dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
- dict(type='YOLOv5KeepRatioResize', scale=img_scale),
- dict(
- type='LetterResize',
- scale=img_scale,
- allow_scale_up=False,
- pad_val=dict(img=114)),
- dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
- dict(
- type='mmdet.PackDetInputs',
- meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
- 'scale_factor', 'pad_param'))
- ]
-
- val_dataloader = dict(
- batch_size=val_batch_size_per_gpu,
- num_workers=val_num_workers,
- persistent_workers=persistent_workers,
- pin_memory=True,
- drop_last=False,
- sampler=dict(type='DefaultSampler', shuffle=False),
- dataset=dict(
- type=dataset_type,
- data_root=data_root,
- test_mode=True,
- data_prefix=dict(img=val_data_prefix),
- ann_file=val_ann_file,
- pipeline=test_pipeline,
- batch_shapes_cfg=batch_shapes_cfg))
-
- test_dataloader = val_dataloader
-
- param_scheduler = None
- optim_wrapper = dict(
- type='OptimWrapper',
- optimizer=dict(
- type='SGD',
- lr=base_lr,
- momentum=0.937,
- weight_decay=weight_decay,
- nesterov=True,
- batch_size_per_gpu=train_batch_size_per_gpu),
- constructor='YOLOv5OptimizerConstructor')
-
- default_hooks = dict(
- param_scheduler=dict(
- type='YOLOv5ParamSchedulerHook',
- scheduler_type='linear',
- lr_factor=lr_factor,
- max_epochs=max_epochs),
- checkpoint=dict(
- type='CheckpointHook',
- interval=save_checkpoint_intervals,
- save_best='auto',
- max_keep_ckpts=max_keep_ckpts))
-
- custom_hooks = [
- dict(
- type='EMAHook',
- ema_type='ExpMomentumEMA',
- momentum=0.0001,
- update_buffers=True,
- strict_load=False,
- priority=49)
- ]
-
- val_evaluator = dict(
- type='mmdet.CocoMetric',
- proposal_nums=(100, 1, 10),
- ann_file=data_root + val_ann_file,
- metric='bbox')
- test_evaluator = val_evaluator
-
- train_cfg = dict(
- type='EpochBasedTrainLoop',
- max_epochs=max_epochs,
- val_interval=save_checkpoint_intervals)
- val_cfg = dict(type='ValLoop')
- test_cfg = dict(type='TestLoop')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。