赞
踩
动态蛇形卷积(DSCONV)是一种深度学习中的卷积神经网络(CNN)技术,它旨在处理卷积操作中的某些限制,以提高网络处理不规则数据的能力。
适用场景:血管、道路等复杂网状管状结构
传统卷积的限制:
标准的CNN卷积层通常在处理规则、网格状数据(如图像)时效果最佳。然而,在面对不规则或扭曲的数据结构时,它们可能不那么有效。
不规则数据的挑战:
在实际应用中,如医学图像、三维形状处理或自然语言处理等领域,数据常常呈现不规则形态,这给传统的CNN带来了挑战。
动态调整:
DSCONV通过动态调整卷积核的形状和大小来适应数据的不规则性,这种调整是根据数据的特定特征和结构进行的。
灵活性:
这种方法比传统的卷积操作更加灵活,能更好地捕捉不规则数据中的局部特征。
1.yolov8的主体代码来自 官方仓库
2.在ultralytics-main\ultralytics\nn\modules
目录下,新增DySnakeConv.py
文件
DySnakeConv.py
import torch
import torch.nn as nn
def autopad(k, p=None):
if p is None:
p = k//2 if isinstance(k,int) else [x//2 for x in k] #auto-pad
return p
class Conv(nn.Module):
#Standard convolution
def __init__(self,c1,c2,k=1,s=1,p=None,g=1,act=True): #ch_in,ch_out,kernel,stride,padding,groups
super(Conv,self).__init__()
self.conv = nn.Conv2d(c1,c2,k,s,autopad(k,p),groups=g,bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU() if act is True else (act if isinstance(act,nn.Module) else nn.Identity())
def forward(self,x):
return self.act(self.bn(self.conv(x)))
def fuseforward(self,x):
return self.act(self.conv(x))
class DySnakeConv(nn.Module):
def __init__(self, inc, ouc, k=3, act=True) -> None:
super().__init__()
self.conv_0 = Conv(inc, ouc, k, act=act)
self.conv_x = DSConv(inc, ouc, 0, k)
self.conv_y = DSConv(inc, ouc, 1, k)
self.conv_1x1 = Conv(ouc * 3, ouc, 1, act=act)
def forward(self, x):
return self.conv_1x1(torch.cat([self.conv_0(x), self.conv_x(x), self.conv_y(x)], dim=1))
class DSConv(nn.Module):
def __init__(self,in_ch,out_ch,morph,kernel_size=3,if_offset=True,extend_scope=1,act=True):
"""
The Dynamic Snake Convolution
:param in_ch: input channel
:param out_ch: output channel
:param kernel_size: the size of kernel
:param extend_scope: the range to expand (default 1 for this method)
:param morph: the morphology of the convolution kernel is mainly divided into two types
along the x-axis (0) and the y-axis (1) (see the paper for details)
:param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel
:param device: set on gpu
"""
super(DSConv,self).__init__()
#use the <offset_conv>to learn the deformable offset
self.offset_conv = nn.Conv2d(in_ch,2 * kernel_size,3,padding=1)
self.bn = nn.BatchNorm2d(2 * kernel_size)
self.kernel_size = kernel_size
#two types of the DSConv (along.x-axis and y-axis)
self.dsc_conv_x = nn.Conv2d(
in_ch,
out_ch,
kernel_size=(kernel_size,1),
stride=(kernel_size,1),
padding=0,
)
self.dsc_conv_y = nn.Conv2d(
in_ch,
out_ch,
kernel_size=(1,kernel_size),
stride=(1,kernel_size),
padding=0,
)
self.gn = nn.GroupNorm(out_ch // 4, out_ch)
self.act = nn.SiLU() if act is True else (act if isinstance(act,nn.Module)else nn.Identity())
self.extend_scope = extend_scope
self.morph = morph
self.if_offset = if_offset
def forward(self, f):
offset = self.offset_conv(f)
offset = self.bn(offset)
offset = torch.tanh(offset)
input_shape = f.shape
dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph)
deformed_feature = dsc.deform_conv(f, offset, self.if_offset)
if self.morph == 0:
x = self.dsc_conv_x(deformed_feature.type(f.dtype))
x = self.gn(x)
x = self.act(x)
return x
else:
x = self.dsc_conv_y(deformed_feature.type(f.dtype))
x = self.gn(x)
x = self.act(x)
return x
class DSC(object):
def __init__(self,input_shape,kernel_size,extend_scope,morph):
self.num_points = kernel_size
self.width = input_shape[2]
self.height = input_shape[3]
self.morph = morph
self.extend_scope = extend_scope
self.num_batch = input_shape[0]
self.num_channels = input_shape[1]
def _coordinate_map_3D(self, offset, if_offset):
device = offset.device
# offset
y_offset, x_offset = torch.split(offset, self.num_points, dim=1)
y_center = torch.arange(0, self.width).repeat([self.height])
y_center = y_center.reshape(self.height, self.width)
y_center = y_center.permute(1, 0)
y_center = y_center.reshape([-1, self.width, self.height])
y_center = y_center.repeat([self.num_points, 1,1]).float()
y_center = y_center.unsqueeze(0)
x_center = torch.arange(0, self.height).repeat([self.width])
x_center = x_center.reshape(self.width, self.height)
x_center = x_center.permute(0, 1)
x_center = x_center.reshape([-1, self.width, self.height])
x_center = x_center.repeat([self.num_points, 1, 1]).float()
x_center = x_center.unsqueeze(0)
if self.morph == 0:
"""
Initialize the kernel and flatten the kernel
y:only need 0
x:-num_points//2 num_points//2 (Determined by the kernel size)
!!The related PPT will be submitted later,and the PPT will contain the whole changes of each step
"""
y = torch.linspace(0, 0, 1)
x = torch.linspace(
-int(self.num_points // 2),
int(self.num_points // 2),
int(self.num_points),
)
y,x = torch.meshgrid(y,x)
y_spread = y.reshape(-1, 1)
x_spread = x.reshape(-1, 1)
y_grid = y_spread.repeat([1, self.width * self.height])
y_grid = y_grid.reshape([self.num_points, self.width, self.height])
y_grid = y_grid.unsqueeze(0) #[B * K * K, W, H]
x_grid = x_spread.repeat([1, self.width * self.height])
x_grid = x_grid.reshape([self.num_points, self.width, self.height])
x_grid = x_grid.unsqueeze(0) # [B*K*K,W,H]
y_new = y_center + y_grid
x_new = x_center + x_grid
y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(device)
x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(device)
y_offset_new = y_offset.detach().clone()
if if_offset:
y_offset = y_offset.permute(1,0, 2, 3)
y_offset_new = y_offset_new.permute(1, 0, 2, 3)
center = int(self.num_points // 2)
y_offset_new[center] = 0
for index in range(1, center):
y_offset_new[center+index] = (y_offset_new[center + index - 1] + y_offset[center + index])
y_offset_new[center-index] = (y_offset_new[center - index + 1] + y_offset[center - index])
y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(device)
y_new = y_new.add(y_offset_new.mul(self.extend_scope))
y_new = y_new.reshape(
[self.num_batch, self.num_points, 1, self.width, self.height])
y_new = y_new.permute(0, 3, 1, 4, 2)
y_new = y_new.reshape([
self.num_batch, self.num_points * self.width, 1 * self.height
])
x_new = x_new.reshape(
[self.num_batch, self.num_points,1,self.width,self.height])
x_new = x_new.permute(0, 3, 1, 4, 2)
x_new = x_new.reshape([
self.num_batch, self.num_points * self.width, 1 * self.height
])
return y_new,x_new
else:
y = torch.linspace(
-int(self.num_points // 2),
int(self.num_points // 2),
int(self.num_points),
)
x = torch.linspace(0, 0, 1)
y,x = torch.meshgrid(y,x)
y_spread = y.reshape(-1, 1)
x_spread = x.reshape(-1, 1)
y_grid = y_spread.repeat([1, self.width * self.height])
y_grid = y_grid.reshape([self.num_points, self.width, self.height])
y_grid = y_grid.unsqueeze(0)
x_grid = x_spread.repeat([1, self.width * self.height])
x_grid = x_grid.reshape([self.num_points, self.width, self.height])
x_grid = x_grid.unsqueeze(0)
y_new = y_center + y_grid
x_new = x_center + x_grid
y_new = y_new.repeat(self.num_batch, 1, 1, 1)
x_new = x_new.repeat(self.num_batch, 1, 1, 1)
y_new = y_new.to(device)
x_new = x_new.to(device)
x_offset_new = x_offset.detach().clone()
if if_offset:
x_offset = x_offset.permute(1, 0, 2, 3)
x_offset_new = x_offset_new.permute(1, 0, 2, 3)
center = int(self.num_points // 2)
x_offset_new[center] = 0
for index in range(1, center):
x_offset_new[center + index] = (x_offset_new[center + index - 1] + y_offset[center + index])
x_offset_new[center - index] = (x_offset_new[center - index + 1] + y_offset[center - index])
x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(device)
x_new = y_new.add(x_offset_new.mul(self.extend_scope))
y_new = y_new.reshape(
[self.num_batch, 1, self.num_points, self.width, self.height])
y_new = y_new.permute(0, 3, 1, 4, 2)
y_new = y_new.reshape([
self.num_batch, 1 * self.width, self.num_points * self.height
])
x_new = x_new.reshape(
[self.num_batch, 1, self.num_points, self.width, self.height])
x_new = x_new.permute(0, 3, 1, 4, 2)
x_new = x_new.reshape([
self.num_batch, 1 * self.width, self.num_points * self.height
])
return y_new, x_new
"""
input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H]
output: [N,1,K*D,K*W,K*H] deformed feature map
"""
def _bilinear_interpolate_3D(self, input_feature, y, x):
device = input_feature.device
y = y.reshape([-1]).float()
x = x.reshape([-1]).float()
zero = torch.zeros([]).int()
max_y = self.width - 1
max_x = self.height - 1
# find 8 grid locations
y0 = torch.floor(y).int()
y1 = y0 + 1
x0 = torch.floor(x).int()
x1 = x0 + 1
# clip out coordinates exceeding feature map volume
y0 = torch.clamp(y0, zero, max_y)
y1 = torch.clamp(y1, zero, max_y)
x0 = torch.clamp(x0, zero, max_x)
x1 = torch.clamp(x1, zero, max_x)
input_feature_flat = input_feature.flatten()
input_feature_flat = input_feature_flat.reshape(
self.num_batch, self.num_channels, self.width, self.height)
input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
dimension = self.height * self.width
base = torch.arange(self.num_batch) * dimension
base = base.reshape([-1, 1]).float()
repeat = torch.ones([self.num_points * self.width * self.height
]).unsqueeze(0)
repeat = repeat.float()
base = torch.matmul(base, repeat)
base = base.reshape([-1])
base = base.to(device)
base_y0 = base + y0 * self.height
base_y1 = base + y1 * self.height
# top rectangle of the neighbourhood volume
index_a0 = base_y0 - base + x0
index_c0 = base_y0 - base + x1
# bottom rectangle of the neighbourhood volume
index_a1 = base_y1 - base + x0
index_c1 = base_y1 - base + x1
# get 8 grid values
value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device)
value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device)
value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device)
value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device)
# find 8 grid locations
y0 = torch.floor(y).int()
y1 = y0 + 1
x0 = torch.floor(x).int()
x1 = x0 + 1
# clip out coordinates exceeding feature map volume
y0 = torch.clamp(y0, zero, max_y)
y1 = torch.clamp(y1, zero, max_y)
x0 = torch.clamp(x0, zero, max_x)
x1 = torch.clamp(x1, zero, max_x)
input_feature_flat = input_feature.flatten()
input_feature_flat = input_feature_flat.reshape(
self.num_batch, self.num_channels, self.width, self.height)
input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
dimension = self.height * self.width
base = torch.arange(self.num_batch) * dimension
base = base.reshape([-1, 1]).float()
repeat = torch.ones([self.num_points * self.width * self.height
]).unsqueeze(0)
repeat = repeat.float()
base = torch.matmul(base, repeat)
base = base.reshape([-1])
base = base.to(device)
base_y0 = base + y0 * self.height
base_y1 = base + y1 * self.height
# top rectangle of the neighbourhood volume
index_a0 = base_y0 - base + x0
index_c0 = base_y0 - base + x1
# bottom rectangle of the neighbourhood volume
index_a1 = base_y1 - base + x0
index_c1 = base_y1 - base + x1
# get 8 grid values
value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device)
value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device)
value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device)
value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device)
# find 8 grid locations
y0 = torch.floor(y).int()
y1 = y0 + 1
x0 = torch.floor(x).int()
x1 = x0 + 1
# clip out coordinates exceeding feature map volume
y0 = torch.clamp(y0, zero, max_y + 1)
y1 = torch.clamp(y1, zero, max_y + 1)
x0 = torch.clamp(x0, zero, max_x + 1)
x1 = torch.clamp(x1, zero, max_x + 1)
x0_float = x0.float()
x1_float = x1.float()
y0_float = y0.float()
y1_float = y1.float()
vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(device)
vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(device)
vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(device)
vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(device)
outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +
value_c1 * vol_c1)
if self.morph == 0:
outputs = outputs.reshape([
self.num_batch,
self.num_points * self.width,
1 * self.height,
self.num_channels,
])
outputs = outputs.permute(0, 3, 1, 2)
else:
outputs = outputs.reshape([
self.num_batch,
1 * self.width,
self.num_points * self.height,
self.num_channels,
])
outputs = outputs.permute(0, 3, 1, 2)
return outputs
def deform_conv(self, input, offset, if_offset):
y, x = self._coordinate_map_3D(offset, if_offset)
deformed_feature = self._bilinear_interpolate_3D(input, y, x)
return deformed_feature
if __name__ == '__main__':
input = torch.randn(1,128,8,8)
dsconv = DySnakeConv(128,256)
output = dsconv(input)
print(output.shape)
3.修改block.py
配置
(a)新增C2f_DySnake
类
class C2f_DySnakeConv(nn.Module):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
expansion.
"""
super().__init__()
self.c = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
self.m = nn.ModuleList(Bottleneck_DySnakeConv(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
def forward(self, x):
"""Forward pass through C2f layer."""
y = list(self.cv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
def forward_split(self, x):
"""Forward pass using split() instead of chunk()."""
y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
(b)新增Bottleneck_DySnakeConv
类
class Bottleneck_DySnakeConv(nn.Module):
def __init__(self,c1,c2,shortcut=True,g=1,k=(3,3),e=0.5):
super().__init__()
c_ = int(c2 * e)
self.cv1 = DySnakeConv(c1,c_)
self.cv2 = DySnakeConv(c_,c2)
self.add = shortcut and c1==c2
def forward(self,x):
return x+self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
©import自定义模块DySnakeConv
from ultralytics.nn.modules.DySnakeConv import DySnakeConv
3.导入到task.py
文件
(a)修改\ultralytics-main\ultralytics\nn\tasks.py
文件
from ultralytics.nn.modules.block import Bottleneck_DySnakeConv,C2f_DySnakeConv
(b)修改parse_model
函数
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
"""Parse a YOLO model.yaml dictionary into a PyTorch model."""
import ast
# Args
max_channels = float('inf')
nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales'))
depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
if scales:
scale = d.get('scale')
if not scale:
scale = tuple(scales.keys())[0]
LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
depth, width, max_channels = scales[scale]
if act:
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
if verbose:
LOGGER.info(f"{colorstr('activation:')} {act}") # print
if verbose:
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
ch = [ch]
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain """动态蛇形,增加了C2f_DySnakeConv,增加了Bottleneck_DySnakeConv"""
if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck,Bottleneck_DySnakeConv, GhostBottleneck, SPP, SPPF, DWConv, Focus,
BottleneckCSP, C1, C2,C2f, C2f_DySnakeConv, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(min(c2, max_channels) * width, 8)
args = [c1, c2, *args[1:]]
if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3):
args.insert(2, n) # number of repeats
n = 1
elif m is AIFI:
args = [ch[f], *args]
elif m in (HGStem, HGBlock):
c1, cm, c2 = ch[f], args[0], args[1]
args = [c1, cm, c2, *args[2:]]
if m is HGBlock:
args.insert(4, n) # number of repeats
n = 1
elif m is nn.BatchNorm2d:
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
elif m in (Detect, Segment, Pose):
args.append([ch[x] for x in f])
if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
args.insert(1, [ch[x] for x in f])
else:
c2 = ch[f]
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
t = str(m)[8:-2].replace('__main__.', '') # module type
m.np = sum(x.numel() for x in m_.parameters()) # number params
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
if verbose:
LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
if i == 0:
ch = []
ch.append(c2)
return nn.Sequential(*layers), sorted(save)
4.在yaml文件中用C2f_DySnakeConv
替换C2f
优点:提供更好的灵活性和适应性,在处理复杂、不规则数据时可能比标准CNN更有效。
挑战:实现复杂,计算成本可能更高,需要更多的调优和实验来优化其性能。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。