当前位置:   article > 正文

[pytorch] Unet医学分割 代码详解_unet图象分割代码

unet图象分割代码

U-Net for brain segmentation

基于深度学习分割算法在 PyTorch 中的 U-Net 实现,用于脑 MRI 中的 FLAIR 异常分割
github代码: U-Net for brain segmentation
kaggle代码: brain-segmentation-pytorch
数据集下载:Brain MRI segmentation
数据集很小,代码也很清晰,比较好实现

Unet 模型

在这里插入图片描述

class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
unet = UNet()
print(unet)
  • 1
  • 2
UNet(
  (encoder1): Sequential(
    (enc1conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc1norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc1relu1): ReLU(inplace=True)
    (enc1conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc1norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc1relu2): ReLU(inplace=True)
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder2): Sequential(
    (enc2conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc2norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc2relu1): ReLU(inplace=True)
    (enc2conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc2norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc2relu2): ReLU(inplace=True)
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder3): Sequential(
    (enc3conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc3norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc3relu1): ReLU(inplace=True)
    (enc3conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc3norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc3relu2): ReLU(inplace=True)
  )
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder4): Sequential(
    (enc4conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc4norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc4relu1): ReLU(inplace=True)
    (enc4conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc4norm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc4relu2): ReLU(inplace=True)
  )
  (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (bottleneck): Sequential(
    (bottleneckconv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bottlenecknorm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bottleneckrelu1): ReLU(inplace=True)
    (bottleneckconv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bottlenecknorm2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bottleneckrelu2): ReLU(inplace=True)
  )
  (upconv4): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
  (decoder4): Sequential(
    (dec4conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec4norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec4relu1): ReLU(inplace=True)
    (dec4conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec4norm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec4relu2): ReLU(inplace=True)
  )
  (upconv3): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
  (decoder3): Sequential(
    (dec3conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec3norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec3relu1): ReLU(inplace=True)
    (dec3conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec3norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec3relu2): ReLU(inplace=True)
  )
  (upconv2): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
  (decoder2): Sequential(
    (dec2conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec2norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec2relu1): ReLU(inplace=True)
    (dec2conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec2norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec2relu2): ReLU(inplace=True)
  )
  (upconv1): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
  (decoder1): Sequential(
    (dec1conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec1norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec1relu1): ReLU(inplace=True)
    (dec1conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec1norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec1relu2): ReLU(inplace=True)
  )
  (conv): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84

数据读取

遍历文件

首先遍历文件地址和名称,通过os.walk函数。dirpath 表示当前正在访问的文件夹路径, dirnames 表示该文件夹下的子目录名list, filenames表示该文件夹下的文件list.
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

for (dirpath, dirnames, filenames) in os.walk(images_dir):
    for filename in sorted(
        filter(lambda f: ".tif" in f, filenames),# 过滤包含.tif的文件
        key=lambda x: int(x.split(".")[-2].split("_")[4]), #在每个病人的文件中,根据文件名称中最后一位排序 
    ):
        print(filename)
        filepath = os.path.join(dirpath, filename)
        print(filepath)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
TCGA_CS_6665_20010817_1.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_1.tif
TCGA_CS_6665_20010817_1_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_1_mask.tif
TCGA_CS_6665_20010817_2_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_2_mask.tif
TCGA_CS_6665_20010817_2.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_2.tif
TCGA_CS_6665_20010817_3.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_3.tif
TCGA_CS_6665_20010817_3_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_3_mask.tif
TCGA_CS_6665_20010817_4_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_4_mask.tif
TCGA_CS_6665_20010817_4.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_4.tif
TCGA_CS_6665_20010817_5_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_5_mask.tif
TCGA_CS_6665_20010817_5.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_5.tif
TCGA_CS_6665_20010817_6.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_6.tif
TCGA_CS_6665_20010817_6_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_6_mask.tif
TCGA_CS_6665_20010817_7_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_7_mask.tif
TCGA_CS_6665_20010817_7.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_7.tif
TCGA_CS_6665_20010817_8.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_8.tif
TCGA_CS_6665_20010817_8_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_8_mask.tif
TCGA_CS_6665_20010817_9.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_9.tif
TCGA_CS_6665_20010817_9_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_9_mask.tif
TCGA_CS_6665_20010817_10_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_10_mask.tif
TCGA_CS_6665_20010817_10.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_10.tif
TCGA_CS_6665_20010817_11_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_11_mask.tif
TCGA_CS_6665_20010817_11.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_11.tif
TCGA_CS_6665_20010817_12.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_12.tif
TCGA_CS_6665_20010817_12_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_12_mask.tif
TCGA_CS_6665_20010817_13_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_13_mask.tif
TCGA_CS_6665_20010817_13.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_13.tif
TCGA_CS_6665_20010817_14_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_14_mask.tif
TCGA_CS_6665_20010817_14.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_14.tif
TCGA_CS_6665_20010817_15_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_15_mask.tif
TCGA_CS_6665_20010817_15.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_15.tif
TCGA_CS_6665_20010817_16.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_16.tif
TCGA_CS_6665_20010817_16_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_16_mask.tif
TCGA_CS_6665_20010817_17_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_17_mask.tif
TCGA_CS_6665_20010817_17.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_17.tif
TCGA_CS_6665_20010817_18_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_18_mask.tif
TCGA_CS_6665_20010817_18.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_18.tif
TCGA_CS_6665_20010817_19.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_19.tif
TCGA_CS_6665_20010817_19_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_19_mask.tif
TCGA_CS_6665_20010817_20_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_20_mask.tif
TCGA_CS_6665_20010817_20.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_20.tif
TCGA_CS_6665_20010817_21.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_21.tif
TCGA_CS_6665_20010817_21_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_21_mask.tif
TCGA_CS_6665_20010817_22.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_22.tif
TCGA_CS_6665_20010817_22_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_22_mask.tif
TCGA_CS_6665_20010817_23_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_23_mask.tif
TCGA_CS_6665_20010817_23.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_23.tif
TCGA_CS_6665_20010817_24_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_24_mask.tif
TCGA_CS_6665_20010817_24.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_24.tif
TCGA_CS_6669_20020102_1.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_1.tif
TCGA_CS_6669_20020102_1_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_1_mask.tif
TCGA_CS_6669_20020102_2_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_2_mask.tif
TCGA_CS_6669_20020102_2.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_2.tif
TCGA_CS_6669_20020102_3.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_3.tif
TCGA_CS_6669_20020102_3_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_3_mask.tif
TCGA_CS_6669_20020102_4.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_4.tif
TCGA_CS_6669_20020102_4_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_4_mask.tif
TCGA_CS_6669_20020102_5.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_5.tif
TCGA_CS_6669_20020102_5_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_5_mask.tif
TCGA_CS_6669_20020102_6_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_6_mask.tif
TCGA_CS_6669_20020102_6.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_6.tif
TCGA_CS_6669_20020102_7_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_7_mask.tif
TCGA_CS_6669_20020102_7.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_7.tif
TCGA_CS_6669_20020102_8.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_8.tif
TCGA_CS_6669_20020102_8_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_8_mask.tif
TCGA_CS_6669_20020102_9_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_9_mask.tif
TCGA_CS_6669_20020102_9.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_9.tif
TCGA_CS_6669_20020102_10_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_10_mask.tif
TCGA_CS_6669_20020102_10.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_10.tif
TCGA_CS_6669_20020102_11.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_11.tif
TCGA_CS_6669_20020102_11_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_11_mask.tif
TCGA_CS_6669_20020102_12_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_12_mask.tif
TCGA_CS_6669_20020102_12.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_12.tif
TCGA_CS_6669_20020102_13_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_13_mask.tif
TCGA_CS_6669_20020102_13.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_13.tif
TCGA_CS_6669_20020102_14_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_14_mask.tif
TCGA_CS_6669_20020102_14.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_14.tif
TCGA_CS_6669_20020102_15.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_15.tif
TCGA_CS_6669_20020102_15_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_15_mask.tif
TCGA_CS_6669_20020102_16_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_16_mask.tif
TCGA_CS_6669_20020102_16.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_16.tif
TCGA_CS_6669_20020102_17.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_17.tif
TCGA_CS_6669_20020102_17_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_17_mask.tif
TCGA_CS_6669_20020102_18_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_18_mask.tif
TCGA_CS_6669_20020102_18.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_18.tif
TCGA_CS_6669_20020102_19_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_19_mask.tif
TCGA_CS_6669_20020102_19.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_19.tif
TCGA_CS_6669_20020102_20_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_20_mask.tif
TCGA_CS_6669_20020102_20.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_20.tif
TCGA_CS_6669_20020102_21_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_21_mask.tif
TCGA_CS_6669_20020102_21.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_21.tif
TCGA_CS_6669_20020102_22_mask.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_22_mask.tif
TCGA_CS_6669_20020102_22.tif
./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_22.tif
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184

在这里插入图片描述

读取数据

# read images
volumes = {}
masks = {}
images_dir = './archive/lgg-mri-segmentation/kaggle_3m'
print("reading {} images...".format("train"))
for (dirpath, dirnames, filenames) in os.walk(images_dir):
    image_slices = []
    mask_slices = []
    for filename in sorted(
        filter(lambda f: ".tif" in f, filenames),
        key=lambda x: int(x.split(".")[-2].split("_")[4]),
    ):
        filepath = os.path.join(dirpath, filename)
        if "mask" in filename:
            mask_slices.append(imread(filepath, as_gray=True))
        else:
            image_slices.append(imread(filepath))
    if len(image_slices) > 0:
        patient_id = dirpath.split("/")[-1]
        volumes[patient_id] = np.array(image_slices[1:-1])
        masks[patient_id] = np.array(mask_slices[1:-1])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

将所以病人数据储存在dict格式中,这里作者舍去了每个病人第一个和最后一个图片image_slices[1:-1],不太清楚是为什么。

print(len(volumes)) # 110个病人
print(len(volumes['TCGA_DU_8163_19961119'])) # 这个病人有35个切片
  • 1
  • 2

数据集划分

随机选择十个病人的数据作为验证集,剩下的是训练集

patients_list = sorted(volumes) # 所有病人的名称 "all"时返回全部
seed=42
subset = "train"
# select cases to subset
if not subset == "all":
    random.seed(seed)
    validation_patients = random.sample(patients_list, k=10) # 从序列中随机抽取k个元素,并将k个元素生以list形式返回。 
    if subset == "validation":
        patients_list = validation_patients
    else: # "train"
        patients_list = sorted(
            list(set(patients_list).difference(validation_patients)) # difference() 方法用于返回集合的差集
        )
print(patients_list)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
['TCGA_CS_4941_19960909', 'TCGA_CS_4942_19970222', 'TCGA_CS_4943_20000902', 'TCGA_CS_5393_19990606', 'TCGA_CS_5395_19981004', 'TCGA_CS_5396_20010302', 'TCGA_CS_5397_20010315', 'TCGA_CS_6186_20000601', 'TCGA_CS_6188_20010812', 'TCGA_CS_6290_20000917', 'TCGA_CS_6665_20010817', 'TCGA_CS_6666_20011109', 'TCGA_CS_6669_20020102', 'TCGA_DU_5849_19950405', 'TCGA_DU_5852_19950709', 'TCGA_DU_5853_19950823', 'TCGA_DU_5854_19951104', 'TCGA_DU_5855_19951217', 'TCGA_DU_5871_19941206', 'TCGA_DU_5872_19950223', 'TCGA_DU_5874_19950510', 'TCGA_DU_6399_19830416', 'TCGA_DU_6400_19830518', 'TCGA_DU_6401_19831001', 'TCGA_DU_6405_19851005', 'TCGA_DU_6407_19860514', 'TCGA_DU_7008_19830723', 'TCGA_DU_7010_19860307', 'TCGA_DU_7013_19860523', 'TCGA_DU_7018_19911220', 'TCGA_DU_7019_19940908', 'TCGA_DU_7294_19890104', 'TCGA_DU_7298_19910324', 'TCGA_DU_7299_19910417', 'TCGA_DU_7300_19910814', 'TCGA_DU_7301_19911112', 'TCGA_DU_7302_19911203', 'TCGA_DU_7304_19930325', 'TCGA_DU_7306_19930512', 'TCGA_DU_7309_19960831', 'TCGA_DU_8162_19961029', 'TCGA_DU_8163_19961119', 'TCGA_DU_8164_19970111', 'TCGA_DU_8165_19970205', 'TCGA_DU_8166_19970322', 'TCGA_DU_8167_19970402', 'TCGA_DU_8168_19970503', 'TCGA_DU_A5TP_19970614', 'TCGA_DU_A5TR_19970726', 'TCGA_DU_A5TS_19970726', 'TCGA_DU_A5TT_19980318', 'TCGA_DU_A5TU_19980312', 'TCGA_DU_A5TW_19980228', 'TCGA_DU_A5TY_19970709', 'TCGA_EZ_7264_20010816', 'TCGA_FG_5962_20000626', 'TCGA_FG_5964_20010511', 'TCGA_FG_6688_20020215', 'TCGA_FG_6689_20020326', 'TCGA_FG_6690_20020226', 'TCGA_FG_6691_20020405', 'TCGA_FG_6692_20020606', 'TCGA_FG_7634_20000128', 'TCGA_FG_7637_20000922', 'TCGA_FG_7643_20021104', 'TCGA_FG_8189_20030516', 'TCGA_FG_A4MT_20020212', 'TCGA_FG_A4MU_20030903', 'TCGA_FG_A60K_20040224', 'TCGA_HT_7473_19970826', 'TCGA_HT_7475_19970918', 'TCGA_HT_7602_19951103', 'TCGA_HT_7605_19950916', 'TCGA_HT_7608_19940304', 'TCGA_HT_7680_19970202', 'TCGA_HT_7684_19950816', 'TCGA_HT_7686_19950629', 'TCGA_HT_7690_19960312', 'TCGA_HT_7693_19950520', 'TCGA_HT_7694_19950404', 'TCGA_HT_7855_19951020', 'TCGA_HT_7856_19950831', 'TCGA_HT_7860_19960513', 'TCGA_HT_7874_19950902', 'TCGA_HT_7877_19980917', 'TCGA_HT_7881_19981015', 'TCGA_HT_7882_19970125', 'TCGA_HT_7884_19980913', 'TCGA_HT_8018_19970411', 'TCGA_HT_8105_19980826', 'TCGA_HT_8106_19970727', 'TCGA_HT_8107_19980708', 'TCGA_HT_8111_19980330', 'TCGA_HT_8113_19930809', 'TCGA_HT_8114_19981030', 'TCGA_HT_8563_19981209', 'TCGA_HT_A5RC_19990831', 'TCGA_HT_A616_19991226', 'TCGA_HT_A61A_20000127', 'TCGA_HT_A61B_19991127']
  • 1

数据增强

作者自己写的数据增强函数,同时对数据和mask做数据增强。

def crop_sample(x):
    volume, mask = x
    volume[volume < np.max(volume) * 0.1] = 0
    z_projection = np.max(np.max(np.max(volume, axis=-1), axis=-1), axis=-1)
    z_nonzero = np.nonzero(z_projection)
    z_min = np.min(z_nonzero)
    z_max = np.max(z_nonzero) + 1
    y_projection = np.max(np.max(np.max(volume, axis=0), axis=-1), axis=-1)
    y_nonzero = np.nonzero(y_projection)
    y_min = np.min(y_nonzero)
    y_max = np.max(y_nonzero) + 1
    x_projection = np.max(np.max(np.max(volume, axis=0), axis=0), axis=-1)
    x_nonzero = np.nonzero(x_projection)
    x_min = np.min(x_nonzero)
    x_max = np.max(x_nonzero) + 1
    return (
        volume[z_min:z_max, y_min:y_max, x_min:x_max],
        mask[z_min:z_max, y_min:y_max, x_min:x_max],
    )


def pad_sample(x):
    volume, mask = x
    a = volume.shape[1]
    b = volume.shape[2]
    if a == b:
        return volume, mask
    diff = (max(a, b) - min(a, b)) / 2.0
    if a > b:
        padding = ((0, 0), (0, 0), (int(np.floor(diff)), int(np.ceil(diff))))
    else:
        padding = ((0, 0), (int(np.floor(diff)), int(np.ceil(diff))), (0, 0))
    mask = np.pad(mask, padding, mode="constant", constant_values=0)
    padding = padding + ((0, 0),)
    volume = np.pad(volume, padding, mode="constant", constant_values=0)
    return volume, mask


def resize_sample(x, size=256):
    volume, mask = x
    v_shape = volume.shape
    out_shape = (v_shape[0], size, size)
    mask = resize(
        mask,
        output_shape=out_shape,
        order=0,
        mode="constant",
        cval=0,
        anti_aliasing=False,
    )
    out_shape = out_shape + (v_shape[3],)
    volume = resize(
        volume,
        output_shape=out_shape,
        order=2,
        mode="constant",
        cval=0,
        anti_aliasing=False,
    )
    return volume, mask


def normalize_volume(volume):
    p10 = np.percentile(volume, 10)
    p99 = np.percentile(volume, 99)
    volume = rescale_intensity(volume, in_range=(p10, p99))
    m = np.mean(volume, axis=(0, 1, 2))
    s = np.std(volume, axis=(0, 1, 2))
    volume = (volume - m) / s
    return volume

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
print("preprocessing {} volumes...".format(subset))
# create list of tuples (volume, mask)
volumes_list = [(volumes[k], masks[k]) for k in patients_list]

print("cropping {} volumes...".format(subset))
# crop to smallest enclosing volume
volumes_list = [crop_sample(v) for v in volumes_list]

print("padding {} volumes...".format(subset))
# pad to square
volumes_list = [pad_sample(v) for v in volumes_list]

print("resizing {} volumes...".format(subset))
# resize
volumes_list = [resize_sample(v, size=image_size) for v in volumes_list]

print("normalizing {} volumes...".format(subset))
# normalize channel-wise
volumes_list = [(normalize_volume(v), m) for v, m in volumes_list]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

在这里插入图片描述
在这里插入图片描述

根据mask计算出切片概率,随机采样的时候使用

# probabilities for sampling slices based on masks
slice_weights = [m.sum(axis=-1).sum(axis=-1) for v, m in volumes_list]
slice_weights = [
    (s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in slice_weights
]
print(len(slice_weights)) #100
print(len(slice_weights[0])) #21
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

mask只有三维,我们需要增加一个维度

# add channel dimension to masks
volumes_list = [(v, m[..., np.newaxis]) for (v, m) in volumes_list]  # ... 等于 [:,:,:]
  • 1
  • 2

索引列表

一个是病人的索引,另一个是切片的索引

# create global index for patient and slice (idx -> (p_idx, s_idx))
num_slices = [v.shape[0] for v, m in volumes_list] # 每个病人的切片个数
patient_slice_index = list(
    zip(
        sum([[i] * num_slices[i] for i in range(len(num_slices))], []),
        sum([list(range(x)) for x in num_slices], []),
    )
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

在这里插入图片描述

zip() 将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
sum的作用
在这里插入图片描述
最后的索引是这样的,第一个病人的地一个slice,第一个病人的地二个slice,第一个病人的地三个slice…第二个病人的地一个slice…
在这里插入图片描述

getitem

首先__len__返回的是patient_slice_index的长度,也就是我们有多少组图像(数据+mask算一组),我们训练的时getitem每次产生一组2d数据来进行训练,数据图像当作输入,mask当作label。

idx = 50
patient = patient_slice_index[idx][0]
slice_n = patient_slice_index[idx][1]
v, m = volumes_list[patient]
image = v[slice_n]
mask = m[slice_n]
print(len(volumes_list)) # 100个病人
print(v.shape) # 每个病人的切片数量
print(m.shape) # 对应的mask
print(image.shape)
print(mask.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
100
(18, 224, 224, 3)
(18, 224, 224, 1)
(224, 224, 3)
(224, 224, 1)
  • 1
  • 2
  • 3
  • 4
  • 5

首先通过index随机选出来一个病人patient,然后对于这个病人,随机选一个slice。

# fix dimensions (C, H, W)
image = image.transpose(2, 0, 1)
mask = mask.transpose(2, 0, 1)

image_tensor = torch.from_numpy(image.astype(np.float32))
mask_tensor = torch.from_numpy(mask.astype(np.float32))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

最后因为训练需要tensor格式的数据,所以将通道放到第一位,然后转化为tensor

训练和验证

DSC

图像分割常用评价指标DSC。对于分割过程中的评价标准主要采用Dice相似系数(Dice Similariy Coefficient,DSC),Dice系数是一种集合相似度度量指标,通常用于计算两个样本的相似度,值的范围 0-1 ,分割结果最好时值为 1 ,最差时值为 0.

详情: 图像分割常用评价指标DSC、Hausdorff_95、IOU、PPV等

validation

训练过程没有太多要解释的,按pytorch正常流程走就行。训练结果

reading train images...
preprocessing train volumes...
cropping train volumes...
padding train volumes...
resizing train volumes...
normalizing train volumes...
done creating train dataset
reading validation images...
preprocessing validation volumes...
cropping validation volumes...
padding validation volumes...
resizing validation volumes...
normalizing validation volumes...
done creating validation dataset
epoch 1 | loss: 0.8733632518694951
epoch 1 | val_loss: 0.9460358023643494
epoch 1 | val_dsc: 0.1948670436467475
epoch 2 | loss: 0.8402993633196905
epoch 2 | val_loss: 0.931992749373118
epoch 2 | val_dsc: 0.4118475790023708
epoch 3 | loss: 0.8270544547301072
epoch 3 | val_loss: 0.9293850064277649
epoch 3 | val_dsc: 0.4527188003808261
epoch 4 | loss: 0.8215367862811456
epoch 4 | val_loss: 0.9270628492037455
epoch 4 | val_dsc: 0.6138052787773625
epoch 5 | loss: 0.8171369204154382
epoch 5 | val_loss: 0.9256295363108317
epoch 5 | val_dsc: 0.4911489058649566
epoch 6 | loss: 0.8134041795363793
epoch 6 | val_loss: 0.9244295557339987
epoch 6 | val_dsc: 0.6963948791553534
epoch 7 | loss: 0.8091526673390315
epoch 7 | val_loss: 0.9240273038546244
epoch 7 | val_dsc: 0.7177726293762504
epoch 8 | loss: 0.8052632143864265
epoch 8 | val_loss: 0.9218348860740662
epoch 8 | val_dsc: 0.7100711862449238
epoch 9 | loss: 0.801162777038721
epoch 9 | val_loss: 0.9223942359288534
epoch 9 | val_dsc: 0.7482642381530232
epoch 10 | loss: 0.7974189153084388
epoch 10 | val_loss: 0.9184039433797201
.....
Best validation mean DSC: 0.855153
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45

看一下验证过程dsc的计算过程

validation_true = []
for i, data in enumerate(loader_valid):
    x, y_true = data
    x, y_true = x.to(device), y_true.to(device)
    print(y_true.shape)
    y_true_np = y_true.detach().cpu().numpy()
    print(y_true_np.shape)
    validation_true.extend(
        [y_true_np[s] for s in range(y_true_np.shape[0])]
    )
    print(len(validation_true))
    print(validation_true[0].shape)
    break
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
torch.Size([128, 1, 256, 256])
(128, 1, 256, 256)
128
(1, 256, 256)
  • 1
  • 2
  • 3
  • 4

这里我选择batc_size为128, 所以每次产生128张数据,将其保存到validation_true和validation_pred中。

mean_dsc = np.mean(
    dsc_per_volume(
        validation_pred,
        validation_true,
        loader_valid.dataset.patient_slice_index,
    )
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在每个epoch结束之后,计算这个epoch所有数据的dsc均值。

if mean_dsc > best_validation_dsc:
    best_validation_dsc = mean_dsc
    torch.save(unet.state_dict(), os.path.join(weights, "unet.pt"))
  • 1
  • 2
  • 3

每当出现更好的结果,我们将模型保存下来

预测

我们使用之前保存的模型进行图像预测

device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")

loader_train, loader_valid = data_loaders(batch_size, workers, image_size, aug_scale, aug_angle)
loaders = {"train": loader_train, "valid": loader_valid}
unet = UNet(in_channels=BrainSegmentationDataset.in_channels, out_channels=BrainSegmentationDataset.out_channels)
unet.to(device)
state_dict = torch.load(os.path.join('./', "unet.pt"))
unet.load_state_dict(state_dict)
unet.eval()

input_list = []
pred_list = []
true_list = []

for i, data in enumerate(loader_valid):
    x, y_true = data
    x, y_true = x.to(device), y_true.to(device)
    with torch.set_grad_enabled(False):
        y_pred = unet(x)
        y_pred_np = y_pred.detach().cpu().numpy()
        pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])
        y_true_np = y_true.detach().cpu().numpy()
        true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])
        x_np = x.detach().cpu().numpy()
        input_list.extend([x_np[s] for s in range(x_np.shape[0])])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

然后他用这个函数处理预测出来的结果. 现在我们只是有了每个slice的结果,我们需要将这些预测出来的结果按病人划分到一起,这样我们才能更好的观察分割结果。

def postprocess_per_volume(
    input_list, pred_list, true_list, patient_slice_index, patients
):
    volumes = {}
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        volume_in = np.array(input_list[index : index + num_slices[p]])
        volume_pred = np.round(
            np.array(pred_list[index : index + num_slices[p]])
        ).astype(int)
        volume_true = np.array(true_list[index : index + num_slices[p]])
        volumes[patients[p]] = (volume_in, volume_pred, volume_true)
        index += num_slices[p]
    return volumes
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

我们来看一下这个函数的用法
首先,我们先统计一下验证集/测试集中每个病人slice的个数。 在数据读取阶段,我们看过loader_valid.dataset.patient_slice_index的结果,他像二位坐标一样确定了哪张slice,第一个病人的第一张实力测,第一个病人的第二张slice…
在这里插入图片描述

然后我们使用np.bincount统计二维i坐标的第一个维度,他的数量就是每个病人slice的数量
在这里插入图片描述
然后根据每个病人slice的数量,将预测出来的结果划分到一起。
在这里插入图片描述
在这里插入图片描述
然后,计算出每个病人的dice并画出来

dsc_dist = dsc_distribution(volumes)

dsc_dist_plot = plot_dsc(dsc_dist)
imsave("./dsc.png", dsc_dist_plot)
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述
红线是均值,绿线是中位数
最后,我们还要看一下分割的效果

for p in volumes:
    x = volumes[p][0]
    y_pred = volumes[p][1]
    y_true = volumes[p][2]
    for s in range(x.shape[0]):
        image = gray2rgb(x[s, 1])  # channel 1 is for FLAIR
        image = outline(image, y_pred[s, 0], color=[255, 0, 0])
        image = outline(image, y_true[s, 0], color=[0, 255, 0])
        filename = "{}-{}.png".format(p, str(s).zfill(2))
        filepath = os.path.join("./resultat", filename)
        imsave(filepath, image)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

在原始图像上画出预测出来的结果(红色)和gt(绿色)

在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述

完整代码

依赖

import os
import random

from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from tqdm import tqdm
from skimage.exposure import rescale_intensity
from skimage.io import imread, imsave
from skimage.transform import resize, rescale, rotate
from torch.utils.data import Dataset
from torchvision.transforms import Compose
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

数据增强函数

def crop_sample(x):
    volume, mask = x
    volume[volume < np.max(volume) * 0.1] = 0
    z_projection = np.max(np.max(np.max(volume, axis=-1), axis=-1), axis=-1)
    z_nonzero = np.nonzero(z_projection)
    z_min = np.min(z_nonzero)
    z_max = np.max(z_nonzero) + 1
    y_projection = np.max(np.max(np.max(volume, axis=0), axis=-1), axis=-1)
    y_nonzero = np.nonzero(y_projection)
    y_min = np.min(y_nonzero)
    y_max = np.max(y_nonzero) + 1
    x_projection = np.max(np.max(np.max(volume, axis=0), axis=0), axis=-1)
    x_nonzero = np.nonzero(x_projection)
    x_min = np.min(x_nonzero)
    x_max = np.max(x_nonzero) + 1
    return (
        volume[z_min:z_max, y_min:y_max, x_min:x_max],
        mask[z_min:z_max, y_min:y_max, x_min:x_max],
    )


def pad_sample(x):
    volume, mask = x
    a = volume.shape[1]
    b = volume.shape[2]
    if a == b:
        return volume, mask
    diff = (max(a, b) - min(a, b)) / 2.0
    if a > b:
        padding = ((0, 0), (0, 0), (int(np.floor(diff)), int(np.ceil(diff))))
    else:
        padding = ((0, 0), (int(np.floor(diff)), int(np.ceil(diff))), (0, 0))
    mask = np.pad(mask, padding, mode="constant", constant_values=0)
    padding = padding + ((0, 0),)
    volume = np.pad(volume, padding, mode="constant", constant_values=0)
    return volume, mask


def resize_sample(x, size=256):
    volume, mask = x
    v_shape = volume.shape
    out_shape = (v_shape[0], size, size)
    mask = resize(
        mask,
        output_shape=out_shape,
        order=0,
        mode="constant",
        cval=0,
        anti_aliasing=False,
    )
    out_shape = out_shape + (v_shape[3],)
    volume = resize(
        volume,
        output_shape=out_shape,
        order=2,
        mode="constant",
        cval=0,
        anti_aliasing=False,
    )
    return volume, mask


def normalize_volume(volume):
    p10 = np.percentile(volume, 10)
    p99 = np.percentile(volume, 99)
    volume = rescale_intensity(volume, in_range=(p10, p99))
    m = np.mean(volume, axis=(0, 1, 2))
    s = np.std(volume, axis=(0, 1, 2))
    volume = (volume - m) / s
    return volume
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
def transforms(scale=None, angle=None, flip_prob=None):
    transform_list = []

    if scale is not None:
        transform_list.append(Scale(scale))
    if angle is not None:
        transform_list.append(Rotate(angle))
    if flip_prob is not None:
        transform_list.append(HorizontalFlip(flip_prob))

    return Compose(transform_list)


class Scale(object):

    def __init__(self, scale):
        self.scale = scale

    def __call__(self, sample):
        image, mask = sample

        img_size = image.shape[0]

        scale = np.random.uniform(low=1.0 - self.scale, high=1.0 + self.scale)

        image = rescale(
            image,
            (scale, scale),
            multichannel=True,
            preserve_range=True,
            mode="constant",
            anti_aliasing=False,
        )
        mask = rescale(
            mask,
            (scale, scale),
            order=0,
            multichannel=True,
            preserve_range=True,
            mode="constant",
            anti_aliasing=False,
        )

        if scale < 1.0:
            diff = (img_size - image.shape[0]) / 2.0
            padding = ((int(np.floor(diff)), int(np.ceil(diff))),) * 2 + ((0, 0),)
            image = np.pad(image, padding, mode="constant", constant_values=0)
            mask = np.pad(mask, padding, mode="constant", constant_values=0)
        else:
            x_min = (image.shape[0] - img_size) // 2
            x_max = x_min + img_size
            image = image[x_min:x_max, x_min:x_max, ...]
            mask = mask[x_min:x_max, x_min:x_max, ...]

        return image, mask


class Rotate(object):

    def __init__(self, angle):
        self.angle = angle

    def __call__(self, sample):
        image, mask = sample

        angle = np.random.uniform(low=-self.angle, high=self.angle)
        image = rotate(image, angle, resize=False, preserve_range=True, mode="constant")
        mask = rotate(
            mask, angle, resize=False, order=0, preserve_range=True, mode="constant"
        )
        return image, mask


class HorizontalFlip(object):

    def __init__(self, flip_prob):
        self.flip_prob = flip_prob

    def __call__(self, sample):
        image, mask = sample

        if np.random.rand() > self.flip_prob:
            return image, mask

        image = np.fliplr(image).copy()
        mask = np.fliplr(mask).copy()

        return image, mask
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88

读取数据

class BrainSegmentationDataset(Dataset):
    """Brain MRI dataset for FLAIR abnormality segmentation"""

    in_channels = 3
    out_channels = 1

    def __init__(
        self,
        images_dir,
        transform=None,
        image_size=256,
        subset="train",
        random_sampling=True,
        seed=42,
    ):
        assert subset in ["all", "train", "validation"]

        # read images
        volumes = {}
        masks = {}
        print("reading {} images...".format(subset))
        for (dirpath, dirnames, filenames) in os.walk(images_dir):
            image_slices = []
            mask_slices = []
            for filename in sorted(
                filter(lambda f: ".tif" in f, filenames),
                key=lambda x: int(x.split(".")[-2].split("_")[4]),
            ):
                filepath = os.path.join(dirpath, filename)
                if "mask" in filename:
                    mask_slices.append(imread(filepath, as_gray=True))
                else:
                    image_slices.append(imread(filepath))
            if len(image_slices) > 0:
                patient_id = dirpath.split("/")[-1]
                volumes[patient_id] = np.array(image_slices[1:-1])
                masks[patient_id] = np.array(mask_slices[1:-1])

        self.patients = sorted(volumes)

        # select cases to subset
        if not subset == "all":
            random.seed(seed)
            validation_patients = random.sample(self.patients, k=10)
            if subset == "validation":
                self.patients = validation_patients
            else:
                self.patients = sorted(
                    list(set(self.patients).difference(validation_patients))
                )

        print("preprocessing {} volumes...".format(subset))
        # create list of tuples (volume, mask)
        self.volumes = [(volumes[k], masks[k]) for k in self.patients]

        print("cropping {} volumes...".format(subset))
        # crop to smallest enclosing volume
        self.volumes = [crop_sample(v) for v in self.volumes]

        print("padding {} volumes...".format(subset))
        # pad to square
        self.volumes = [pad_sample(v) for v in self.volumes]

        print("resizing {} volumes...".format(subset))
        # resize
        self.volumes = [resize_sample(v, size=image_size) for v in self.volumes]

        print("normalizing {} volumes...".format(subset))
        # normalize channel-wise
        self.volumes = [(normalize_volume(v), m) for v, m in self.volumes]

        # probabilities for sampling slices based on masks
        self.slice_weights = [m.sum(axis=-1).sum(axis=-1) for v, m in self.volumes]
        self.slice_weights = [
            (s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in self.slice_weights
        ]

        # add channel dimension to masks
        self.volumes = [(v, m[..., np.newaxis]) for (v, m) in self.volumes]

        print("done creating {} dataset".format(subset))

        # create global index for patient and slice (idx -> (p_idx, s_idx))
        num_slices = [v.shape[0] for v, m in self.volumes]
        self.patient_slice_index = list(
            zip(
                sum([[i] * num_slices[i] for i in range(len(num_slices))], []),
                sum([list(range(x)) for x in num_slices], []),
            )
        )

        self.random_sampling = random_sampling

        self.transform = transform

    def __len__(self):
        return len(self.patient_slice_index)

    def __getitem__(self, idx):
        patient = self.patient_slice_index[idx][0]
        slice_n = self.patient_slice_index[idx][1]

        if self.random_sampling:
            patient = np.random.randint(len(self.volumes))
            slice_n = np.random.choice(
                range(self.volumes[patient][0].shape[0]), p=self.slice_weights[patient]
            )

        v, m = self.volumes[patient]
        image = v[slice_n]
        mask = m[slice_n]

        if self.transform is not None:
            image, mask = self.transform((image, mask))

        # fix dimensions (C, H, W)
        image = image.transpose(2, 0, 1)
        mask = mask.transpose(2, 0, 1)

        image_tensor = torch.from_numpy(image.astype(np.float32))
        mask_tensor = torch.from_numpy(mask.astype(np.float32))

        # return tensors
        return image_tensor, mask_tensor
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
def data_loaders(batch_size, workers, image_size, aug_scale, aug_angle):
    dataset_train, dataset_valid = datasets('./archive/lgg-mri-segmentation/kaggle_3m', image_size, aug_scale, aug_angle)

    def worker_init(worker_id):
        np.random.seed(42 + worker_id)

    loader_train = DataLoader(
        dataset_train,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=workers,
        worker_init_fn=worker_init,
    )
    loader_valid = DataLoader(
        dataset_valid,
        batch_size=batch_size,
        drop_last=False,
        num_workers=workers,
        worker_init_fn=worker_init,
    )

    return loader_train, loader_valid
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
def datasets(images, image_size, aug_scale, aug_angle):
    train = BrainSegmentationDataset(
        images_dir=images,
        subset="train",
        image_size=image_size,
        transform=transforms(scale=aug_scale, angle=aug_angle, flip_prob=0.5),
    )
    valid = BrainSegmentationDataset(
        images_dir=images,
        subset="validation",
        image_size=image_size,
        random_sampling=False,
    )
    return train, valid

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

网络构建

class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92

metric计算

class DiceLoss(nn.Module):

    def __init__(self):
        super(DiceLoss, self).__init__()
        self.smooth = 1.0

    def forward(self, y_pred, y_true):
        assert y_pred.size() == y_true.size()
        y_pred = y_pred[:, 0].contiguous().view(-1)
        y_true = y_true[:, 0].contiguous().view(-1)
        intersection = (y_pred * y_true).sum()
        dsc = (2. * intersection + self.smooth) / (
            y_pred.sum() + y_true.sum() + self.smooth
        )
        return 1. - dsc


def log_images(x, y_true, y_pred, channel=1):
    images = []
    x_np = x[:, channel].cpu().numpy()
    y_true_np = y_true[:, 0].cpu().numpy()
    y_pred_np = y_pred[:, 0].cpu().numpy()
    for i in range(x_np.shape[0]):
        image = gray2rgb(np.squeeze(x_np[i]))
        image = outline(image, y_pred_np[i], color=[255, 0, 0])
        image = outline(image, y_true_np[i], color=[0, 255, 0])
        images.append(image)
    return images


def gray2rgb(image):
    w, h = image.shape
    image += np.abs(np.min(image))
    image_max = np.abs(np.max(image))
    if image_max > 0:
        image /= image_max
    ret = np.empty((w, h, 3), dtype=np.uint8)
    ret[:, :, 2] = ret[:, :, 1] = ret[:, :, 0] = image * 255
    return ret


def outline(image, mask, color):
    mask = np.round(mask)
    yy, xx = np.nonzero(mask)
    for y, x in zip(yy, xx):
        if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < 1.0:
            image[max(0, y) : y + 1, max(0, x) : x + 1] = color
    return image
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
def dsc(y_pred, y_true):
    y_pred = np.round(y_pred).astype(int)
    y_true = np.round(y_true).astype(int)
    return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))


