当前位置:   article > 正文

U_Net语义分割完整版_u-net image_mask

u-net image_mask

1、背景

鉴于莫有人看俄的博客.....,俄决定放一个小项目。同时放一个吸引眼球的封面。

cover

2、U_Net完整版

网上发布的U_Net版本多是针对灰度图,彩色的rgb图像包含颜色信息,因此本项目以信息量更大的彩色图作为网络的输入,做一个3类(包含背景)目标图像的分割。

首先来看看项目文件结构:

1、dataprocess.py   ---->>定义数据读入,可在读入过程对数据进行transform等操作。

2、metrics.py   ---->>定义语义分割的评价标准miou。

3、model.py  ---->>定义U_Net模型结构

4、train.py  ---->>定义完整训练

5、utils  ---->>存放标注数据处理、训练好模型的测速、可视化等脚本。

3、数据读入

  1. from torch.utils.data import Dataset
  2. from torch.utils.data import DataLoader
  3. from torchvision import transforms
  4. from PIL import Image
  5. import numpy as np
  6. import os
  7. class Mydataset(Dataset):
  8. CLASSES = [0, 1, 2]
  9. def __len__(self):
  10. return len(self.ids)
  11. def __init__(self,images_dir:str,masks_dir:str,nb_classes,classes=None,transform=None):
  12. super().__init__()
  13. self.class_values = [self.CLASSES.index(cls) for cls in classes]
  14. self.nb_classes=nb_classes
  15. self.ids = os.listdir(images_dir)
  16. self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
  17. self.masks_fps = [os.path.join(masks_dir, image_id.split('.')[0] + '.npy') for image_id in self.ids]
  18. self.transform=transform
  19. def __getitem__(self, i):
  20. image = Image.open(self.images_fps[i])
  21. mask = np.load(self.masks_fps[i])
  22. mask[mask > self.nb_classes - 1] = 0
  23. mask=Image.fromarray(mask)
  24. change=transforms.Resize((48,64),2)
  25. mask=change(mask)
  26. mask=np.array(mask)
  27. if self.transform is not None:
  28. image = self.transform(image)
  29. return image, mask
  30. def to_categorical(y, num_classes=None, dtype='float32'):
  31. y = np.array(y, dtype='int')
  32. input_shape = y.shape
  33. if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
  34. input_shape = tuple(input_shape[:-1])
  35. y = y.ravel()
  36. if not num_classes:
  37. num_classes = np.max(y) + 1
  38. n = y.shape[0]
  39. categorical = np.zeros((n, num_classes), dtype=dtype)
  40. categorical[np.arange(n), y] = 1
  41. output_shape = input_shape + (num_classes,)
  42. categorical = np.reshape(categorical, output_shape)
  43. return categorical

4、评价标准

  1. import torch.nn as nn
  2. import torch
  3. import numpy as np
  4. from dataprocess import to_categorical
  5. class IoUMetric(nn.Module):
  6. __name__ = 'iou'
  7. def __init__(self, eps=1e-7, threshold=0.5, activation='sigmoid'):
  8. super().__init__()
  9. self.activation = activation
  10. self.eps = eps
  11. self.threshold = threshold
  12. def forward(self, y_pr, y_gt):
  13. return iou(y_pr, y_gt, self.eps, self.threshold, self.activation)
  14. def iou(pr, gt, eps=1e-7, threshold=None, activation='sigmoid'):
  15. if activation is None or activation == "none":
  16. activation_fn = lambda x: x
  17. elif activation == "sigmoid":
  18. activation_fn = torch.nn.Sigmoid()
  19. elif activation == "softmax2d":
  20. activation_fn = torch.nn.Softmax2d()
  21. else:
  22. raise NotImplementedError(
  23. "Activation implemented for sigmoid and softmax2d"
  24. )
  25. pr = activation_fn(pr)
  26. iou_all = 0
  27. smooth = 1
  28. pr = torch.argmax(pr, dim=1)
  29. pr = pr.cpu().numpy()
  30. gt = gt.cpu().numpy()
  31. pr = to_categorical(pr, num_classes=3)
  32. gt = to_categorical(gt, num_classes=3)
  33. nb_classes = 3
  34. for i in range(0, nb_classes):
  35. res_true = gt[:, :, :, i:i + 1]
  36. res_pred = pr[:, :, :, i:i + 1]
  37. res_pred = res_pred.astype(np.float64)
  38. res_true = res_true.astype(np.float64)
  39. intersection = np.sum(np.abs(res_true * res_pred), axis=(1, 2, 3))
  40. union = np.sum(res_true, axis=(1, 2, 3)) + np.sum(res_pred, axis=(1, 2, 3)) - intersection
  41. iou_all += (np.mean((intersection + smooth) / (union + smooth), axis=0))
  42. return iou_all / nb_classes

