赞
踩
论文链接:https://arxiv.org/pdf/2311.17791.pdf
代码链接:https://github.com/yaoppeng/U-Net_v2/blob/master/unet_v2/UNet_v2.py
def forward(self, x): seg_outs = [] f1, f2, f3, f4 = self.encoder(x) f1 = self.ca_1(f1) * f1 f1 = self.sa_1(f1) * f1 f1 = self.Translayer_1(f1) f2 = self.ca_2(f2) * f2 f2 = self.sa_2(f2) * f2 f2 = self.Translayer_2(f2) f3 = self.ca_3(f3) * f3 f3 = self.sa_3(f3) * f3 f3 = self.Translayer_3(f3) f4 = self.ca_4(f4) * f4 f4 = self.sa_4(f4) * f4 f4 = self.Translayer_4(f4) f41 = self.sdi_4([f1, f2, f3, f4], f4) f31 = self.sdi_3([f1, f2, f3, f4], f3) f21 = self.sdi_2([f1, f2, f3, f4], f2) f11 = self.sdi_1([f1, f2, f3, f4], f1) class SDI(nn.Module): def __init__(self, channel): super().__init__() self.convs = nn.ModuleList( [nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) for _ in range(4)]) def forward(self, xs, anchor): ans = torch.ones_like(anchor) target_size = anchor.shape[-1] for i, x in enumerate(xs): if x.shape[-1] > target_size: x = F.adaptive_avg_pool2d(x, (target_size, target_size)) elif x.shape[-1] < target_size: x = F.interpolate(x, size=(target_size, target_size), mode='bilinear', align_corners=True) ans = ans * self.convs[i](x) return ans
过去的UNet在上采样的过程中每次通过拼接的方式复用一个stage的特征
这里则是每个stage都会通过哈达玛积的方式复用编码器中所有stage的特征
在复用前会对编码器每个stage 串联通道、空间注意力做增强
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。