def dsc_distribution(volumes):
    dsc_dict = {}
    for p in volumes:
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        dsc_dict[p] = dsc(y_pred, y_true)
    return dsc_dict


def dsc_per_volume(validation_pred, validation_true, patient_slice_index):
    dsc_list = []
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        y_pred = np.array(validation_pred[index : index + num_slices[p]])
        y_true = np.array(validation_true[index : index + num_slices[p]])
        dsc_list.append(dsc(y_pred, y_true))
        index += num_slices[p]
    return dsc_list


def postprocess_per_volume(
    input_list, pred_list, true_list, patient_slice_index, patients
):
    volumes = {}
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        volume_in = np.array(input_list[index : index + num_slices[p]])
        volume_pred = np.round(
            np.array(pred_list[index : index + num_slices[p]])
        ).astype(int)
        volume_true = np.array(true_list[index : index + num_slices[p]])
        volumes[patients[p]] = (volume_in, volume_pred, volume_true)
        index += num_slices[p]
    return volumes


def log_loss_summary(loss, step, prefix=""):
    print("epoch {} | {}: {}".format(step + 1, prefix + "loss", np.mean(loss)))

def log_scalar_summary(tag, value, step):
    print("epoch {} | {}: {}".format(step + 1, tag, value))


def plot_dsc(dsc_dist):
    y_positions = np.arange(len(dsc_dist))
    dsc_dist = sorted(dsc_dist.items(), key=lambda x: x[1])
    values = [x[1] for x in dsc_dist]
    labels = [x[0] for x in dsc_dist]
    labels = ["_".join(l.split("_")[1:-1]) for l in labels]
    fig = plt.figure(figsize=(12, 8))
    canvas = FigureCanvasAgg(fig)
    plt.barh(y_positions, values, align="center", color="skyblue")
    plt.yticks(y_positions, labels)
    plt.xticks(np.arange(0.0, 1.0, 0.1))
    plt.xlim([0.0, 1.0])
    plt.gca().axvline(np.mean(values), color="tomato", linewidth=2)
    plt.gca().axvline(np.median(values), color="forestgreen", linewidth=2)
    plt.xlabel("Dice coefficient", fontsize="x-large")
    plt.gca().xaxis.grid(color="silver", alpha=0.5, linestyle="--", linewidth=1)
    plt.tight_layout()
    canvas.draw()
    plt.close()
    s, (width, height) = canvas.print_to_buffer()
    return np.fromstring(s, np.uint8).reshape((height, width, 4))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72

