当前位置:   article > 正文

【BraTS】Brain Tumor Segmentation 脑部肿瘤分割2(UNet的复现)_tbrats: trusted brain tumor segmentation

tbrats: trusted brain tumor segmentation

Brats
医学领域采集的图像,一般都是灰白色的,相比于现实空间中的彩色图像,存储的信息简单了很多。一个是1个维度的灰度信息,一个是3个维度组合到一起的彩色信息。

脑部肿瘤分割的这批数据,相较于其他的数据,有一个特别的地方,就是它是多模态的。尽管是4个不同的模态,但是都对应到一个标签。

此时不禁联想到:既然RGB三个通道的彩色图像,可以直接作为输入送进网络,能够学习到色彩的信息。那么,脑肿瘤采集的4个维度的信息,是不是也可以简单的把它拍在一起,构成一个每一层4个channel的图像,一次检查155层呢。

  • 这样输入图像变成: 4 * 155 * 240 * 240,155层,每一层4个channel,每一个channel是240*240大小
  • 输出还是原来的形式:155 * 240 * 240

此时,这个任务就没有了模态的概念了,就当做他是一个彩色多通道的图像来处理。也不要有体积的概念,他就是一个个二维的图像。尽管损失了很多内在联系的信息,但是,简单了很多啊

  • 没有上下层信息
  • 淡化模态信息

简单的,直接用医学领域常用的unet网络模型作为训练的网络。后面我们就着重搭建复现unet网络模型

复现unet

unet从整体结构上进行划分,大体可以分成两个阶段:

  • 下采样的阶段,也就是U的左边
  • 上采样的阶段,也就是U的右边

unet

而下采样阶段,我们根据数据流动的方式,我们又分为5个的横向layer,每一个layer分别是有以下3个层串联组成:

  • 1个红色箭头pool的层,标号A
  • 2个3*3的卷积层,标号B

上采样阶段,也可以看成是4个横向的,具有相似结构的layer,每一个layer分别是有以下3个层串联组成:

  • 1个向上变大的反卷积层(转置卷积),标号C
  • 2个3*3的卷积层,标号B

输出阶段,经过一个1*1的卷积层,将特征映射到输出维度的类别特征上。如果有5个类别,那输出channel就是等于5的,这个也要定义。

除上面两个之外,还有一个跨层连接的残差结构,用于将下采样的数据传递给上采样阶段使用。避免下采样时候损失信息太多,帮助它恢复。

此时,再简单一些:我们不考虑跨层连接的残差结构,假设就是一个完整的串行,定义网络模型简单骨架版,大概是这样的:

class UNet2D(nn.Module):
    def __init__(self, ):
        super(UNet2D, self).__init__()

        self.downLayer1 = B
        self.downLayer2 = nn.Sequential(A,
                                        B)

        self.downLayer3 = nn.Sequential(A,
                                        B)

        self.downLayer4 = nn.Sequential(A,
                                        B)

        self.bottomLayer = nn.Sequential(A,
                                        B)

        self.upLayer1 = nn.Sequential(C,
                                      B)
        self.upLayer2 = nn.Sequential(C,
                                      B)
        self.upLayer3 = nn.Sequential(C,
                                      B)
        self.upLayer4 = nn.Sequential(C,
                                      B)

        self.outLayer = nn.Conv2d()

    def forward(self, x):
        x = self.downLayer1(x)  
        x = self.downLayer2(x) 
        x = self.downLayer3(x)
        x = self.downLayer4(x) 

        x = self.bottomLayer(x) 

        x = self.upLayer1(x)  
        x = self.upLayer2(x)  
        x = self.upLayer3(x) 
        x = self.upLayer4(x) 
        x = self.outLayer(x)  

        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43

有了整体的结构,我们现在定义 B(2个3*3的卷积层)的结构,也就是上图中绿色矩形框的部分。如下:

class ConvBlock2d(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ConvBlock2d, self).__init__()

        # 第1个3*3的卷积层
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

        # 第2个3*3的卷积层
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    # 定义数据前向流动形式
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

B的结构比较的简单,这里我就不赘述了,一看便知。卷积之后,接数据归一化,再接激活函数,重复上述过程两次,就是B干的事情。

在卷积神经网络的卷积层之后总会添加BatchNorm2d进行数据的归一化处理,这使得数据在进行Relu之前不会因为数据过大而导致网络性能的不稳定

为啥要单独定义B?因为,单独定义的目的是为了复用,避免出现重复书写的繁琐过程。尤其是卷积后面都喜欢配上Norm和激活函数,这样就能少些一些代码。(简洁、也就是懒)

受此启发,C的反卷积结构,是不是也要单独定义下。如下:

