赞
踩
基于深度学习分割算法在 PyTorch 中的 U-Net 实现,用于脑 MRI 中的 FLAIR 异常分割
github代码: U-Net for brain segmentation
kaggle代码: brain-segmentation-pytorch
数据集下载:Brain MRI segmentation
数据集很小,代码也很清晰,比较好实现
class UNet(nn.Module): def __init__(self, in_channels=3, out_channels=1, init_features=32): super(UNet, self).__init__() features = init_features self.encoder1 = UNet._block(in_channels, features, name="enc1") self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.encoder2 = UNet._block(features, features * 2, name="enc2") self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.encoder3 = UNet._block(features * 2, features * 4, name="enc3") self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) self.encoder4 = UNet._block(features * 4, features * 8, name="enc4") self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck") self.upconv4 = nn.ConvTranspose2d( features * 16, features * 8, kernel_size=2, stride=2 ) self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4") self.upconv3 = nn.ConvTranspose2d( features * 8, features * 4, kernel_size=2, stride=2 ) self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3") self.upconv2 = nn.ConvTranspose2d( features * 4, features * 2, kernel_size=2, stride=2 ) self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2") self.upconv1 = nn.ConvTranspose2d( features * 2, features, kernel_size=2, stride=2 ) self.decoder1 = UNet._block(features * 2, features, name="dec1") self.conv = nn.Conv2d( in_channels=features, out_channels=out_channels, kernel_size=1 ) def forward(self, x): enc1 = self.encoder1(x) enc2 = self.encoder2(self.pool1(enc1)) enc3 = self.encoder3(self.pool2(enc2)) enc4 = self.encoder4(self.pool3(enc3)) bottleneck = self.bottleneck(self.pool4(enc4)) dec4 = self.upconv4(bottleneck) dec4 = torch.cat((dec4, enc4), dim=1) dec4 = self.decoder4(dec4) dec3 = self.upconv3(dec4) dec3 = torch.cat((dec3, enc3), dim=1) dec3 = self.decoder3(dec3) dec2 = self.upconv2(dec3) dec2 = torch.cat((dec2, enc2), dim=1) dec2 = self.decoder2(dec2) dec1 = self.upconv1(dec2) dec1 = torch.cat((dec1, enc1), dim=1) dec1 = self.decoder1(dec1) return torch.sigmoid(self.conv(dec1)) @staticmethod def _block(in_channels, features, name): return nn.Sequential( OrderedDict( [ ( name + "conv1", nn.Conv2d( in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False, ), ), (name + "norm1", nn.BatchNorm2d(num_features=features)), (name + "relu1", nn.ReLU(inplace=True)), ( name + "conv2", nn.Conv2d( in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False, ), ), (name + "norm2", nn.BatchNorm2d(num_features=features)), (name + "relu2", nn.ReLU(inplace=True)), ] ) )
unet = UNet()
print(unet)
UNet( (encoder1): Sequential( (enc1conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (enc1norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (enc1relu1): ReLU(inplace=True) (enc1conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (enc1norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (enc1relu2): ReLU(inplace=True) ) (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (encoder2): Sequential( (enc2conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (enc2norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (enc2relu1): ReLU(inplace=True) (enc2conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (enc2norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (enc2relu2): ReLU(inplace=True) ) (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (encoder3): Sequential( (enc3conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (enc3norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (enc3relu1): ReLU(inplace=True) (enc3conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (enc3norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (enc3relu2): ReLU(inplace=True) ) (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (encoder4): Sequential( (enc4conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (enc4norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (enc4relu1): ReLU(inplace=True) (enc4conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (enc4norm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (enc4relu2): ReLU(inplace=True) ) (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (bottleneck): Sequential( (bottleneckconv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bottlenecknorm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bottleneckrelu1): ReLU(inplace=True) (bottleneckconv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bottlenecknorm2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bottleneckrelu2): ReLU(inplace=True) ) (upconv4): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2)) (decoder4): Sequential( (dec4conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (dec4norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (dec4relu1): ReLU(inplace=True) (dec4conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (dec4norm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (dec4relu2): ReLU(inplace=True) ) (upconv3): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2)) (decoder3): Sequential( (dec3conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (dec3norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (dec3relu1): ReLU(inplace=True) (dec3conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (dec3norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (dec3relu2): ReLU(inplace=True) ) (upconv2): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2)) (decoder2): Sequential( (dec2conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (dec2norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (dec2relu1): ReLU(inplace=True) (dec2conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (dec2norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (dec2relu2): ReLU(inplace=True) ) (upconv1): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2)) (decoder1): Sequential( (dec1conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (dec1norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (dec1relu1): ReLU(inplace=True) (dec1conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (dec1norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (dec1relu2): ReLU(inplace=True) ) (conv): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1)) )
首先遍历文件地址和名称,通过os.walk函数。dirpath 表示当前正在访问的文件夹路径, dirnames 表示该文件夹下的子目录名list, filenames表示该文件夹下的文件list.
for (dirpath, dirnames, filenames) in os.walk(images_dir):
for filename in sorted(
filter(lambda f: ".tif" in f, filenames),# 过滤包含.tif的文件
key=lambda x: int(x.split(".")[-2].split("_")[4]), #在每个病人的文件中,根据文件名称中最后一位排序
):
print(filename)
filepath = os.path.join(dirpath, filename)
print(filepath)
TCGA_CS_6665_20010817_1.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_1.tif TCGA_CS_6665_20010817_1_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_1_mask.tif TCGA_CS_6665_20010817_2_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_2_mask.tif TCGA_CS_6665_20010817_2.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_2.tif TCGA_CS_6665_20010817_3.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_3.tif TCGA_CS_6665_20010817_3_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_3_mask.tif TCGA_CS_6665_20010817_4_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_4_mask.tif TCGA_CS_6665_20010817_4.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_4.tif TCGA_CS_6665_20010817_5_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_5_mask.tif TCGA_CS_6665_20010817_5.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_5.tif TCGA_CS_6665_20010817_6.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_6.tif TCGA_CS_6665_20010817_6_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_6_mask.tif TCGA_CS_6665_20010817_7_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_7_mask.tif TCGA_CS_6665_20010817_7.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_7.tif TCGA_CS_6665_20010817_8.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_8.tif TCGA_CS_6665_20010817_8_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_8_mask.tif TCGA_CS_6665_20010817_9.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_9.tif TCGA_CS_6665_20010817_9_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_9_mask.tif TCGA_CS_6665_20010817_10_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_10_mask.tif TCGA_CS_6665_20010817_10.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_10.tif TCGA_CS_6665_20010817_11_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_11_mask.tif TCGA_CS_6665_20010817_11.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_11.tif TCGA_CS_6665_20010817_12.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_12.tif TCGA_CS_6665_20010817_12_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_12_mask.tif TCGA_CS_6665_20010817_13_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_13_mask.tif TCGA_CS_6665_20010817_13.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_13.tif TCGA_CS_6665_20010817_14_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_14_mask.tif TCGA_CS_6665_20010817_14.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_14.tif TCGA_CS_6665_20010817_15_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_15_mask.tif TCGA_CS_6665_20010817_15.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_15.tif TCGA_CS_6665_20010817_16.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_16.tif TCGA_CS_6665_20010817_16_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_16_mask.tif TCGA_CS_6665_20010817_17_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_17_mask.tif TCGA_CS_6665_20010817_17.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_17.tif TCGA_CS_6665_20010817_18_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_18_mask.tif TCGA_CS_6665_20010817_18.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_18.tif TCGA_CS_6665_20010817_19.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_19.tif TCGA_CS_6665_20010817_19_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_19_mask.tif TCGA_CS_6665_20010817_20_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_20_mask.tif TCGA_CS_6665_20010817_20.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_20.tif TCGA_CS_6665_20010817_21.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_21.tif TCGA_CS_6665_20010817_21_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_21_mask.tif TCGA_CS_6665_20010817_22.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_22.tif TCGA_CS_6665_20010817_22_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_22_mask.tif TCGA_CS_6665_20010817_23_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_23_mask.tif TCGA_CS_6665_20010817_23.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_23.tif TCGA_CS_6665_20010817_24_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_24_mask.tif TCGA_CS_6665_20010817_24.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6665_20010817/TCGA_CS_6665_20010817_24.tif TCGA_CS_6669_20020102_1.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_1.tif TCGA_CS_6669_20020102_1_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_1_mask.tif TCGA_CS_6669_20020102_2_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_2_mask.tif TCGA_CS_6669_20020102_2.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_2.tif TCGA_CS_6669_20020102_3.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_3.tif TCGA_CS_6669_20020102_3_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_3_mask.tif TCGA_CS_6669_20020102_4.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_4.tif TCGA_CS_6669_20020102_4_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_4_mask.tif TCGA_CS_6669_20020102_5.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_5.tif TCGA_CS_6669_20020102_5_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_5_mask.tif TCGA_CS_6669_20020102_6_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_6_mask.tif TCGA_CS_6669_20020102_6.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_6.tif TCGA_CS_6669_20020102_7_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_7_mask.tif TCGA_CS_6669_20020102_7.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_7.tif TCGA_CS_6669_20020102_8.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_8.tif TCGA_CS_6669_20020102_8_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_8_mask.tif TCGA_CS_6669_20020102_9_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_9_mask.tif TCGA_CS_6669_20020102_9.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_9.tif TCGA_CS_6669_20020102_10_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_10_mask.tif TCGA_CS_6669_20020102_10.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_10.tif TCGA_CS_6669_20020102_11.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_11.tif TCGA_CS_6669_20020102_11_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_11_mask.tif TCGA_CS_6669_20020102_12_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_12_mask.tif TCGA_CS_6669_20020102_12.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_12.tif TCGA_CS_6669_20020102_13_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_13_mask.tif TCGA_CS_6669_20020102_13.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_13.tif TCGA_CS_6669_20020102_14_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_14_mask.tif TCGA_CS_6669_20020102_14.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_14.tif TCGA_CS_6669_20020102_15.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_15.tif TCGA_CS_6669_20020102_15_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_15_mask.tif TCGA_CS_6669_20020102_16_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_16_mask.tif TCGA_CS_6669_20020102_16.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_16.tif TCGA_CS_6669_20020102_17.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_17.tif TCGA_CS_6669_20020102_17_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_17_mask.tif TCGA_CS_6669_20020102_18_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_18_mask.tif TCGA_CS_6669_20020102_18.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_18.tif TCGA_CS_6669_20020102_19_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_19_mask.tif TCGA_CS_6669_20020102_19.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_19.tif TCGA_CS_6669_20020102_20_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_20_mask.tif TCGA_CS_6669_20020102_20.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_20.tif TCGA_CS_6669_20020102_21_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_21_mask.tif TCGA_CS_6669_20020102_21.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_21.tif TCGA_CS_6669_20020102_22_mask.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_22_mask.tif TCGA_CS_6669_20020102_22.tif ./archive/lgg-mri-segmentation/kaggle_3m/TCGA_CS_6669_20020102/TCGA_CS_6669_20020102_22.tif
# read images volumes = {} masks = {} images_dir = './archive/lgg-mri-segmentation/kaggle_3m' print("reading {} images...".format("train")) for (dirpath, dirnames, filenames) in os.walk(images_dir): image_slices = [] mask_slices = [] for filename in sorted( filter(lambda f: ".tif" in f, filenames), key=lambda x: int(x.split(".")[-2].split("_")[4]), ): filepath = os.path.join(dirpath, filename) if "mask" in filename: mask_slices.append(imread(filepath, as_gray=True)) else: image_slices.append(imread(filepath)) if len(image_slices) > 0: patient_id = dirpath.split("/")[-1] volumes[patient_id] = np.array(image_slices[1:-1]) masks[patient_id] = np.array(mask_slices[1:-1])
将所以病人数据储存在dict格式中,这里作者舍去了每个病人第一个和最后一个图片image_slices[1:-1],不太清楚是为什么。
print(len(volumes)) # 110个病人
print(len(volumes['TCGA_DU_8163_19961119'])) # 这个病人有35个切片
随机选择十个病人的数据作为验证集,剩下的是训练集
patients_list = sorted(volumes) # 所有病人的名称 "all"时返回全部
seed=42
subset = "train"
# select cases to subset
if not subset == "all":
random.seed(seed)
validation_patients = random.sample(patients_list, k=10) # 从序列中随机抽取k个元素,并将k个元素生以list形式返回。
if subset == "validation":
patients_list = validation_patients
else: # "train"
patients_list = sorted(
list(set(patients_list).difference(validation_patients)) # difference() 方法用于返回集合的差集
)
print(patients_list)
['TCGA_CS_4941_19960909', 'TCGA_CS_4942_19970222', 'TCGA_CS_4943_20000902', 'TCGA_CS_5393_19990606', 'TCGA_CS_5395_19981004', 'TCGA_CS_5396_20010302', 'TCGA_CS_5397_20010315', 'TCGA_CS_6186_20000601', 'TCGA_CS_6188_20010812', 'TCGA_CS_6290_20000917', 'TCGA_CS_6665_20010817', 'TCGA_CS_6666_20011109', 'TCGA_CS_6669_20020102', 'TCGA_DU_5849_19950405', 'TCGA_DU_5852_19950709', 'TCGA_DU_5853_19950823', 'TCGA_DU_5854_19951104', 'TCGA_DU_5855_19951217', 'TCGA_DU_5871_19941206', 'TCGA_DU_5872_19950223', 'TCGA_DU_5874_19950510', 'TCGA_DU_6399_19830416', 'TCGA_DU_6400_19830518', 'TCGA_DU_6401_19831001', 'TCGA_DU_6405_19851005', 'TCGA_DU_6407_19860514', 'TCGA_DU_7008_19830723', 'TCGA_DU_7010_19860307', 'TCGA_DU_7013_19860523', 'TCGA_DU_7018_19911220', 'TCGA_DU_7019_19940908', 'TCGA_DU_7294_19890104', 'TCGA_DU_7298_19910324', 'TCGA_DU_7299_19910417', 'TCGA_DU_7300_19910814', 'TCGA_DU_7301_19911112', 'TCGA_DU_7302_19911203', 'TCGA_DU_7304_19930325', 'TCGA_DU_7306_19930512', 'TCGA_DU_7309_19960831', 'TCGA_DU_8162_19961029', 'TCGA_DU_8163_19961119', 'TCGA_DU_8164_19970111', 'TCGA_DU_8165_19970205', 'TCGA_DU_8166_19970322', 'TCGA_DU_8167_19970402', 'TCGA_DU_8168_19970503', 'TCGA_DU_A5TP_19970614', 'TCGA_DU_A5TR_19970726', 'TCGA_DU_A5TS_19970726', 'TCGA_DU_A5TT_19980318', 'TCGA_DU_A5TU_19980312', 'TCGA_DU_A5TW_19980228', 'TCGA_DU_A5TY_19970709', 'TCGA_EZ_7264_20010816', 'TCGA_FG_5962_20000626', 'TCGA_FG_5964_20010511', 'TCGA_FG_6688_20020215', 'TCGA_FG_6689_20020326', 'TCGA_FG_6690_20020226', 'TCGA_FG_6691_20020405', 'TCGA_FG_6692_20020606', 'TCGA_FG_7634_20000128', 'TCGA_FG_7637_20000922', 'TCGA_FG_7643_20021104', 'TCGA_FG_8189_20030516', 'TCGA_FG_A4MT_20020212', 'TCGA_FG_A4MU_20030903', 'TCGA_FG_A60K_20040224', 'TCGA_HT_7473_19970826', 'TCGA_HT_7475_19970918', 'TCGA_HT_7602_19951103', 'TCGA_HT_7605_19950916', 'TCGA_HT_7608_19940304', 'TCGA_HT_7680_19970202', 'TCGA_HT_7684_19950816', 'TCGA_HT_7686_19950629', 'TCGA_HT_7690_19960312', 'TCGA_HT_7693_19950520', 'TCGA_HT_7694_19950404', 'TCGA_HT_7855_19951020', 'TCGA_HT_7856_19950831', 'TCGA_HT_7860_19960513', 'TCGA_HT_7874_19950902', 'TCGA_HT_7877_19980917', 'TCGA_HT_7881_19981015', 'TCGA_HT_7882_19970125', 'TCGA_HT_7884_19980913', 'TCGA_HT_8018_19970411', 'TCGA_HT_8105_19980826', 'TCGA_HT_8106_19970727', 'TCGA_HT_8107_19980708', 'TCGA_HT_8111_19980330', 'TCGA_HT_8113_19930809', 'TCGA_HT_8114_19981030', 'TCGA_HT_8563_19981209', 'TCGA_HT_A5RC_19990831', 'TCGA_HT_A616_19991226', 'TCGA_HT_A61A_20000127', 'TCGA_HT_A61B_19991127']
作者自己写的数据增强函数,同时对数据和mask做数据增强。
def crop_sample(x): volume, mask = x volume[volume < np.max(volume) * 0.1] = 0 z_projection = np.max(np.max(np.max(volume, axis=-1), axis=-1), axis=-1) z_nonzero = np.nonzero(z_projection) z_min = np.min(z_nonzero) z_max = np.max(z_nonzero) + 1 y_projection = np.max(np.max(np.max(volume, axis=0), axis=-1), axis=-1) y_nonzero = np.nonzero(y_projection) y_min = np.min(y_nonzero) y_max = np.max(y_nonzero) + 1 x_projection = np.max(np.max(np.max(volume, axis=0), axis=0), axis=-1) x_nonzero = np.nonzero(x_projection) x_min = np.min(x_nonzero) x_max = np.max(x_nonzero) + 1 return ( volume[z_min:z_max, y_min:y_max, x_min:x_max], mask[z_min:z_max, y_min:y_max, x_min:x_max], ) def pad_sample(x): volume, mask = x a = volume.shape[1] b = volume.shape[2] if a == b: return volume, mask diff = (max(a, b) - min(a, b)) / 2.0 if a > b: padding = ((0, 0), (0, 0), (int(np.floor(diff)), int(np.ceil(diff)))) else: padding = ((0, 0), (int(np.floor(diff)), int(np.ceil(diff))), (0, 0)) mask = np.pad(mask, padding, mode="constant", constant_values=0) padding = padding + ((0, 0),) volume = np.pad(volume, padding, mode="constant", constant_values=0) return volume, mask def resize_sample(x, size=256): volume, mask = x v_shape = volume.shape out_shape = (v_shape[0], size, size) mask = resize( mask, output_shape=out_shape, order=0, mode="constant", cval=0, anti_aliasing=False, ) out_shape = out_shape + (v_shape[3],) volume = resize( volume, output_shape=out_shape, order=2, mode="constant", cval=0, anti_aliasing=False, ) return volume, mask def normalize_volume(volume): p10 = np.percentile(volume, 10) p99 = np.percentile(volume, 99) volume = rescale_intensity(volume, in_range=(p10, p99)) m = np.mean(volume, axis=(0, 1, 2)) s = np.std(volume, axis=(0, 1, 2)) volume = (volume - m) / s return volume
print("preprocessing {} volumes...".format(subset)) # create list of tuples (volume, mask) volumes_list = [(volumes[k], masks[k]) for k in patients_list] print("cropping {} volumes...".format(subset)) # crop to smallest enclosing volume volumes_list = [crop_sample(v) for v in volumes_list] print("padding {} volumes...".format(subset)) # pad to square volumes_list = [pad_sample(v) for v in volumes_list] print("resizing {} volumes...".format(subset)) # resize volumes_list = [resize_sample(v, size=image_size) for v in volumes_list] print("normalizing {} volumes...".format(subset)) # normalize channel-wise volumes_list = [(normalize_volume(v), m) for v, m in volumes_list]
根据mask计算出切片概率,随机采样的时候使用
# probabilities for sampling slices based on masks
slice_weights = [m.sum(axis=-1).sum(axis=-1) for v, m in volumes_list]
slice_weights = [
(s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in slice_weights
]
print(len(slice_weights)) #100
print(len(slice_weights[0])) #21
mask只有三维,我们需要增加一个维度
# add channel dimension to masks
volumes_list = [(v, m[..., np.newaxis]) for (v, m) in volumes_list] # ... 等于 [:,:,:]
一个是病人的索引,另一个是切片的索引
# create global index for patient and slice (idx -> (p_idx, s_idx))
num_slices = [v.shape[0] for v, m in volumes_list] # 每个病人的切片个数
patient_slice_index = list(
zip(
sum([[i] * num_slices[i] for i in range(len(num_slices))], []),
sum([list(range(x)) for x in num_slices], []),
)
)
zip() 将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
sum的作用
最后的索引是这样的,第一个病人的地一个slice,第一个病人的地二个slice,第一个病人的地三个slice…第二个病人的地一个slice…
首先__len__返回的是patient_slice_index的长度,也就是我们有多少组图像(数据+mask算一组),我们训练的时getitem每次产生一组2d数据来进行训练,数据图像当作输入,mask当作label。
idx = 50
patient = patient_slice_index[idx][0]
slice_n = patient_slice_index[idx][1]
v, m = volumes_list[patient]
image = v[slice_n]
mask = m[slice_n]
print(len(volumes_list)) # 100个病人
print(v.shape) # 每个病人的切片数量
print(m.shape) # 对应的mask
print(image.shape)
print(mask.shape)
100
(18, 224, 224, 3)
(18, 224, 224, 1)
(224, 224, 3)
(224, 224, 1)
首先通过index随机选出来一个病人patient,然后对于这个病人,随机选一个slice。
# fix dimensions (C, H, W)
image = image.transpose(2, 0, 1)
mask = mask.transpose(2, 0, 1)
image_tensor = torch.from_numpy(image.astype(np.float32))
mask_tensor = torch.from_numpy(mask.astype(np.float32))
最后因为训练需要tensor格式的数据,所以将通道放到第一位,然后转化为tensor
图像分割常用评价指标DSC。对于分割过程中的评价标准主要采用Dice相似系数(Dice Similariy Coefficient,DSC),Dice系数是一种集合相似度度量指标,通常用于计算两个样本的相似度,值的范围 0-1 ,分割结果最好时值为 1 ,最差时值为 0.
详情: 图像分割常用评价指标DSC、Hausdorff_95、IOU、PPV等
训练过程没有太多要解释的,按pytorch正常流程走就行。训练结果
reading train images... preprocessing train volumes... cropping train volumes... padding train volumes... resizing train volumes... normalizing train volumes... done creating train dataset reading validation images... preprocessing validation volumes... cropping validation volumes... padding validation volumes... resizing validation volumes... normalizing validation volumes... done creating validation dataset epoch 1 | loss: 0.8733632518694951 epoch 1 | val_loss: 0.9460358023643494 epoch 1 | val_dsc: 0.1948670436467475 epoch 2 | loss: 0.8402993633196905 epoch 2 | val_loss: 0.931992749373118 epoch 2 | val_dsc: 0.4118475790023708 epoch 3 | loss: 0.8270544547301072 epoch 3 | val_loss: 0.9293850064277649 epoch 3 | val_dsc: 0.4527188003808261 epoch 4 | loss: 0.8215367862811456 epoch 4 | val_loss: 0.9270628492037455 epoch 4 | val_dsc: 0.6138052787773625 epoch 5 | loss: 0.8171369204154382 epoch 5 | val_loss: 0.9256295363108317 epoch 5 | val_dsc: 0.4911489058649566 epoch 6 | loss: 0.8134041795363793 epoch 6 | val_loss: 0.9244295557339987 epoch 6 | val_dsc: 0.6963948791553534 epoch 7 | loss: 0.8091526673390315 epoch 7 | val_loss: 0.9240273038546244 epoch 7 | val_dsc: 0.7177726293762504 epoch 8 | loss: 0.8052632143864265 epoch 8 | val_loss: 0.9218348860740662 epoch 8 | val_dsc: 0.7100711862449238 epoch 9 | loss: 0.801162777038721 epoch 9 | val_loss: 0.9223942359288534 epoch 9 | val_dsc: 0.7482642381530232 epoch 10 | loss: 0.7974189153084388 epoch 10 | val_loss: 0.9184039433797201 ..... Best validation mean DSC: 0.855153
看一下验证过程dsc的计算过程
validation_true = []
for i, data in enumerate(loader_valid):
x, y_true = data
x, y_true = x.to(device), y_true.to(device)
print(y_true.shape)
y_true_np = y_true.detach().cpu().numpy()
print(y_true_np.shape)
validation_true.extend(
[y_true_np[s] for s in range(y_true_np.shape[0])]
)
print(len(validation_true))
print(validation_true[0].shape)
break
torch.Size([128, 1, 256, 256])
(128, 1, 256, 256)
128
(1, 256, 256)
这里我选择batc_size为128, 所以每次产生128张数据,将其保存到validation_true和validation_pred中。
mean_dsc = np.mean(
dsc_per_volume(
validation_pred,
validation_true,
loader_valid.dataset.patient_slice_index,
)
)
在每个epoch结束之后,计算这个epoch所有数据的dsc均值。
if mean_dsc > best_validation_dsc:
best_validation_dsc = mean_dsc
torch.save(unet.state_dict(), os.path.join(weights, "unet.pt"))
每当出现更好的结果,我们将模型保存下来
我们使用之前保存的模型进行图像预测
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") loader_train, loader_valid = data_loaders(batch_size, workers, image_size, aug_scale, aug_angle) loaders = {"train": loader_train, "valid": loader_valid} unet = UNet(in_channels=BrainSegmentationDataset.in_channels, out_channels=BrainSegmentationDataset.out_channels) unet.to(device) state_dict = torch.load(os.path.join('./', "unet.pt")) unet.load_state_dict(state_dict) unet.eval() input_list = [] pred_list = [] true_list = [] for i, data in enumerate(loader_valid): x, y_true = data x, y_true = x.to(device), y_true.to(device) with torch.set_grad_enabled(False): y_pred = unet(x) y_pred_np = y_pred.detach().cpu().numpy() pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])]) y_true_np = y_true.detach().cpu().numpy() true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])]) x_np = x.detach().cpu().numpy() input_list.extend([x_np[s] for s in range(x_np.shape[0])])
然后他用这个函数处理预测出来的结果. 现在我们只是有了每个slice的结果,我们需要将这些预测出来的结果按病人划分到一起,这样我们才能更好的观察分割结果。
def postprocess_per_volume(
input_list, pred_list, true_list, patient_slice_index, patients
):
volumes = {}
num_slices = np.bincount([p[0] for p in patient_slice_index])
index = 0
for p in range(len(num_slices)):
volume_in = np.array(input_list[index : index + num_slices[p]])
volume_pred = np.round(
np.array(pred_list[index : index + num_slices[p]])
).astype(int)
volume_true = np.array(true_list[index : index + num_slices[p]])
volumes[patients[p]] = (volume_in, volume_pred, volume_true)
index += num_slices[p]
return volumes
我们来看一下这个函数的用法
首先,我们先统计一下验证集/测试集中每个病人slice的个数。 在数据读取阶段,我们看过loader_valid.dataset.patient_slice_index的结果,他像二位坐标一样确定了哪张slice,第一个病人的第一张实力测,第一个病人的第二张slice…
然后我们使用np.bincount统计二维i坐标的第一个维度,他的数量就是每个病人slice的数量
然后根据每个病人slice的数量,将预测出来的结果划分到一起。
然后,计算出每个病人的dice并画出来
dsc_dist = dsc_distribution(volumes)
dsc_dist_plot = plot_dsc(dsc_dist)
imsave("./dsc.png", dsc_dist_plot)
红线是均值,绿线是中位数
最后,我们还要看一下分割的效果
for p in volumes:
x = volumes[p][0]
y_pred = volumes[p][1]
y_true = volumes[p][2]
for s in range(x.shape[0]):
image = gray2rgb(x[s, 1]) # channel 1 is for FLAIR
image = outline(image, y_pred[s, 0], color=[255, 0, 0])
image = outline(image, y_true[s, 0], color=[0, 255, 0])
filename = "{}-{}.png".format(p, str(s).zfill(2))
filepath = os.path.join("./resultat", filename)
imsave(filepath, image)
在原始图像上画出预测出来的结果(红色)和gt(绿色)
import os import random from collections import OrderedDict import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from matplotlib import pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg from tqdm import tqdm from skimage.exposure import rescale_intensity from skimage.io import imread, imsave from skimage.transform import resize, rescale, rotate from torch.utils.data import Dataset from torchvision.transforms import Compose
def crop_sample(x): volume, mask = x volume[volume < np.max(volume) * 0.1] = 0 z_projection = np.max(np.max(np.max(volume, axis=-1), axis=-1), axis=-1) z_nonzero = np.nonzero(z_projection) z_min = np.min(z_nonzero) z_max = np.max(z_nonzero) + 1 y_projection = np.max(np.max(np.max(volume, axis=0), axis=-1), axis=-1) y_nonzero = np.nonzero(y_projection) y_min = np.min(y_nonzero) y_max = np.max(y_nonzero) + 1 x_projection = np.max(np.max(np.max(volume, axis=0), axis=0), axis=-1) x_nonzero = np.nonzero(x_projection) x_min = np.min(x_nonzero) x_max = np.max(x_nonzero) + 1 return ( volume[z_min:z_max, y_min:y_max, x_min:x_max], mask[z_min:z_max, y_min:y_max, x_min:x_max], ) def pad_sample(x): volume, mask = x a = volume.shape[1] b = volume.shape[2] if a == b: return volume, mask diff = (max(a, b) - min(a, b)) / 2.0 if a > b: padding = ((0, 0), (0, 0), (int(np.floor(diff)), int(np.ceil(diff)))) else: padding = ((0, 0), (int(np.floor(diff)), int(np.ceil(diff))), (0, 0)) mask = np.pad(mask, padding, mode="constant", constant_values=0) padding = padding + ((0, 0),) volume = np.pad(volume, padding, mode="constant", constant_values=0) return volume, mask def resize_sample(x, size=256): volume, mask = x v_shape = volume.shape out_shape = (v_shape[0], size, size) mask = resize( mask, output_shape=out_shape, order=0, mode="constant", cval=0, anti_aliasing=False, ) out_shape = out_shape + (v_shape[3],) volume = resize( volume, output_shape=out_shape, order=2, mode="constant", cval=0, anti_aliasing=False, ) return volume, mask def normalize_volume(volume): p10 = np.percentile(volume, 10) p99 = np.percentile(volume, 99) volume = rescale_intensity(volume, in_range=(p10, p99)) m = np.mean(volume, axis=(0, 1, 2)) s = np.std(volume, axis=(0, 1, 2)) volume = (volume - m) / s return volume
def transforms(scale=None, angle=None, flip_prob=None): transform_list = [] if scale is not None: transform_list.append(Scale(scale)) if angle is not None: transform_list.append(Rotate(angle)) if flip_prob is not None: transform_list.append(HorizontalFlip(flip_prob)) return Compose(transform_list) class Scale(object): def __init__(self, scale): self.scale = scale def __call__(self, sample): image, mask = sample img_size = image.shape[0] scale = np.random.uniform(low=1.0 - self.scale, high=1.0 + self.scale) image = rescale( image, (scale, scale), multichannel=True, preserve_range=True, mode="constant", anti_aliasing=False, ) mask = rescale( mask, (scale, scale), order=0, multichannel=True, preserve_range=True, mode="constant", anti_aliasing=False, ) if scale < 1.0: diff = (img_size - image.shape[0]) / 2.0 padding = ((int(np.floor(diff)), int(np.ceil(diff))),) * 2 + ((0, 0),) image = np.pad(image, padding, mode="constant", constant_values=0) mask = np.pad(mask, padding, mode="constant", constant_values=0) else: x_min = (image.shape[0] - img_size) // 2 x_max = x_min + img_size image = image[x_min:x_max, x_min:x_max, ...] mask = mask[x_min:x_max, x_min:x_max, ...] return image, mask class Rotate(object): def __init__(self, angle): self.angle = angle def __call__(self, sample): image, mask = sample angle = np.random.uniform(low=-self.angle, high=self.angle) image = rotate(image, angle, resize=False, preserve_range=True, mode="constant") mask = rotate( mask, angle, resize=False, order=0, preserve_range=True, mode="constant" ) return image, mask class HorizontalFlip(object): def __init__(self, flip_prob): self.flip_prob = flip_prob def __call__(self, sample): image, mask = sample if np.random.rand() > self.flip_prob: return image, mask image = np.fliplr(image).copy() mask = np.fliplr(mask).copy() return image, mask
class BrainSegmentationDataset(Dataset): """Brain MRI dataset for FLAIR abnormality segmentation""" in_channels = 3 out_channels = 1 def __init__( self, images_dir, transform=None, image_size=256, subset="train", random_sampling=True, seed=42, ): assert subset in ["all", "train", "validation"] # read images volumes = {} masks = {} print("reading {} images...".format(subset)) for (dirpath, dirnames, filenames) in os.walk(images_dir): image_slices = [] mask_slices = [] for filename in sorted( filter(lambda f: ".tif" in f, filenames), key=lambda x: int(x.split(".")[-2].split("_")[4]), ): filepath = os.path.join(dirpath, filename) if "mask" in filename: mask_slices.append(imread(filepath, as_gray=True)) else: image_slices.append(imread(filepath)) if len(image_slices) > 0: patient_id = dirpath.split("/")[-1] volumes[patient_id] = np.array(image_slices[1:-1]) masks[patient_id] = np.array(mask_slices[1:-1]) self.patients = sorted(volumes) # select cases to subset if not subset == "all": random.seed(seed) validation_patients = random.sample(self.patients, k=10) if subset == "validation": self.patients = validation_patients else: self.patients = sorted( list(set(self.patients).difference(validation_patients)) ) print("preprocessing {} volumes...".format(subset)) # create list of tuples (volume, mask) self.volumes = [(volumes[k], masks[k]) for k in self.patients] print("cropping {} volumes...".format(subset)) # crop to smallest enclosing volume self.volumes = [crop_sample(v) for v in self.volumes] print("padding {} volumes...".format(subset)) # pad to square self.volumes = [pad_sample(v) for v in self.volumes] print("resizing {} volumes...".format(subset)) # resize self.volumes = [resize_sample(v, size=image_size) for v in self.volumes] print("normalizing {} volumes...".format(subset)) # normalize channel-wise self.volumes = [(normalize_volume(v), m) for v, m in self.volumes] # probabilities for sampling slices based on masks self.slice_weights = [m.sum(axis=-1).sum(axis=-1) for v, m in self.volumes] self.slice_weights = [ (s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in self.slice_weights ] # add channel dimension to masks self.volumes = [(v, m[..., np.newaxis]) for (v, m) in self.volumes] print("done creating {} dataset".format(subset)) # create global index for patient and slice (idx -> (p_idx, s_idx)) num_slices = [v.shape[0] for v, m in self.volumes] self.patient_slice_index = list( zip( sum([[i] * num_slices[i] for i in range(len(num_slices))], []), sum([list(range(x)) for x in num_slices], []), ) ) self.random_sampling = random_sampling self.transform = transform def __len__(self): return len(self.patient_slice_index) def __getitem__(self, idx): patient = self.patient_slice_index[idx][0] slice_n = self.patient_slice_index[idx][1] if self.random_sampling: patient = np.random.randint(len(self.volumes)) slice_n = np.random.choice( range(self.volumes[patient][0].shape[0]), p=self.slice_weights[patient] ) v, m = self.volumes[patient] image = v[slice_n] mask = m[slice_n] if self.transform is not None: image, mask = self.transform((image, mask)) # fix dimensions (C, H, W) image = image.transpose(2, 0, 1) mask = mask.transpose(2, 0, 1) image_tensor = torch.from_numpy(image.astype(np.float32)) mask_tensor = torch.from_numpy(mask.astype(np.float32)) # return tensors return image_tensor, mask_tensor
def data_loaders(batch_size, workers, image_size, aug_scale, aug_angle): dataset_train, dataset_valid = datasets('./archive/lgg-mri-segmentation/kaggle_3m', image_size, aug_scale, aug_angle) def worker_init(worker_id): np.random.seed(42 + worker_id) loader_train = DataLoader( dataset_train, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=workers, worker_init_fn=worker_init, ) loader_valid = DataLoader( dataset_valid, batch_size=batch_size, drop_last=False, num_workers=workers, worker_init_fn=worker_init, ) return loader_train, loader_valid
def datasets(images, image_size, aug_scale, aug_angle):
train = BrainSegmentationDataset(
images_dir=images,
subset="train",
image_size=image_size,
transform=transforms(scale=aug_scale, angle=aug_angle, flip_prob=0.5),
)
valid = BrainSegmentationDataset(
images_dir=images,
subset="validation",
image_size=image_size,
random_sampling=False,
)
return train, valid
class UNet(nn.Module): def __init__(self, in_channels=3, out_channels=1, init_features=32): super(UNet, self).__init__() features = init_features self.encoder1 = UNet._block(in_channels, features, name="enc1") self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.encoder2 = UNet._block(features, features * 2, name="enc2") self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.encoder3 = UNet._block(features * 2, features * 4, name="enc3") self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) self.encoder4 = UNet._block(features * 4, features * 8, name="enc4") self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck") self.upconv4 = nn.ConvTranspose2d( features * 16, features * 8, kernel_size=2, stride=2 ) self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4") self.upconv3 = nn.ConvTranspose2d( features * 8, features * 4, kernel_size=2, stride=2 ) self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3") self.upconv2 = nn.ConvTranspose2d( features * 4, features * 2, kernel_size=2, stride=2 ) self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2") self.upconv1 = nn.ConvTranspose2d( features * 2, features, kernel_size=2, stride=2 ) self.decoder1 = UNet._block(features * 2, features, name="dec1") self.conv = nn.Conv2d( in_channels=features, out_channels=out_channels, kernel_size=1 ) def forward(self, x): enc1 = self.encoder1(x) enc2 = self.encoder2(self.pool1(enc1)) enc3 = self.encoder3(self.pool2(enc2)) enc4 = self.encoder4(self.pool3(enc3)) bottleneck = self.bottleneck(self.pool4(enc4)) dec4 = self.upconv4(bottleneck) dec4 = torch.cat((dec4, enc4), dim=1) dec4 = self.decoder4(dec4) dec3 = self.upconv3(dec4) dec3 = torch.cat((dec3, enc3), dim=1) dec3 = self.decoder3(dec3) dec2 = self.upconv2(dec3) dec2 = torch.cat((dec2, enc2), dim=1) dec2 = self.decoder2(dec2) dec1 = self.upconv1(dec2) dec1 = torch.cat((dec1, enc1), dim=1) dec1 = self.decoder1(dec1) return torch.sigmoid(self.conv(dec1)) @staticmethod def _block(in_channels, features, name): return nn.Sequential( OrderedDict( [ ( name + "conv1", nn.Conv2d( in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False, ), ), (name + "norm1", nn.BatchNorm2d(num_features=features)), (name + "relu1", nn.ReLU(inplace=True)), ( name + "conv2", nn.Conv2d( in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False, ), ), (name + "norm2", nn.BatchNorm2d(num_features=features)), (name + "relu2", nn.ReLU(inplace=True)), ] ) )
class DiceLoss(nn.Module): def __init__(self): super(DiceLoss, self).__init__() self.smooth = 1.0 def forward(self, y_pred, y_true): assert y_pred.size() == y_true.size() y_pred = y_pred[:, 0].contiguous().view(-1) y_true = y_true[:, 0].contiguous().view(-1) intersection = (y_pred * y_true).sum() dsc = (2. * intersection + self.smooth) / ( y_pred.sum() + y_true.sum() + self.smooth ) return 1. - dsc def log_images(x, y_true, y_pred, channel=1): images = [] x_np = x[:, channel].cpu().numpy() y_true_np = y_true[:, 0].cpu().numpy() y_pred_np = y_pred[:, 0].cpu().numpy() for i in range(x_np.shape[0]): image = gray2rgb(np.squeeze(x_np[i])) image = outline(image, y_pred_np[i], color=[255, 0, 0]) image = outline(image, y_true_np[i], color=[0, 255, 0]) images.append(image) return images def gray2rgb(image): w, h = image.shape image += np.abs(np.min(image)) image_max = np.abs(np.max(image)) if image_max > 0: image /= image_max ret = np.empty((w, h, 3), dtype=np.uint8) ret[:, :, 2] = ret[:, :, 1] = ret[:, :, 0] = image * 255 return ret def outline(image, mask, color): mask = np.round(mask) yy, xx = np.nonzero(mask) for y, x in zip(yy, xx): if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < 1.0: image[max(0, y) : y + 1, max(0, x) : x + 1] = color return image
def dsc(y_pred, y_true): y_pred = np.round(y_pred).astype(int) y_true = np.round(y_true).astype(int) return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true)) def dsc_distribution(volumes): dsc_dict = {} for p in volumes: y_pred = volumes[p][1] y_true = volumes[p][2] dsc_dict[p] = dsc(y_pred, y_true) return dsc_dict def dsc_per_volume(validation_pred, validation_true, patient_slice_index): dsc_list = [] num_slices = np.bincount([p[0] for p in patient_slice_index]) index = 0 for p in range(len(num_slices)): y_pred = np.array(validation_pred[index : index + num_slices[p]]) y_true = np.array(validation_true[index : index + num_slices[p]]) dsc_list.append(dsc(y_pred, y_true)) index += num_slices[p] return dsc_list def postprocess_per_volume( input_list, pred_list, true_list, patient_slice_index, patients ): volumes = {} num_slices = np.bincount([p[0] for p in patient_slice_index]) index = 0 for p in range(len(num_slices)): volume_in = np.array(input_list[index : index + num_slices[p]]) volume_pred = np.round( np.array(pred_list[index : index + num_slices[p]]) ).astype(int) volume_true = np.array(true_list[index : index + num_slices[p]]) volumes[patients[p]] = (volume_in, volume_pred, volume_true) index += num_slices[p] return volumes def log_loss_summary(loss, step, prefix=""): print("epoch {} | {}: {}".format(step + 1, prefix + "loss", np.mean(loss))) def log_scalar_summary(tag, value, step): print("epoch {} | {}: {}".format(step + 1, tag, value)) def plot_dsc(dsc_dist): y_positions = np.arange(len(dsc_dist)) dsc_dist = sorted(dsc_dist.items(), key=lambda x: x[1]) values = [x[1] for x in dsc_dist] labels = [x[0] for x in dsc_dist] labels = ["_".join(l.split("_")[1:-1]) for l in labels] fig = plt.figure(figsize=(12, 8)) canvas = FigureCanvasAgg(fig) plt.barh(y_positions, values, align="center", color="skyblue") plt.yticks(y_positions, labels) plt.xticks(np.arange(0.0, 1.0, 0.1)) plt.xlim([0.0, 1.0]) plt.gca().axvline(np.mean(values), color="tomato", linewidth=2) plt.gca().axvline(np.median(values), color="forestgreen", linewidth=2) plt.xlabel("Dice coefficient", fontsize="x-large") plt.gca().xaxis.grid(color="silver", alpha=0.5, linestyle="--", linewidth=1) plt.tight_layout() canvas.draw() plt.close() s, (width, height) = canvas.print_to_buffer() return np.fromstring(s, np.uint8).reshape((height, width, 4))
batch_size = 128
epochs = 300
lr = 0.0001
workers = 8
weights = "./"
image_size = 256
aug_scale = 0.05
aug_angle = 15
def train_validate(): device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") loader_train, loader_valid = data_loaders(batch_size, workers, image_size, aug_scale, aug_angle) loaders = {"train": loader_train, "valid": loader_valid} unet = UNet(in_channels=BrainSegmentationDataset.in_channels, out_channels=BrainSegmentationDataset.out_channels) unet.to(device) dsc_loss = DiceLoss() best_validation_dsc = 0.0 optimizer = optim.Adam(unet.parameters(), lr=lr) loss_train = [] loss_valid = [] step = 0 for epoch in range(epochs): for phase in ["train", "valid"]: if phase == "train": unet.train() else: unet.eval() validation_pred = [] validation_true = [] for i, data in enumerate(loaders[phase]): if phase == "train": step += 1 x, y_true = data x, y_true = x.to(device), y_true.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == "train"): y_pred = unet(x) loss = dsc_loss(y_pred, y_true) if phase == "valid": loss_valid.append(loss.item()) y_pred_np = y_pred.detach().cpu().numpy() validation_pred.extend( [y_pred_np[s] for s in range(y_pred_np.shape[0])] ) y_true_np = y_true.detach().cpu().numpy() validation_true.extend( [y_true_np[s] for s in range(y_true_np.shape[0])] ) if phase == "train": loss_train.append(loss.item()) loss.backward() optimizer.step() if phase == "train": log_loss_summary(loss_train, epoch) loss_train = [] if phase == "valid": log_loss_summary(loss_valid, epoch, prefix="val_") mean_dsc = np.mean( dsc_per_volume( validation_pred, validation_true, loader_valid.dataset.patient_slice_index, ) ) log_scalar_summary("val_dsc", mean_dsc, epoch) if mean_dsc > best_validation_dsc: best_validation_dsc = mean_dsc torch.save(unet.state_dict(), os.path.join(weights, "unet.pt")) loss_valid = [] print("\nBest validation mean DSC: {:4f}\n".format(best_validation_dsc)) state_dict = torch.load(os.path.join(weights, "unet.pt")) unet.load_state_dict(state_dict) unet.eval() input_list = [] pred_list = [] true_list = [] for i, data in enumerate(loader_valid): x, y_true = data x, y_true = x.to(device), y_true.to(device) with torch.set_grad_enabled(False): y_pred = unet(x) y_pred_np = y_pred.detach().cpu().numpy() pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])]) y_true_np = y_true.detach().cpu().numpy() true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])]) x_np = x.detach().cpu().numpy() input_list.extend([x_np[s] for s in range(x_np.shape[0])]) volumes = postprocess_per_volume( input_list, pred_list, true_list, loader_valid.dataset.patient_slice_index, loader_valid.dataset.patients, ) dsc_dist = dsc_distribution(volumes) dsc_dist_plot = plot_dsc(dsc_dist) imsave("./dsc.png", dsc_dist_plot) for p in volumes: x = volumes[p][0] y_pred = volumes[p][1] y_true = volumes[p][2] for s in range(x.shape[0]): image = gray2rgb(x[s, 1]) # channel 1 is for FLAIR image = outline(image, y_pred[s, 0], color=[255, 0, 0]) image = outline(image, y_true[s, 0], color=[0, 255, 0]) filename = "{}-{}.png".format(p, str(s).zfill(2)) filepath = os.path.join("./resultat", filename) imsave(filepath, image)
train_validate()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。