训练与评估

batch_size = 128
epochs = 300
lr = 0.0001
workers = 8
weights = "./"
image_size = 256
aug_scale = 0.05
aug_angle = 15
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
def train_validate():
    device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
    
    loader_train, loader_valid = data_loaders(batch_size, workers, image_size, aug_scale, aug_angle)
    loaders = {"train": loader_train, "valid": loader_valid}
    
    unet = UNet(in_channels=BrainSegmentationDataset.in_channels, out_channels=BrainSegmentationDataset.out_channels)
    unet.to(device)
    
    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0
    
    optimizer = optim.Adam(unet.parameters(), lr=lr)
    
    loss_train = []
    loss_valid = []
    
    step = 0
    
    for epoch in range(epochs):
        for phase in ["train", "valid"]:
            if phase == "train":
                unet.train()
            else:
                unet.eval()
    
            validation_pred = []
            validation_true = []
    
            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1
    
                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)
    
                optimizer.zero_grad()
    
                with torch.set_grad_enabled(phase == "train"):
                    y_pred = unet(x)
    
                    loss = dsc_loss(y_pred, y_true)
    
                    if phase == "valid":
                        loss_valid.append(loss.item())
                        y_pred_np = y_pred.detach().cpu().numpy()
                        validation_pred.extend(
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                        )
                        y_true_np = y_true.detach().cpu().numpy()
                        validation_true.extend(
                            [y_true_np[s] for s in range(y_true_np.shape[0])]
                        )
                        
                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()
    
            if phase == "train":
                log_loss_summary(loss_train, epoch)
                loss_train = []

            if phase == "valid":
                log_loss_summary(loss_valid, epoch, prefix="val_")
                mean_dsc = np.mean(
                    dsc_per_volume(
                        validation_pred,
                        validation_true,
                        loader_valid.dataset.patient_slice_index,
                    )
                )
                log_scalar_summary("val_dsc", mean_dsc, epoch)
                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    torch.save(unet.state_dict(), os.path.join(weights, "unet.pt"))
                loss_valid = []
    
    print("\nBest validation mean DSC: {:4f}\n".format(best_validation_dsc))
    
    state_dict = torch.load(os.path.join(weights, "unet.pt"))
    unet.load_state_dict(state_dict)
    unet.eval()
    
    input_list = []
    pred_list = []
    true_list = []
    
    for i, data in enumerate(loader_valid):
        x, y_true = data
        x, y_true = x.to(device), y_true.to(device)
        with torch.set_grad_enabled(False):
            y_pred = unet(x)
            y_pred_np = y_pred.detach().cpu().numpy()
            pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])
            y_true_np = y_true.detach().cpu().numpy()
            true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])
            x_np = x.detach().cpu().numpy()
            input_list.extend([x_np[s] for s in range(x_np.shape[0])])
            
    volumes = postprocess_per_volume(
        input_list,
        pred_list,
        true_list,
        loader_valid.dataset.patient_slice_index,
        loader_valid.dataset.patients,
    )
    
    dsc_dist = dsc_distribution(volumes)

    dsc_dist_plot = plot_dsc(dsc_dist)
    imsave("./dsc.png", dsc_dist_plot)

    for p in volumes:
        x = volumes[p][0]
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        for s in range(x.shape[0]):
            image = gray2rgb(x[s, 1])  # channel 1 is for FLAIR
            image = outline(image, y_pred[s, 0], color=[255, 0, 0])
            image = outline(image, y_true[s, 0], color=[0, 255, 0])
            filename = "{}-{}.png".format(p, str(s).zfill(2))
            filepath = os.path.join("./resultat", filename)
            imsave(filepath, image)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
train_validate()
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/351113
推荐阅读
相关标签
  

闽ICP备14008679号