赞
踩
U-Net进阶教程:如何在U-Net中加入ResNet的残差连接
在本教程中,我们将探讨如何在经典的U-Net架构中融入ResNet的残差连接。这种结合了U-Net在图像分割领域的优势和ResNet的残差连接的混合模型,我们称之为ResUnet,旨在通过残差学习改善模型的训练效率和性能。
残差连接是一种允许数据直接从网络的较低层传递到较高层的结构。这种方式可以帮助解决深度神经网络训练过程中的梯度消失问题,使得网络能够学习到更加复杂的功能。
在Res1Unet中,我们在每个下采样(编码)和上采样(解码)步骤中都加入了残差连接,本质上是通过一个核为1的卷积操作来实现维度匹配。以下是Python中的实现代码和相应的解释。
class ResUnet(nn.Module):
def __init__(self, in_ch, out_ch):
super(Res1Unet, self).__init__()
# down sampling
# 假如输入 224*224*1 的图像
# H = ((224 - 3 + 1 + 2 - 1) / 1) + 1 = 224 unet的卷积不会改变特征图的大小
self.conv1 = DoubleConv(in_ch, 64)
# to increase the dimensions
self.w1 = nn.Conv2d(in_ch, 64, kernel_size=1, padding=0, stride=1)
self.pool1 = nn.MaxPool2d(2) # 224 -> 112
self.conv2 = DoubleConv(64, 128) # 不变
# to increase the dimensions
self.w2 = nn.Conv2d(64, 128, kernel_size=1, padding=0, stride=1)
self.pool2 = nn.MaxPool2d(2) # 56
self.conv3 = DoubleConv(128, 256)
# to increase the dimensions
self.w3 = nn.Conv2d(128, 256, kernel_size=1, padding=0, stride=1)
self.pool3 = nn.MaxPool2d(2) # 28
self.conv4 = DoubleConv(256, 512)
# to increase the dimensions
self.w4 = nn.Conv2d(256, 512, kernel_size=1, padding=0, stride=1)
self.pool4 = nn.MaxPool2d(2) # 14
self.conv5 = DoubleConv(512, 1024)
# to increase the dimensions
self.w5 = nn.Conv2d(512, 1024, kernel_size=1, padding=0, stride=1)
# up sampling
# H_out = (14 - 1) * 2 + 2 = 28 往上反卷积
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512) # 28
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.conv10 = nn.Conv2d(64, out_ch, 1)
# 训练时尝试让神经元失活,加大泛化性,仅在训练时使用,pytorch自动补偿参数
self.dropout = nn.Dropout(p=0.3)
def forward(self, x):
# 下采样部分
down0_res = self.w1(x) # residual block,残差连接
down0 = self.conv1(x) + down0_res
down1 = self.pool1(down0)
down1_res = self.w2(down1) # residual block
down1 = self.conv2(down1) + down1_res
down2 = self.pool2(down1)
down2_res = self.w3(down2)
down2 = self.conv3(down2) + down2_res
down3 = self.pool3(down2)
down3_res = self.w4(down3)
down3 = self.conv4(down3) + down3_res
down4 = self.pool4(down3)
down4_res = self.w5(down4)
# 5 , 连接上采样部分前,双卷积卷积操作 [14, 14, 1024]
down5 = self.conv5(down4) + down4_res
# 上采样部分
up_6 = self.up6(down5) # [28, 28, 512]
merge6 = torch.cat([up_6, down3], dim=1) # cat之后又变为[28, 28, 1024]
c6 = self.conv6(merge6) # 重新双卷积变为[28, 28, 512]
up_7 = self.up7(c6) # [56, 56, 256]
merge7 = torch.cat([up_7, down2], dim=1)
c7 = self.conv7(merge7) # [56, 56, 256]
up_8 = self.up8(c7) # [112, 112, 128]
merge8 = torch.cat([up_8, down1], dim=1)
c8 = self.conv8(merge8) # [112, 112, 128]
up_9 = self.up9(c8) # [224, 224, 64]
merge9 = torch.cat([up_9, down0], dim=1)
c9 = self.conv9(merge9) # [224, 224, 64]
c10 = self.conv10(c9) # 卷积输出最终图像 [224, 224, t]
return c10
在这个例子中,w1
、w2
、w3
、w4
、w5
是为了匹配维度而设置的1x1卷积,它们允许我们将原始输入或下采样后的特征添加到特征图中,这样就实现了残差连接。在U-Net的每个编码阶段之后,我们都会加上一个这样的残差连接。
Res1Unet是一个强大的网络架构,它结合了U-Net的优秀特性和ResNet的强大能力。虽然这可能会带来一些额外的计算成本,但在许多情况下,这种额外的成本是值得的,因为它可以显著提升模型性能。
希望本教程能够帮助你理解如何在U-Net中加入残差连接,并鼓励你尝试将这种方法应用到你自己的项目中。
往期精彩干货
基于mmdetection3d的单目3D目标检测模型,效果远超CenterNet3D
SSH?Termius?一篇文章教你使用远程服务器训练
Jetson nano开机自启动python程序
【代码实践】focal loss损失函数及其变形原理详细讲解和图像分割实践(含源码)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。