- class ACmix(nn.Module):
- def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):
- super(ACmix, self).__init__()
- self.in_planes = in_planes
- self.out_planes = out_planes
- self.head = head
- self.kernel_att = kernel_att
- self.kernel_conv = kernel_conv
- self.stride = stride
- self.dilation = dilation
- self.rate1 = torch.nn.Parameter(torch.Tensor(1))
- self.rate2 = torch.nn.Parameter(torch.Tensor(1))
- self.head_dim = self.out_planes // self.head
- self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
- self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
- self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
- self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)
- self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2
- self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)
- self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)
- self.softmax = torch.nn.Softmax(dim=1)
- self.fc = nn.Conv2d(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)
- self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes,
- kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1,
- stride=stride)
- self.reset_parameters()
- def reset_parameters(self):
- init_rate_half(self.rate1)
- init_rate_half(self.rate2)
- kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)
- for i in range(self.kernel_conv * self.kernel_conv):
- kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.
- kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)
- self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)
- self.dep_conv.bias = init_rate_0(self.dep_conv.bias)
- def forward(self, x):
- # 经过1*1的卷积得到q,k,v,同时也是后面进行共享的特征图
- q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)
- # 归一化q*k/根号k
- scaling = float(self.head_dim) ** -0.5
- b, c, h, w = q.shape
- h_out, w_out = h // self.stride, w // self.stride
- # ### att
- # ## positional encoding 位置编码
- pe = self.conv_p(position(h, w, x.is_cuda))
- q_att = q.view(b * self.head, self.head_dim, h, w) * scaling
- k_att = k.view(b * self.head, self.head_dim, h, w)
- v_att = v.view(b * self.head, self.head_dim, h, w)
- if self.stride > 1:
- q_att = stride(q_att, self.stride)
- q_pe = stride(pe, self.stride)
- else:
- q_pe = pe
- # 重构key,得到窗口特征
- unfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head, self.head_dim,
- self.kernel_att * self.kernel_att, h_out,
- w_out) # b*head, head_dim, k_att^2, h_out, w_out
- unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out,
- w_out) # 1, head_dim, k_att^2, h_out, w_out
- att = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(
- 1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)
- att = self.softmax(att)
- out_att = self.unfold(self.pad_att(v_att)).view(b * self.head, self.head_dim, self.kernel_att * self.kernel_att,
- h_out, w_out)
- out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)
- ## conv 共享q,k,v,进行卷积特征提取
- f_all = self.fc(torch.cat(
- [q.view(b, self.head, self.head_dim, h * w), k.view(b, self.head, self.head_dim, h * w),
- v.view(b, self.head, self.head_dim, h * w)], 1))
- f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])
- out_conv = self.dep_conv(f_conv)
- # 特征融合
- return self.rate1 * out_att + self.rate2 * out_conv
- class CB2d(nn.Module):
- def __init__(self, inplanes, pool='att', fusions=['channel_add', 'channel_mul']):
- super(CB2d, self).__init__()
- assert pool in ['avg', 'att']
- assert all([f in ['channel_add', 'channel_mul'] for f in fusions])
- assert len(fusions) > 0, 'at least one fusion should be used'
- self.inplanes = inplanes
- self.planes = inplanes // 4
- self.pool = pool
- self.fusions = fusions
- if 'att' in pool:
- self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
- self.softmax = nn.Softmax(dim=2)
- else:
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- if 'channel_add' in fusions:
- self.channel_add_conv = nn.Sequential(
- nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
- nn.LayerNorm([self.planes, 1, 1]),
- nn.ReLU(inplace=True),
- nn.Conv2d(self.planes, self.inplanes, kernel_size=1)
- )
- else:
- self.channel_add_conv = None
- if 'channel_mul' in fusions:
- self.channel_mul_conv = nn.Sequential(
- nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
- nn.LayerNorm([self.planes, 1, 1]),
- nn.ReLU(inplace=True),
- nn.Conv2d(self.planes, self.inplanes, kernel_size=1)
- )
- else:
- self.channel_mul_conv = None
- self.reset_parameters()
- def reset_parameters(self):
- if self.pool == 'att':
- kaiming_init(self.conv_mask, mode='fan_in')
- self.conv_mask.inited = True
- if self.channel_add_conv is not None:
- last_zero_init(self.channel_add_conv)
- if self.channel_mul_conv is not None:
- last_zero_init(self.channel_mul_conv)
- def spatial_pool(self, x):
- batch, channel, height, width = x.size()
- # 得到图像中各个特征点的权重
- if self.pool == 'att': # iscyy
- input_x = x
- input_x = input_x.view(batch, channel, height * width)
- input_x = input_x.unsqueeze(1)
- # mask即特征点的权重,首先使用卷积,然后使用softmax操作
- context_mask = self.conv_mask(x)
- context_mask = context_mask.view(batch, 1, height * width)
- context_mask = self.softmax(context_mask)
- context_mask = context_mask.unsqueeze(3)
- # 将权重作用到原始特征图
- context = torch.matmul(input_x, context_mask)
- context = context.view(batch, channel, 1, 1)
- else:
- context = self.avg_pool(x)
- return context
- def forward(self, x):
- context = self.spatial_pool(x)
- # 在通道层面学习一个权重和偏置项
- if self.channel_mul_conv is not None:
- channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
- out = x * channel_mul_term
- else:
- out = x
- if self.channel_add_conv is not None:
- channel_add_term = self.channel_add_conv(context)
- out = out + channel_add_term
- return out
- class CA(nn.Module):
- # Coordinate Attention for Efficient Mobile Network Design
- '''
- Recent studies on mobile network design have demonstrated the remarkable effectiveness of channel attention (e.g., the Squeeze-and-Excitation attention) for lifting
- model performance, but they generally neglect the positional information, which is important for generating spatially selective attention maps. In this paper, we propose a
- novel attention mechanism for mobile iscyy networks by embedding positional information into channel attention, which
- we call “coordinate attention”. Unlike channel attention
- that transforms a feature tensor to a single feature vector iscyy via 2D global pooling, the coordinate attention factorizes channel attention into two 1D feature encoding
- processes that aggregate features along the two spatial directions, respectively
- '''
- def __init__(self, inp, oup, reduction=32):
- super(CA, self).__init__()
- mip = max(8, inp // reduction)
- self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
- self.bn1 = nn.BatchNorm2d(mip)
- self.act = h_swish()
- self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
- self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
- def forward(self, x):
- identity = x
- # 沿着H,W维度分别进行平均池化
- n, c, h, w = x.size()
- pool_h = nn.AdaptiveAvgPool2d((h, 1))
- pool_w = nn.AdaptiveAvgPool2d((1, w))
- x_h = pool_h(x)
- x_w = pool_w(x).permute(0, 1, 3, 2)
- # H,W维度拼接,经过卷积进行特征提取,进一步学习H,W维度的关联
- y = torch.cat([x_h, x_w], dim=2)
- y = self.conv1(y)
- y = self.bn1(y)
- y = self.act(y)
- # 分离H,W维度
- x_h, x_w = torch.split(y, [h, w], dim=2)
- x_w = x_w.permute(0, 1, 3, 2)
- # 经过sigmoid得到注意力权重
- a_h = self.conv_h(x_h).sigmoid()
- a_w = self.conv_w(x_w).sigmoid()
- # 将坐标轴的注意力作用到原特征图
- out = identity * a_w * a_h
- return out
- class space_to_depth(nn.Module):
- # Changing the dimension of the Tensor
- def __init__(self, dimension=1):
- super().__init__()
- self.d = dimension
- def forward(self, x):
- return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
导入:from torchvision.ops import Deformconv2d
self.offset=nn.Conv2d(512,18, 3, padding=1, bias=True)
self.torchvision_dcn2d = Deformconv2(512, 512, 3, stride=1, padding=1)
offset = self.torchvision_offset(x)
x= self.torchvision_dcn2d(x,offset)
