赞
踩
Learning to Upsample by Learning to Sample
尽管最近的基于内核的动态上采样器如CARAFE、FADE和SAPA取得了令人印象深刻的性能提升,但它们引入了大量的工作量,主要是由于时间消耗大的动态卷积和用于生成动态内核的额外子网络。 此外,FADE和SAPA对高分辨率特征的需求在一定程度上限制了它们的应用场景。为了解决这些问题,研究人员绕过了动态卷积,并从点采样的角度来表述上采样,这更加节省资源并可以用PyTorch中的标准内置函数轻松实现。与之前的基于内核的动态上采样相比,DySample不需要自定义的CUDA包,并且参数、FLOPs、GPU内存和延迟都要少得多。除了轻量级的特点之外,DySample在五个密集预测任务(语义分割、目标检测、实例分割、全景分割和单目深度估计)中都优于其他上采样器。DySample的应用领域也更广泛,可以适用于各类图像处理任务,有效提升图像处理的效率和质量。
论文地址:Learning to Upsample by Learning to Sample
代码地址:dysample/dysample.py at main · tiny-smart/dysample · GitHub
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
-
- def normal_init(module, mean=0, std=1, bias=0):
- if hasattr(module, 'weight') and module.weight is not None:
- nn.init.normal_(module.weight, mean, std)
- if hasattr(module, 'bias') and module.bias is not None:
- nn.init.constant_(module.bias, bias)
-
-
- def constant_init(module, val, bias=0):
- if hasattr(module, 'weight') and module.weight is not None:
- nn.init.constant_(module.weight, val)
- if hasattr(module, 'bias') and module.bias is not None:
- nn.init.constant_(module.bias, bias)
-
-
- class DySample(nn.Module):
- def __init__(self, in_channels, scale=2, style='lp', groups=4, dyscope=False):
- super().__init__()
- self.scale = scale
- self.style = style
- self.groups = groups
- assert style in ['lp', 'pl']
- if style == 'pl':
- assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
- assert in_channels >= groups and in_channels % groups == 0
-
- if style == 'pl':
- in_channels = in_channels // scale ** 2
- out_channels = 2 * groups
- else:
- out_channels = 2 * groups * scale ** 2
-
- self.offset = nn.Conv2d(in_channels, out_channels, 1)
- normal_init(self.offset, std=0.001)
- if dyscope:
- self.scope = nn.Conv2d(in_channels, out_channels, 1)
- constant_init(self.scope, val=0.)
-
- self.register_buffer('init_pos', self._init_pos())
-
- def _init_pos(self):
- h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
- return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)
-
- def sample(self, x, offset):
- B, _, H, W = offset.shape
- offset = offset.view(B, 2, -1, H, W)
- coords_h = torch.arange(H) + 0.5
- coords_w = torch.arange(W) + 0.5
- coords = torch.stack(torch.meshgrid([coords_w, coords_h])
- ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
- normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
- coords = 2 * (coords + offset) / normalizer - 1
- coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view(
- B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
- return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear',
- align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W)
-
- def forward_lp(self, x):
- if hasattr(self, 'scope'):
- offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos
- else:
- offset = self.offset(x) * 0.25 + self.init_pos
- return self.sample(x, offset)
-
- def forward_pl(self, x):
- x_ = F.pixel_shuffle(x, self.scale)
- if hasattr(self, 'scope'):
- offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_pos
- else:
- offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_pos
- return self.sample(x, offset)
-
- def forward(self, x):
- if self.style == 'pl':
- return self.forward_pl(x)
- return self.forward_lp(x)
-
-
- if __name__ == '__main__':
- x = torch.rand(2, 64, 4, 7)
- dys = DySample(64)
- print(dys(x).shape)
yolo中parse_model插入部分:
- elif m in (DySample,):
- args.insert(0, ch[f])
yaml配置文件部分:
- # YOLOv9
- # Powered bu https://blog.csdn.net/StopAndGoyyy
- # parameters
- nc: 11 # number of classes
- #depth_multiple: 0.33 # model depth multiple
- depth_multiple: 1 # model depth multiple
- #width_multiple: 0.25 # layer channel multiple
- width_multiple: 1 # layer channel multiple
- #activation: nn.LeakyReLU(0.1)
- #activation: nn.ReLU()
-
- # anchors
- anchors: 3
-
- # YOLOv9 backbone
- backbone:
- [
- [-1, 1, Silence, []],
-
- # conv down
- [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
-
- # conv down
- [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
-
- # elan-1 block
- [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 3
-
- # avg-conv down
- [-1, 1, ADown, [256]], # 4-P3/8
-
- # elan-2 block
- [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 5
-
- # avg-conv down
- [-1, 1, ADown, [512]], # 6-P4/16
-
- # elan-2 block
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 7
-
- # avg-conv down
- [-1, 1, ADown, [512]], # 8-P5/32
-
- # elan-2 block
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 9
- ]
-
- # YOLOv9 head
- head:
- [
- # elan-spp block
- [-1, 1, SPPELAN, [512, 256]], # 10
-
- # up-concat merge
- #[-1, 1, nn.Upsample, [None, 2, 'nearest']],
- [-1, 1, DySample, []],
- [[-1, 7], 1, Concat, [1]], # cat backbone P4
-
- # elan-2 block
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 13
-
- # up-concat merge
- [-1, 1, DySample, []],
- [[-1, 5], 1, Concat, [1]], # cat backbone P3
-
- # elan-2 block
- [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 16 (P3/8-small)
-
- # avg-conv-down merge
- [-1, 1, ADown, [256]],
- [[-1, 13], 1, Concat, [1]], # cat head P4
-
- # elan-2 block
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 19 (P4/16-medium)
-
- # avg-conv-down merge
- [-1, 1, ADown, [512]],
- [[-1, 10], 1, Concat, [1]], # cat head P5
-
- # elan-2 block
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 22 (P5/32-large)
-
-
- # multi-level reversible auxiliary branch
-
- # routing
- [5, 1, CBLinear, [[256]]], # 23
- [7, 1, CBLinear, [[256, 512]]], # 24
- [9, 1, CBLinear, [[256, 512, 512]]], # 25
-
- # conv down
- [0, 1, Conv, [64, 3, 2]], # 26-P1/2
-
- # conv down
- [-1, 1, Conv, [128, 3, 2]], # 27-P2/4
-
- # elan-1 block
- [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 28
-
- # avg-conv down fuse
- [-1, 1, ADown, [256]], # 29-P3/8
- [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30
-
- # elan-2 block
- [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 31
-
- # avg-conv down fuse
- [-1, 1, ADown, [512]], # 32-P4/16
- [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33
-
- # elan-2 block
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 34
-
- # avg-conv down fuse
- [-1, 1, ADown, [512]], # 35-P5/32
- [[25, -1], 1, CBFuse, [[2]]], # 36
-
- # elan-2 block
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 37
-
-
-
- # detection head
-
- # detect
- [[31, 34, 37, 16, 19, 22], 1, DualDDetect, [nc]], # DualDDetect(A3, A4, A5, P3, P4, P5)
- ]
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。