5、U_Net模型结构

  1. import torch
  2. from torch import nn
  3. import numpy as np
  4. class block_down(nn.Module):
  5. def __init__(self,inp_channel,out_channel):
  6. super(block_down,self).__init__()
  7. self.conv1=nn.Conv2d(inp_channel,out_channel,3,padding=1)
  8. self.conv2=nn.Conv2d(out_channel,out_channel,3,padding=1)
  9. self.bn=nn.BatchNorm2d(out_channel)
  10. self.relu=nn.ReLU6(inplace=True)
  11. def forward(self,x):
  12. x=self.conv1(x)
  13. x=self.bn(x)
  14. x=self.relu(x)
  15. x=self.conv2(x)
  16. x=self.bn(x)
  17. x=self.relu(x)
  18. return x
  19. class block_up(nn.Module):
  20. def __init__(self,inp_channel,out_channel):
  21. super(block_up,self).__init__()
  22. self.up=nn.ConvTranspose2d(inp_channel,out_channel,2,stride=2)
  23. self.conv1=nn.Conv2d(inp_channel,out_channel,3,padding=1)
  24. self.conv2=nn.Conv2d(out_channel,out_channel,3,padding=1)
  25. self.bn=nn.BatchNorm2d(out_channel)
  26. self.relu=nn.ReLU6(inplace=True)
  27. def forward(self,x,y):
  28. x=self.up(x)
  29. x=torch.cat([x,y],dim=1)
  30. x=self.conv1(x)
  31. x=self.bn(x)
  32. x=self.relu(x)
  33. x=self.conv2(x)
  34. x=self.bn(x)
  35. x=self.relu(x)
  36. return x
  37. class U_net(nn.Module):
  38. def __init__(self,out_channel):
  39. super(U_net,self).__init__()
  40. self.out=nn.Conv2d(64,out_channel,1)
  41. self.maxpool=nn.MaxPool2d(2)
  42. self.block_down=block_down
  43. self.block_up=block_up
  44. self.block1=block_down(3,64)
  45. self.block2=block_down(64,128)
  46. self.block3=block_down(128,256)
  47. self.block4=block_down(256,512)
  48. self.block5=block_down(512,1024)
  49. self.block6=block_up(1024,512)
  50. self.block7=block_up(512,256)
  51. self.block8=block_up(256,128)
  52. self.block9=block_up(128,64)
  53. def forward(self,x):
  54. x1_use=self.block1(x)
  55. x1=self.maxpool(x1_use)
  56. x2_use=self.block2(x1)
  57. x2=self.maxpool(x2_use)
  58. x3_use=self.block3(x2)
  59. x3=self.maxpool(x3_use)
  60. x4_use=self.block4(x3)
  61. x4=self.maxpool(x4_use)
  62. x5=self.block5(x4)
  63. x6=self.block6(x5,x4_use)
  64. x7=self.block7(x6,x3_use)
  65. x8=self.block8(x7,x2_use)
  66. x9=self.block9(x8,x1_use)
  67. x10=self.out(x9)
  68. out=torch.sigmoid(x10)
  69. return out
  70. if __name__=="__main__":
  71. test_input=torch.rand(1, 3, 48, 64).to("cuda")
  72. print("input_size:",test_input.size())
  73. model=U_net(out_channel=3)
  74. model.cuda()
  75. ouput=model(test_input)
  76. print("output_size:",ouput.size())