class ConvTrans2d(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ConvTrans2d, self).__init__()
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),        # 转置卷积、反卷积
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv1(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

有了B的基础结构,下采样阶段的A和B都定义清楚了。前文定义上采样的模型时候不考虑跨层连接的残差结构,假设就是一个完整的串行。此时需要看到他从U的左边传过来的数据,那实施上采样阶段的一个简单模块是这样的:

  1. 反卷积
  2. 与跨层输入进行连接
  3. 传入B结构

最终上采样的一个模块,可以写成这样:

class UpBlock2d(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(UpBlock2d, self).__init__()
        self.up_conv = ConvTrans2d(in_ch, out_ch)
        self.conv = ConvBlock2d(2 * out_ch, out_ch)

    def forward(self, x, down_features):
        x = self.up_conv(x)
        x = torch.cat([x, down_features], dim=1)
        x = self.conv(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

至此,缺失的部分,我们都给填补上去了,网络结构中还有一些信息是需要我们罗列下的:

  1. channel的变化是1 -> 64 -> 128 -> 256 -> 512 -> 1024 -> 512 -> 256 -> 125 -> 64 -> 2
  2. 卷积核是3*3
  3. pool是max pool
  4. 激活函数是relu

这样,我们就可以改写前面简单版本定义模型的类了,如下:

class UNet2D(nn.Module):
    def __init__(self, in_ch=4, out_ch=2, degree=64):
        super(UNet2D, self).__init__()

        chs = []
        for i in range(5):
            chs.append((2 ** i) * degree)   # [64, 128, 256, 512, 1024]

        self.downLayer1 = ConvBlock2d(in_ch, chs[0])
        self.downLayer2 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        ConvBlock2d(chs[0], chs[1]))

        self.downLayer3 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        ConvBlock2d(chs[1], chs[2]))

        self.downLayer4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        ConvBlock2d(chs[2], chs[3]))

        self.bottomLayer = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
                                        ConvBlock2d(chs[3], chs[4]))

        self.upLayer1 = UpBlock2d(chs[4], chs[3])
        self.upLayer2 = UpBlock2d(chs[3], chs[2])
        self.upLayer3 = UpBlock2d(chs[2], chs[1])
        self.upLayer4 = UpBlock2d(chs[1], chs[0])

        self.outLayer = nn.Conv2d(chs[0], out_ch, kernel_size=1, stride=1)

        # # Params initialization
        # for m in self.modules():
        #     if isinstance(m, nn.Conv2d):
        #         n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        #         m.weight.data.normal_(0, math.sqrt(2. / n))
        #     elif isinstance(m, nn.BatchNorm2d):
        #         m.weight.data.fill_(1)
        #         m.bias.data.zero_()

    def forward(self, x):
        """
        :param x:   4D Tensor    BatchSize * 4(modal) * W * H
        :return:    4D Tensor    BatchSize * 2        * W * H
        """
        x1 = self.downLayer1(x)     # degree(32)   * 16    * W    * H
        x2 = self.downLayer2(x1)    # degree(64)   * 16/2  * W/2  * H/2
        x3 = self.downLayer3(x2)    # degree(128)  * 16/4  * W/4  * H/4
        x4 = self.downLayer4(x3)    # degree(256)  * 16/8  * W/8  * H/8

        x5 = self.bottomLayer(x4)   # degree(512)  * 16/16 * W/16 * H/16

        x = self.upLayer1(x5, x4)   # degree(256)  * 16/8 * W/8 * H/8
        x = self.upLayer2(x, x3)    # degree(128)  * 16/4 * W/4 * H/4
        x = self.upLayer3(x, x2)    # degree(64)   * 16/2 * W/2 * H/2
        x = self.upLayer4(x, x1)    # degree(32)   * 16   * W   * H
        x = self.outLayer(x)        # out_ch(2 )   * 16   * W   * H
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55

到这里,unet结构的复线部分就构建完毕了,不知道你有没有理解上面分拆的一个个结构。可以把构建网络的过程,理解为搭积木

  1. 把长相一样的,给整理到一起
  2. 把有方向信息的,整理到一起
  3. 实在很独特的,就单独定义了插进来

定义好了模型还不算完,分阶段测试下构建的网络是不是和我们所预想的一样。我们给他一个输入,测试下是否与我们最初的想法是一致的,是否报错等等问题,如下:

if __name__ == "__main__":
    net = UNet2D(4, 5, degree=64)
    print(net)
    print("total parameter:" + str(netSize(net)))

    batch_size = 4
    a = torch.randn(batch_size, 4, 240, 240)
    print(a.shape)     # (batch_size, 4, 240, 240)
    b = net(a)
    print(b.shape)      # (batch_size, 5, 240, 240)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

这时,你就可以看看,打印的网络模型,是不是和这张图的结构式完全一样的。改变网络的输入层,输出类别,或者每一层的channel数,看看参数量的变化。如下这样

net = UNet2D(4, 5, degree=64)
  • 1

打印结果:

total parameter:34530437
torch.Size([4, 4, 240, 240])
torch.Size([4, 5, 240, 240])
  • 1
  • 2
  • 3

也可以输入的channel设定为1,如下:

net = UNet2D(1, 5, degree=64)
  • 1

打印结果:

total parameter:34528709
torch.Size([4, 1, 240, 240])
torch.Size([4, 5, 240, 240])
  • 1
  • 2
  • 3

当然,你也可以更改degree的值,如下:

net = UNet2D(4, 5, degree=128)
  • 1

打印结果:

total parameter:138070277
torch.Size([4, 4, 240, 240])
torch.Size([4, 5, 240, 240])
  • 1
  • 2
  • 3

同理,也可以改输出类别,把之前的5类,给改成3类,如下:

net = UNet2D(4, 3, degree=64)
  • 1

输出结果:

total parameter:34529283
torch.Size([4, 4, 240, 240])
torch.Size([4, 3, 240, 240])
  • 1
  • 2
  • 3

上面几处内容的改变,会发现都伴随着模型参数量的改变。有些改变比较大,有些比较小,值得关注下。网络验证通过,就是根据网络的输入,构建一个数据流,传入网络


最后,如果您觉得本篇文章对你有帮助,欢迎点赞

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/97913?site
推荐阅读
相关标签