当前位置:   article > 正文

【干货教学】unet进阶,如何在unet中加入resnet(残差连接)_resunet

resunet


U-Net进阶教程:如何在U-Net中加入ResNet的残差连接

在本教程中,我们将探讨如何在经典的U-Net架构中融入ResNet的残差连接。这种结合了U-Net在图像分割领域的优势和ResNet的残差连接的混合模型,我们称之为ResUnet,旨在通过残差学习改善模型的训练效率和性能。

1.什么是残差连接

残差连接是一种允许数据直接从网络的较低层传递到较高层的结构。这种方式可以帮助解决深度神经网络训练过程中的梯度消失问题,使得网络能够学习到更加复杂的功能。

2.ResUnet架构

2.1 代码实现

在Res1Unet中,我们在每个下采样(编码)和上采样(解码)步骤中都加入了残差连接,本质上是通过一个核为1的卷积操作来实现维度匹配。以下是Python中的实现代码和相应的解释。

class ResUnet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Res1Unet, self).__init__()

        # down sampling
        # 假如输入 224*224*1 的图像
        # H = ((224 - 3 + 1 + 2 - 1) / 1) + 1 = 224  unet的卷积不会改变特征图的大小
        self.conv1 = DoubleConv(in_ch, 64)
        # to increase the dimensions
        self.w1 = nn.Conv2d(in_ch, 64, kernel_size=1, padding=0, stride=1)
        self.pool1 = nn.MaxPool2d(2)  # 224 -> 112

        self.conv2 = DoubleConv(64, 128)  # 不变
        # to increase the dimensions
        self.w2 = nn.Conv2d(64, 128, kernel_size=1, padding=0, stride=1)
        self.pool2 = nn.MaxPool2d(2)  # 56

        self.conv3 = DoubleConv(128, 256)
        # to increase the dimensions
        self.w3 = nn.Conv2d(128, 256, kernel_size=1, padding=0, stride=1)
        self.pool3 = nn.MaxPool2d(2)  # 28

        self.conv4 = DoubleConv(256, 512)
        # to increase the dimensions
        self.w4 = nn.Conv2d(256, 512, kernel_size=1, padding=0, stride=1)
        self.pool4 = nn.MaxPool2d(2)  # 14

        self.conv5 = DoubleConv(512, 1024)
        # to increase the dimensions
        self.w5 = nn.Conv2d(512, 1024, kernel_size=1, padding=0, stride=1)

        # up sampling
        # H_out = (14 - 1) * 2 + 2 = 28 往上反卷积
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)   # 28

        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)

        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)

        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)

        self.conv10 = nn.Conv2d(64, out_ch, 1)

        # 训练时尝试让神经元失活,加大泛化性,仅在训练时使用,pytorch自动补偿参数
        self.dropout = nn.Dropout(p=0.3)

    def forward(self, x):
        # 下采样部分
        down0_res = self.w1(x)  # residual block,残差连接
        down0 = self.conv1(x) + down0_res
        down1 = self.pool1(down0)

        down1_res = self.w2(down1)  # residual block
        down1 = self.conv2(down1) + down1_res
        down2 = self.pool2(down1)

        down2_res = self.w3(down2)
        down2 = self.conv3(down2) + down2_res
        down3 = self.pool3(down2)

        down3_res = self.w4(down3)
        down3 = self.conv4(down3) + down3_res
        down4 = self.pool4(down3)

        down4_res = self.w5(down4)
        # 5 , 连接上采样部分前,双卷积卷积操作    [14, 14, 1024]
        down5 = self.conv5(down4) + down4_res

        # 上采样部分
        up_6 = self.up6(down5)   # [28, 28, 512]
        merge6 = torch.cat([up_6, down3], dim=1)    # cat之后又变为[28, 28, 1024]
        c6 = self.conv6(merge6)   # 重新双卷积变为[28, 28, 512]

        up_7 = self.up7(c6)   # [56, 56, 256]
        merge7 = torch.cat([up_7, down2], dim=1)
        c7 = self.conv7(merge7) # [56, 56, 256]

        up_8 = self.up8(c7)   # [112, 112, 128]
        merge8 = torch.cat([up_8, down1], dim=1)
        c8 = self.conv8(merge8) # [112, 112, 128]

        up_9 = self.up9(c8)   # [224, 224, 64]
        merge9 = torch.cat([up_9, down0], dim=1)
        c9 = self.conv9(merge9)  # [224, 224, 64]

        c10 = self.conv10(c9)  # 卷积输出最终图像   [224, 224, t]

        return c10
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92

在这个例子中,w1w2w3w4w5是为了匹配维度而设置的1x1卷积,它们允许我们将原始输入或下采样后的特征添加到特征图中,这样就实现了残差连接。在U-Net的每个编码阶段之后,我们都会加上一个这样的残差连接。

2.2 图示

2.2.1 unet

在这里插入图片描述

2.2.2 加入残差连接改进

在这里插入图片描述

3.为什么要使用ResUnet?

3.1优势

  1. 改善梯度流通:通过加入残差连接,梯度可以直接流经较短的路径,减少训练过程中的梯度消失问题。
  2. 加速收敛:残差连接有助于网络更快地收敛,提高训练效率。
  3. 提高性能:Res1Unet可以更好地捕捉到图像的细节和上下文信息,提高分割的准确性。

3.2缺点

  1. 增加计算负担:虽然残差连接有很多优点,但它们也会稍微增加前向和后向传播时的计算负担。
  2. 可能导致过拟合:在一些小数据集上,过于复杂的模型可能会导致过拟合。

4.结论

Res1Unet是一个强大的网络架构,它结合了U-Net的优秀特性和ResNet的强大能力。虽然这可能会带来一些额外的计算成本,但在许多情况下,这种额外的成本是值得的,因为它可以显著提升模型性能。
希望本教程能够帮助你理解如何在U-Net中加入残差连接,并鼓励你尝试将这种方法应用到你自己的项目中。


往期精彩干货
基于mmdetection3d的单目3D目标检测模型,效果远超CenterNet3D
SSH?Termius?一篇文章教你使用远程服务器训练
Jetson nano开机自启动python程序
【代码实践】focal loss损失函数及其变形原理详细讲解和图像分割实践(含源码)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/626161
推荐阅读
相关标签
  

闽ICP备14008679号