当前位置:   article > 正文

UNet语义分割网络

unet语义分割

1、本文参考

[炼丹术]UNet图像分割模型相关总结_animalslin的技术博客_51CTO博客

https://cuijiahua.com/blog/2019/11/dl-14.html

Pytorch 深度学习实战教程(二):UNet语义分割网络 - 腾讯云开发者社区-腾讯云

2、UNet网络介绍

UNet网络用于语义分割。

语义就是给图像上目标类别中的每一点打一个标签,使得不同种类的东西在图像上被区分开来。可以理解成像素级别的分类任务,即对每个像素点进行分类。

假如存在五类:Person(人)、Purse(包)、Plants/Grass(植物/草)、Sidewalk(人行道)、Building/Structures(建筑物)。需要创建一个one-hot编码的目标类别标注,即为每个类别创建一个输出通道。因为有5个类别,所以网络输出的通道数也为5,如下图所示:

 

 

因为不存在同一个像素点在两个以上的通道均为1的情况(存疑),所以预测的结果可以通过对每个像素在深度上求argmax的方式被整合到一张分割图中,进而可以通过重叠的方式观察到每个目标。

UNet网络的架构如下(实际实施时思想不变,但是略有调整):

 

3、UNet训练整体方案

(1)通过labelme进行语义标注,产出结果json文件

(2)编写代码,根据json文件的points信息,从原图中获取mask图片

(3)在UNet网络中,输入3通道图片,输出预测的1通道mask(假定只有一个识别类别),将预测的mask和实际的mask计算BCELoss从而进行拟合操作,并且输出准确率和dice score的监控指标

4、UNet网络实施分析

(1)labelme进行多边形标注

 标注完成后,会在图片所在目录下生成json文件。

(2)根据json文件生成mask图片

文件名:json2mask.py

  1. import os
  2. import cv2
  3. import numpy as np
  4. from PIL import Image, ImageDraw
  5. import json
  6. CLASS_NAMES = ['dog', 'cat']
  7. def make_mask(image_dir, save_dir):
  8. data = os.listdir(image_dir)
  9. temp_data = []
  10. for i in data:
  11. if i.split('.')[1] == 'json':
  12. temp_data.append(i)
  13. else:
  14. continue
  15. for js in temp_data:
  16. json_data = json.load(open(os.path.join(image_dir, js), 'r'))
  17. shapes_ = json_data['shapes']
  18. mask = Image.new('P', Image.open(os.path.join(image_dir, js.replace('json', 'jpg'))).size)
  19. for shape_ in shapes_:
  20. label = shape_['label']
  21. points = shape_['points']
  22. points = tuple(tuple(i) for i in points)
  23. mask_draw = ImageDraw.Draw(mask) # 类似于函数声明
  24. mask_draw.polygon(points, fill=CLASS_NAMES.index(label) + 1)
  25. mask = np.array(mask) * 255
  26. cv2.imshow('mask', mask)
  27. cv2.waitKey(0)
  28. cv2.imwrite(os.path.join(save_dir, js.replace('json', 'jpg')), mask)
  29. def vis_label(img):
  30. img = Image.open(img)
  31. img = np.array(img)
  32. print(set(img.reshape(-1).tolist()))
  33. if __name__ == '__main__':
  34. make_mask('D:\\ai_data\\cat\\val', 'D:\\ai_data\\cat\\val_mask')

说明:

  • Image.new中mode='P',代表生成的图片为8-bit pixels,适合用于生成mask图片
  • 像素值=255表示白色,也就是说mask图片中mask部分为白色,非mask部分为黑色。实际得到的mask图片中mask会存在(249,255)的值,使用时需要再处理下。