6、执行主程序

  1. import os
  2. import torch
  3. import torch.nn as nn
  4. import torch.backends.cudnn as cudnn
  5. import datetime
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from model import U_net
  9. from dataprocess import Mydataset
  10. from metrics import IoUMetric
  11. from tensorboardX import SummaryWriter
  12. from torchvision import transforms
  13. from torch.utils.data import DataLoader
  14. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  15. max_score = 0
  16. torch.backends.cudnn.benchmark = True
  17. def val(model, device, val_loader, loss, optimizer, metrics, epoch, timestamp):
  18. global max_score
  19. model.eval()
  20. test_loss = 0
  21. correct = 0
  22. test_miou = 0
  23. with torch.no_grad():
  24. for i, data in enumerate(val_loader):
  25. x, y = data
  26. x = x.to(device)
  27. y = y.to(device)
  28. optimizer.zero_grad()
  29. y_hat = model(x)
  30. y = y.long()
  31. test_loss += loss(y_hat, y).item() # sum up batch loss
  32. test_miou += metrics(y_hat, y)
  33. test_miou /= len(val_loader)
  34. test_loss /= len(val_loader)
  35. print(len(val_loader))
  36. writer.add_scalar('Val/Loss', test_loss, epoch)
  37. writer.add_scalar('Val/Miou', test_miou, epoch)
  38. print('\nTest set: Average loss: {:.4f}, Miou : {:.4f})\n'.format(
  39. test_loss, test_miou))
  40. if max_score < test_miou:
  41. max_score = test_miou
  42. os.makedirs('tmp/{}'.format(timestamp), exist_ok=True)
  43. torch.save(model, 'tmp/{}/{:.4f}_model.pth'.format(timestamp, max_score))
  44. return test_miou
  45. def train(model, device, train_loader, epoch, optimizer, loss, metrics):
  46. total_trainloss = 0
  47. total_trainmiou = 0
  48. model.train()
  49. for batch_idx, data in enumerate(train_loader):
  50. x, y = data
  51. x = x.to(device)
  52. y = y.to(device)
  53. x_var = torch.autograd.Variable(x)
  54. #x_var=x_var.to(device)
  55. optimizer.zero_grad()
  56. y_hat = model(x_var)
  57. train_miou = metrics(y_hat, y.long())
  58. L = loss(y_hat, y.long())
  59. L.backward()
  60. optimizer.step()
  61. total_trainloss += float(L)
  62. total_trainmiou += float(train_miou)
  63. print("batch{}: train_miou:{:.4f} loss:{:.4f}".format(batch_idx, train_miou, L))
  64. if batch_idx % 10 == 0:
  65. niter = epoch * len(train_loder) + batch_idx
  66. writer.add_scalar('Train/Loss', L, niter)
  67. writer.add_scalar('Train/Miou', train_miou, niter)
  68. total_trainloss /= len(train_loder)
  69. total_trainmiou /= len(train_loder)
  70. print('Train Epoch: {}\t Loss: {:.6f}, Miou: {:.4f}'.format(epoch, total_trainloss, total_trainmiou))
  71. if __name__ == '__main__':
  72. DEVICE = 'cuda'
  73. ACTIVATION = 'softmax'
  74. nb_classes = 3
  75. batch_size = 2
  76. timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
  77. writer = SummaryWriter('log/{}'.format(timestamp))
  78. #数据位置
  79. x_train_dir = r"/home/anchao/桌面/U_Net/train_new/images"
  80. y_train_dir = r"/home/anchao/桌面/U_Net/train_new/masks"
  81. x_valid_dir = r"/home/anchao/桌面/U_Net/valid_new/images"
  82. y_valid_dir = r"/home/anchao/桌面/U_Net/valid_new/masks"
  83. # 数据读入
  84. train_transform = transforms.Compose([
  85. transforms.Resize((48,64),2),
  86. transforms.ToTensor(),
  87. transforms.Normalize([0.519401, 0.359217, 0.310136], [0.061113, 0.048637, 0.041166]),
  88. ])
  89. valid_transform = transforms.Compose([
  90. transforms.Resize((48,64),2),
  91. transforms.ToTensor(),
  92. transforms.Normalize([0.517446, 0.360147, 0.310427], [0.061526, 0.049087, 0.041330])
  93. ])
  94. train_dataset = Mydataset(images_dir=x_train_dir, masks_dir=y_train_dir, nb_classes=3, classes=[0, 1, 2],
  95. transform=train_transform)
  96. valid_dataset = Mydataset(images_dir=x_valid_dir, masks_dir=y_valid_dir, nb_classes=3, classes=[0, 1, 2],
  97. transform=valid_transform)
  98. train_loder = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
  99. valid_loder = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)
  100. model=U_net(out_channel=3)
  101. criterion = nn.CrossEntropyLoss()
  102. metrics = IoUMetric(eps=1., activation="softmax2d")
  103. optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=0.001, weight_decay=5e-4)
  104. scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5, verbose=True,
  105. threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0,
  106. eps=1e-08)
  107. model.cuda()
  108. #训练模型
  109. for epoch in range(0, 60):
  110. train(model=model, device=DEVICE, train_loader=train_loder, epoch=epoch, optimizer=optimizer, loss=criterion,
  111. metrics=metrics)
  112. test_miou = val(model=model, device=DEVICE, val_loader=valid_loder, loss=criterion, optimizer=optimizer,
  113. metrics=metrics, epoch=epoch, timestamp=timestamp)
  114. scheduler.step(test_miou)
  115. writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)
  116. print("current lr: {}".format(optimizer.param_groups[0]['lr']))
  117. writer.close()

7、工具文件

.................不放   ----->>>因为目前项目还有一点点坑,但是可以跑起来......

可看出在训练到第二个批次的时候train set的miou达到了0.7,还是很可观,但是test set的miou只有0.45.....,而且越来越低...hhhh。分析原因:

1、图片过小,因为我的电脑显卡是GTX1050,稍有不慎就出现OOM,所以batch size为2,图片尺寸为(48,64),所以下采样可能变为了瞎采样。

2、待发现

如果想获得完美版,请关注我的git,please follow me。 https://github.com/2anchao

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/87848
推荐阅读
相关标签
  

闽ICP备14008679号