(3)UNet网络构造

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class DoubleConv(nn.Module):
  5. """(convolution => [BN] => ReLU) * 2"""
  6. def __init__(self, in_channels, out_channels):
  7. super().__init__()
  8. self.double_conv = nn.Sequential(
  9. nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),
  10. nn.BatchNorm2d(out_channels),
  11. nn.ReLU(inplace=True),
  12. nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),
  13. nn.BatchNorm2d(out_channels),
  14. nn.ReLU(inplace=True)
  15. )
  16. def forward(self, x):
  17. return self.double_conv(x)
  18. class Down(nn.Module):
  19. """Downscaling with maxpool then double conv"""
  20. def __init__(self, in_channels, out_channels):
  21. super().__init__()
  22. self.maxpool_conv = nn.Sequential(
  23. nn.MaxPool2d(2),
  24. DoubleConv(in_channels, out_channels)
  25. )
  26. def forward(self, x):
  27. return self.maxpool_conv(x)
  28. class Up(nn.Module):
  29. """Upscaling then double conv"""
  30. def __init__(self, in_channels, out_channels, bilinear=True):
  31. super().__init__()
  32. if bilinear:
  33. self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
  34. else:
  35. self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
  36. self.conv = DoubleConv(in_channels, out_channels)
  37. def forward(self, x1, x2):
  38. x1 = self.up(x1)
  39. # input is NCHW
  40. diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
  41. diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
  42. x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
  43. diffY // 2, diffY - diffY // 2])
  44. x = torch.cat([x2, x1], dim=1)
  45. return self.conv(x)
  46. class OutConv(nn.Module):
  47. def __init__(self, in_channels, out_channels):
  48. super(OutConv, self).__init__()
  49. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
  50. def forward(self, x):
  51. return self.conv(x)
  52. class UNet(nn.Module):
  53. def __init__(self, n_channels, n_classes, bilinear=False):
  54. super(UNet, self).__init__()
  55. self.n_channels = n_channels
  56. self.n_classes = n_classes
  57. self.bilinear = bilinear
  58. self.inc = DoubleConv(n_channels, 64)
  59. self.down1 = Down(64, 128)
  60. self.down2 = Down(128, 256)
  61. self.down3 = Down(256, 512)
  62. self.down4 = Down(512, 1024)
  63. self.up1 = Up(1024, 512, bilinear)
  64. self.up2 = Up(512, 256, bilinear)
  65. self.up3 = Up(256, 128, bilinear)
  66. self.up4 = Up(128, 64, bilinear)
  67. self.outc = OutConv(64, n_classes)
  68. def forward(self, x):
  69. x1 = self.inc(x)
  70. x2 = self.down1(x1)
  71. x3 = self.down2(x2)
  72. x4 = self.down3(x3)
  73. x5 = self.down4(x4)
  74. x = self.up1(x5, x4)
  75. x = self.up2(x, x3)
  76. x = self.up3(x, x2)
  77. x = self.up4(x, x1)
  78. logits = self.outc(x)
  79. return logits
  80. if __name__ == '__main__':
  81. net = UNet(n_channels=3, n_classes=1)
  82. print(net)
  83. x = torch.randn([1, 3, 572, 572])
  84. out = net(x)
  85. print(out.shape)

 说明:

  • 本代码为按照UNet论文构造的网络,实际中并未使用该代码,需要稍作修改
  • 本代码适合阅读UNet网络,不明白之处可参考:Pytorch 深度学习实战教程(二):UNet语义分割网络 - 腾讯云开发者社区-腾讯云
  •  下采样通过卷积核最大池化完成
  • 上采样通过转置卷积以及和下采样的特征concat完成。 
  • 上采样是会将通道维度减少一半,比如1024到512,因为和下采样的特征(同样也是512)在dim=1(channel维度)进行了concat,所以channel维度的值又变为了1024.
  • 论文的输入w、h和输出的w‘、h’大小不一样,在mask图片比对时会有问题,所以我们希望输入和输出的wh保持一致。此时会设置padding=1,这样double_conv时候w、h会保持不变,只有池化时变为原来的一半,上采样时候又会变为原来的两倍。
  • 以上代码只作为了解UNet网络使用,不作为整个工程的代码

(4)主函数train.py

  1. import torch
  2. import albumentations as A
  3. from albumentations.pytorch import ToTensorV2
  4. from tqdm import tqdm
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. from model import UNET
  8. # from unet_model_new import UNet
  9. from utils import (
  10. load_checkpoint,
  11. save_checkpoint,
  12. get_loaders,
  13. check_accuracy,
  14. save_predictions_as_imgs,
  15. )
  16. # 超参
  17. learning_rate = 1e-4
  18. device = 'cpu'
  19. batch_size = 1
  20. num_epochs = 30
  21. num_workers = 0
  22. image_height = 160
  23. image_width = 240
  24. pin_memory = False
  25. load_model = False
  26. train_img_dir = "D:\\ai_data\\cat\\train2"
  27. train_mask_dir = "D:\\ai_data\\cat\\train2_mask"
  28. val_img_dir = "D:\\ai_data\\cat\\val2"
  29. val_mask_dir = "D:\\ai_data\\cat\\val2_mask"
  30. def train_fn(loader, model, optimizer, loss_fn):
  31. for batch_idx, (data, targets) in enumerate(tqdm(loader)):
  32. data = data.to(device=device)
  33. targets = targets.float().unsqueeze(1).to(device=device)
  34. predictions = model(data)
  35. loss = loss_fn(predictions, targets)
  36. optimizer.zero_grad()
  37. loss.backward()
  38. def main():
  39. train_transform = A.Compose(
  40. [
  41. A.Resize(height=image_height, width=image_width),
  42. A.Rotate(limit=35, p=1.0),
  43. A.HorizontalFlip(p=0.5),
  44. A.VerticalFlip(p=0.1),
  45. A.Normalize(
  46. mean=[0.0, 0.0, 0.0],
  47. std=[1.0, 1.0, 1.0],
  48. max_pixel_value=255.0
  49. ),
  50. ToTensorV2(),
  51. ],
  52. )
  53. val_transform = A.Compose(
  54. [
  55. A.Resize(height=image_height, width=image_width),
  56. A.Normalize(
  57. mean=[0.0, 0.0, 0.0],
  58. std=[1.0, 1.0, 1.0],
  59. max_pixel_value=255.0
  60. ),
  61. ToTensorV2(),
  62. ],
  63. )
  64. model = UNET(in_channels=3, out_channels=1).to(device)
  65. loss_fn = nn.BCEWithLogitsLoss()
  66. optimizer = optim.Adam(model.parameters(), lr=learning_rate)
  67. train_loader, val_loader = get_loaders(
  68. train_img_dir,
  69. train_mask_dir,
  70. val_img_dir,
  71. val_mask_dir,
  72. batch_size,
  73. train_transform,
  74. val_transform,
  75. num_workers,
  76. pin_memory
  77. )
  78. if load_model:
  79. load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)
  80. check_accuracy(-1, "val", val_loader, model, device=device)
  81. for epoch in range(num_epochs):
  82. train_fn(train_loader, model, optimizer, loss_fn)
  83. checkpoint = {
  84. "state_dict": model.state_dict(),
  85. "optimizer": optimizer.state_dict(),
  86. }
  87. save_checkpoint(checkpoint)
  88. check_accuracy(epoch, "train", train_loader, model, device=device)
  89. check_accuracy(epoch, "val", val_loader, model, device=device)
  90. save_predictions_as_imgs(val_loader, model, folder="saved_images/", device=device)
  91. if __name__ == "__main__":
  92. main()

 (5)数据加载dataset.py

  1. import os
  2. from PIL import Image
  3. from torch.utils.data import Dataset
  4. import numpy as np
  5. class CarvanaDataset(Dataset):
  6. def __init__(self, image_dir, mask_dir, transform=None):
  7. self.image_dir = image_dir
  8. self.mask_dir = mask_dir
  9. self.transform = transform
  10. self.images = os.listdir(image_dir)
  11. def __len__(self):
  12. return len(self.images)
  13. def __getitem__(self, index):
  14. img_path = os.path.join(self.image_dir, self.images[index])
  15. mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.jpg"))
  16. image = np.array(Image.open(img_path).convert("RGB"))
  17. mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
  18. mask[mask > 200.0] = 1.0 # 转换为灰度图后并非全是255白色
  19. if self.transform is not None:
  20. augmentations = self.transform(image=image, mask=mask)
  21. image = augmentations["image"]
  22. mask = augmentations["mask"]
  23. return image, mask

(6)模型model.py

  1. import torch
  2. import torch.nn as nn
  3. import torch.functional as F
  4. import torchvision.transforms.functional as TF
  5. class DoubleConv(nn.Module):
  6. def __init__(self, in_channels, out_channels):
  7. super(DoubleConv, self).__init__()
  8. self.conv = nn.Sequential(
  9. nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), # padding=1,保证conv2d的输出hw保持不变
  10. nn.BatchNorm2d(out_channels),
  11. nn.ReLU(inplace=True),
  12. nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), # padding=1,保证conv2d的输出hw保持不变
  13. nn.BatchNorm2d(out_channels),
  14. nn.ReLU(inplace=True),
  15. )
  16. def forward(self, x):
  17. return self.conv(x)
  18. class UNET(nn.Module):
  19. def __init__(
  20. self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
  21. ):
  22. super(UNET, self).__init__()
  23. self.ups = nn.ModuleList()
  24. self.downs = nn.ModuleList()
  25. self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
  26. # Down part of UNET
  27. for feature in features:
  28. self.downs.append(DoubleConv(in_channels, feature))
  29. in_channels = feature
  30. # Up part of UNET
  31. for feature in reversed(features):
  32. self.ups.append(
  33. nn.ConvTranspose2d(
  34. feature*2, feature, kernel_size=2, stride=2,
  35. )
  36. )
  37. self.ups.append(DoubleConv(feature*2, feature))
  38. self.bottleneck = DoubleConv(features[-1], features[-1]*2)
  39. self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
  40. def forward(self, x):
  41. skip_connections = []
  42. for down in self.downs:
  43. x = down(x)
  44. skip_connections.append(x)
  45. x = self.pool(x)
  46. x = self.bottleneck(x)
  47. skip_connections = skip_connections[::-1]
  48. for idx in range(0, len(self.ups), 2):
  49. x = self.ups[idx](x)
  50. skip_connection = skip_connections[idx//2]
  51. if x.shape != skip_connection.shape:
  52. x = TF.resize(x, size=skip_connection.shape[2:]) # 因为有padding=1,所以到不了这一步
  53. # diffY = torch.tensor([skip_connection.size()[2] - x.size()[2]])
  54. # diffX = torch.tensor([skip_connection.size()[3] - x.size()[3]])
  55. # x = F.pad(x, [diffX // 2, diffX - diffX // 2,
  56. # diffY // 2, diffY - diffY // 2])
  57. concat_skip = torch.cat((skip_connection, x), dim=1)
  58. x = self.ups[idx+1](concat_skip)
  59. return self.final_conv(x)
  60. def test():
  61. x = torch.randn((3, 1, 572, 572))
  62. model = UNET(in_channels=1, out_channels=1)
  63. preds = model(x)
  64. assert preds.shape == x.shape
  65. if __name__ == "__main__":
  66. test()

(7)工具utils.py

  1. import torch
  2. import torchvision
  3. from dataset import CarvanaDataset
  4. from torch.utils.data import DataLoader
  5. def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
  6. print("=> Saving checkpoint")
  7. torch.save(state, filename)
  8. def load_checkpoint(checkpoint, model):
  9. print("=> Loading checkpoint")
  10. model.load_state_dict(checkpoint["state_dict"])
  11. def get_loaders(
  12. train_dir,
  13. train_maskdir,
  14. val_dir,
  15. val_maskdir,
  16. batch_size,
  17. train_transform,
  18. val_transform,
  19. num_workers=4,
  20. pin_memory=True,
  21. ):
  22. train_ds = CarvanaDataset(
  23. image_dir=train_dir,
  24. mask_dir=train_maskdir,
  25. transform=train_transform,
  26. )
  27. train_loader = DataLoader(
  28. train_ds,
  29. batch_size=batch_size,
  30. num_workers=num_workers,
  31. pin_memory=pin_memory,
  32. shuffle=True,
  33. )
  34. val_ds = CarvanaDataset(
  35. image_dir=val_dir,
  36. mask_dir=val_maskdir,
  37. transform=val_transform,
  38. )
  39. val_loader = DataLoader(
  40. val_ds,
  41. batch_size=batch_size,
  42. num_workers=num_workers,
  43. pin_memory=pin_memory,
  44. shuffle=False,
  45. )
  46. return train_loader, val_loader
  47. def check_accuracy(epoch, attr, loader, model, device="cuda"):
  48. num_correct = 0
  49. num_pixels = 0
  50. dice_score = 0
  51. model.eval()
  52. with torch.no_grad():
  53. for x, y in loader:
  54. x = x.to(device)
  55. y = y.to(device).unsqueeze(1)
  56. preds = torch.sigmoid(model(x))
  57. preds = (preds > 0.5).float()
  58. num_correct += (preds == y).sum()
  59. num_pixels += torch.numel(preds)
  60. dice_score += (2 * (preds * y).sum()) / (
  61. (preds + y).sum() + 1e-8
  62. )
  63. print(f"{attr}_{epoch+1}: Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
  64. print(f"{attr}_{epoch+1}: Dice score: {dice_score/len(loader)}")
  65. model.train()
  66. def save_predictions_as_imgs(
  67. loader, model, folder="saved_images/", device="cuda"
  68. ):
  69. model.eval()
  70. for idx, (x, y) in enumerate(loader):
  71. x = x.to(device=device)
  72. with torch.no_grad():
  73. preds = torch.sigmoid(model(x))
  74. preds = (preds > 0.5).float()
  75. torchvision.utils.save_image(
  76. preds, f"{folder}/pred_{idx}.png"
  77. )
  78. torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")
  79. model.train()

(8)监控指标dice score说明

参考文档:关于图像分割的评价指标dice_Pierce_KK的博客-CSDN博客_dice评价指标

 dice指标也用在机器学习中,它的表达式为:

这与机器学习中的评价指标F1是相同的。

准确率指标:

召回率指标:

 

F1则是基于准确率和召回率的调和平均值,即:

 

 

dice指标是医学图像中的常见指标,常用于评价图像分割算法的好坏。从公式上来做直观的理解,如下图所示,其代表的是两个体相交的面积占总面积的比值,完美分割该值为1.

 

 本试验中,准确率能够达到60%+,disc score只有0.4+,整体效果不佳。

 

5、 UNet后续发展

(1)UNet网络的思想:

  • 下采样+上采样作为整体的网络结构(Encoder-Decoder)
  • 多尺度的特征融合
  • 信息流通的方式
  • 获得像素级别的segment map

(2)对于改进UNet的见解,参考:谈一谈UNet图像分割_3D视觉工坊的博客-CSDN博客

很多人都喜欢在UNet进行改进,换个优秀的编码器,然后自己在手动把解码器对应实现一下。执御为什么选择UNet上进行改进,可能是因为UNet网络的结构比较简单,而且UNet的效果在很多场景下的表现可能都是差强人意的。 

 UNet最原始的设计思路,相对于后面系列的一个劣势就是:信息融合、位置不偏移。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/87672?site
推荐阅读
相关标签
  

闽ICP备14008679号