当前位置:   article > 正文

Python Unet ++ :医学图像分割,医学细胞分割,Unet医学图像处理,语义分割_from lovaszsoftmax.pytorch import lovasz_losses as

from lovaszsoftmax.pytorch import lovasz_losses as l

一,语义分割:分割领域前几年的发展

图像分割是机器视觉任务的一个重要基础任务,在图像分析、自动驾驶、视频监控等方面都有很重要的作用。图像分割可以被看成一个分类任务,需要给每个像素进行分类,所以就比图像分类任务更加复杂。此处主要介绍 Deep Learning-based 相关方法。

 

 

 

 主要介绍unet和unet++

 

二,数据介绍---医学细胞分割任务

原数据:

标签数据: 

 

 三,代码部分

模型包含以下文件:

archs.py为模型的主体部分:

  1. import torch
  2. from torch import nn
  3. __all__ = ['UNet', 'NestedUNet']
  4. class VGGBlock(nn.Module):
  5. def __init__(self, in_channels, middle_channels, out_channels):
  6. super().__init__()
  7. self.relu = nn.ReLU(inplace=True)
  8. self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
  9. self.bn1 = nn.BatchNorm2d(middle_channels)
  10. self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
  11. self.bn2 = nn.BatchNorm2d(out_channels)
  12. def forward(self, x):
  13. out = self.conv1(x)
  14. out = self.bn1(out)
  15. out = self.relu(out)
  16. out = self.conv2(out)
  17. out = self.bn2(out)
  18. out = self.relu(out)
  19. return out
  20. class UNet(nn.Module):
  21. def __init__(self, num_classes, input_channels=3, **kwargs):
  22. super().__init__()
  23. nb_filter = [32, 64, 128, 256, 512]
  24. self.pool = nn.MaxPool2d(2, 2)
  25. self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)#scale_factor:放大的倍数 插值
  26. self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
  27. self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
  28. self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
  29. self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
  30. self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
  31. self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
  32. self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
  33. self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
  34. self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
  35. self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
  36. def forward(self, input):
  37. x0_0 = self.conv0_0(input)
  38. x1_0 = self.conv1_0(self.pool(x0_0))
  39. x2_0 = self.conv2_0(self.pool(x1_0))
  40. x3_0 = self.conv3_0(self.pool(x2_0))
  41. x4_0 = self.conv4_0(self.pool(x3_0))
  42. x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
  43. x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
  44. x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
  45. x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
  46. output = self.final(x0_4)
  47. return output
  48. class NestedUNet(nn.Module):
  49. def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
  50. super().__init__()
  51. nb_filter = [32, 64, 128, 256, 512]
  52. self.deep_supervision = deep_supervision
  53. self.pool = nn.MaxPool2d(2, 2)
  54. self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
  55. self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
  56. self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
  57. self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
  58. self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
  59. self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
  60. self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
  61. self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
  62. self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
  63. self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
  64. self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
  65. self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
  66. self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
  67. self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
  68. self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
  69. self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
  70. if self.deep_supervision:
  71. self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
  72. self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
  73. self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
  74. self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
  75. else:
  76. self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
  77. def forward(self, input):
  78. print('input:',input.shape)
  79. x0_0 = self.conv0_0(input)
  80. print('x0_0:',x0_0.shape)
  81. x1_0 = self.conv1_0(self.pool(x0_0))
  82. print('x1_0:',x1_0.shape)
  83. x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
  84. print('x0_1:',x0_1.shape)
  85. x2_0 = self.conv2_0(self.pool(x1_0))
  86. print('x2_0:',x2_0.shape)
  87. x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
  88. print('x1_1:',x1_1.shape)
  89. x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
  90. print('x0_2:',x0_2.shape)
  91. x3_0 = self.conv3_0(self.pool(x2_0))
  92. print('x3_0:',x3_0.shape)
  93. x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
  94. print('x2_1:',x2_1.shape)
  95. x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
  96. print('x1_2:',x1_2.shape)
  97. x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
  98. print('x0_3:',x0_3.shape)
  99. x4_0 = self.conv4_0(self.pool(x3_0))
  100. print('x4_0:',x4_0.shape)
  101. x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
  102. print('x3_1:',x3_1.shape)
  103. x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
  104. print('x2_2:',x2_2.shape)
  105. x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
  106. print('x1_3:',x1_3.shape)
  107. x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
  108. print('x0_4:',x0_4.shape)
  109. if self.deep_supervision:
  110. output1 = self.final1(x0_1)
  111. output2 = self.final2(x0_2)
  112. output3 = self.final3(x0_3)
  113. output4 = self.final4(x0_4)
  114. return [output1, output2, output3, output4]
  115. else:
  116. output = self.final(x0_4)
  117. return output

dataset.py为数据的预处理部分

  1. import os
  2. import cv2
  3. import numpy as np
  4. import torch
  5. import torch.utils.data
  6. class Dataset(torch.utils.data.Dataset):
  7. def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):
  8. """
  9. Args:
  10. img_ids (list): Image ids.
  11. img_dir: Image file directory.
  12. mask_dir: Mask file directory.
  13. img_ext (str): Image file extension.
  14. mask_ext (str): Mask file extension.
  15. num_classes (int): Number of classes.
  16. transform (Compose, optional): Compose transforms of albumentations. Defaults to None.
  17. Note:
  18. Make sure to put the files as the following structure:
  19. <dataset name>
  20. ├── images
  21. | ├── 0a7e06.jpg
  22. │ ├── 0aab0a.jpg
  23. │ ├── 0b1761.jpg
  24. │ ├── ...
  25. |
  26. └── masks
  27. ├── 0
  28. | ├── 0a7e06.png
  29. | ├── 0aab0a.png
  30. | ├── 0b1761.png
  31. | ├── ...
  32. |
  33. ├── 1
  34. | ├── 0a7e06.png
  35. | ├── 0aab0a.png
  36. | ├── 0b1761.png
  37. | ├── ...
  38. ...
  39. """
  40. self.img_ids = img_ids
  41. self.img_dir = img_dir
  42. self.mask_dir = mask_dir
  43. self.img_ext = img_ext
  44. self.mask_ext = mask_ext
  45. self.num_classes = num_classes
  46. self.transform = transform
  47. def __len__(self):
  48. return len(self.img_ids)
  49. def __getitem__(self, idx):
  50. img_id = self.img_ids[idx]
  51. img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))
  52. mask = []
  53. for i in range(self.num_classes):
  54. mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),
  55. img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])
  56. mask = np.dstack(mask)
  57. if self.transform is not None:
  58. augmented = self.transform(image=img, mask=mask)#这个包比较方便,能把mask也一并做掉
  59. img = augmented['image']#参考https://github.com/albumentations-team/albumentations
  60. mask = augmented['mask']
  61. img = img.astype('float32') / 255
  62. img = img.transpose(2, 0, 1)
  63. mask = mask.astype('float32') / 255
  64. mask = mask.transpose(2, 0, 1)
  65. return img, mask, {'img_id': img_id}

losses.py

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. try:
  5. from LovaszSoftmax.pytorch.lovasz_losses import lovasz_hinge
  6. except ImportError:
  7. pass
  8. __all__ = ['BCEDiceLoss', 'LovaszHingeLoss']
  9. class BCEDiceLoss(nn.Module):
  10. def __init__(self):
  11. super().__init__()
  12. def forward(self, input, target):
  13. bce = F.binary_cross_entropy_with_logits(input, target)
  14. smooth = 1e-5
  15. input = torch.sigmoid(input)
  16. num = target.size(0)
  17. input = input.view(num, -1)
  18. target = target.view(num, -1)
  19. intersection = (input * target)
  20. dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)
  21. dice = 1 - dice.sum() / num
  22. return 0.5 * bce + dice
  23. class LovaszHingeLoss(nn.Module):
  24. def __init__(self):
  25. super().__init__()
  26. def forward(self, input, target):
  27. input = input.squeeze(1)
  28. target = target.squeeze(1)
  29. loss = lovasz_hinge(input, target, per_image=True)
  30. return loss

metrics.py 模型效果评价指标

  1. import numpy as np
  2. import torch
  3. import torch.nn.functional as F
  4. def iou_score(output, target):
  5. smooth = 1e-5
  6. if torch.is_tensor(output):
  7. output = torch.sigmoid(output).data.cpu().numpy()
  8. if torch.is_tensor(target):
  9. target = target.data.cpu().numpy()
  10. output_ = output > 0.5
  11. target_ = target > 0.5
  12. intersection = (output_ & target_).sum()
  13. union = (output_ | target_).sum()
  14. return (intersection + smooth) / (union + smooth)
  15. def dice_coef(output, target):
  16. smooth = 1e-5
  17. output = torch.sigmoid(output).view(-1).data.cpu().numpy()
  18. target = target.view(-1).data.cpu().numpy()
  19. intersection = (output * target).sum()
  20. return (2. * intersection + smooth) / \
  21. (output.sum() + target.sum() + smooth)

preprocess.py 数据标签的合并处理,将同一张图的多个标签数据合并为一张

  1. import os
  2. from glob import glob
  3. import cv2
  4. import numpy as np
  5. from tqdm import tqdm
  6. def main():
  7. img_size = 96
  8. paths = glob('inputs/stage1_train/*')
  9. os.makedirs('inputs/dsb2018_%d/images' % img_size, exist_ok=True)
  10. os.makedirs('inputs/dsb2018_%d/masks/0' % img_size, exist_ok=True)
  11. for i in tqdm(range(len(paths))):
  12. path = paths[i]
  13. img = cv2.imread(os.path.join(path, 'images',
  14. os.path.basename(path) + '.png'))
  15. mask = np.zeros((img.shape[0], img.shape[1]))
  16. for mask_path in glob(os.path.join(path, 'masks', '*')):
  17. mask_ = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 127
  18. mask[mask_] = 1
  19. if len(img.shape) == 2:
  20. img = np.tile(img[..., None], (1, 1, 3))
  21. if img.shape[2] == 4:
  22. img = img[..., :3]
  23. img = cv2.resize(img, (img_size, img_size))
  24. mask = cv2.resize(mask, (img_size, img_size))
  25. cv2.imwrite(os.path.join('inputs/dsb2018_%d/images' % img_size,
  26. os.path.basename(path) + '.png'), img)
  27. cv2.imwrite(os.path.join('inputs/dsb2018_%d/masks/0' % img_size,
  28. os.path.basename(path) + '.png'), (mask * 255).astype('uint8'))
  29. if __name__ == '__main__':
  30. main()

 utils.py 其它设置

  1. import argparse
  2. def str2bool(v):
  3. if v.lower() in ['true', 1]:
  4. return True
  5. elif v.lower() in ['false', 0]:
  6. return False
  7. else:
  8. raise argparse.ArgumentTypeError('Boolean value expected.')
  9. def count_params(model):
  10. return sum(p.numel() for p in model.parameters() if p.requires_grad)
  11. class AverageMeter(object):
  12. """Computes and stores the average and current value"""
  13. def __init__(self):
  14. self.reset()
  15. def reset(self):
  16. self.val = 0
  17. self.avg = 0
  18. self.sum = 0
  19. self.count = 0
  20. def update(self, val, n=1):
  21. self.val = val
  22. self.sum += val * n
  23. self.count += n
  24. self.avg = self.sum / self.count

train.py 模型训练

  1. import argparse
  2. import os
  3. from collections import OrderedDict
  4. from glob import glob
  5. import pandas as pd
  6. import torch
  7. import torch.backends.cudnn as cudnn
  8. import torch.nn as nn
  9. import torch.optim as optim
  10. import yaml
  11. import albumentations as albu
  12. from albumentations.augmentations import transforms
  13. from albumentations.core.composition import Compose, OneOf
  14. from sklearn.model_selection import train_test_split
  15. from torch.optim import lr_scheduler
  16. from tqdm import tqdm
  17. import archs
  18. import losses
  19. from dataset import Dataset
  20. from metrics import iou_score
  21. from utils import AverageMeter, str2bool
  22. ARCH_NAMES = archs.__all__
  23. LOSS_NAMES = losses.__all__
  24. LOSS_NAMES.append('BCEWithLogitsLoss')
  25. """
  26. 指定参数:
  27. --dataset dsb2018_96
  28. --arch NestedUNet
  29. """
  30. def parse_args():
  31. parser = argparse.ArgumentParser()
  32. parser.add_argument('--name', default=None,
  33. help='model name: (default: arch+timestamp)')
  34. parser.add_argument('--epochs', default=100, type=int, metavar='N',
  35. help='number of total epochs to run')
  36. parser.add_argument('-b', '--batch_size', default=8, type=int,
  37. metavar='N', help='mini-batch size (default: 16)')
  38. # model
  39. parser.add_argument('--arch', '-a', metavar='ARCH', default='NestedUNet',
  40. choices=ARCH_NAMES,
  41. help='model architecture: ' +
  42. ' | '.join(ARCH_NAMES) +
  43. ' (default: NestedUNet)')
  44. parser.add_argument('--deep_supervision', default=False, type=str2bool)
  45. parser.add_argument('--input_channels', default=3, type=int,
  46. help='input channels')
  47. parser.add_argument('--num_classes', default=1, type=int,
  48. help='number of classes')
  49. parser.add_argument('--input_w', default=96, type=int,
  50. help='image width')
  51. parser.add_argument('--input_h', default=96, type=int,
  52. help='image height')
  53. # loss
  54. parser.add_argument('--loss', default='BCEDiceLoss',
  55. choices=LOSS_NAMES,
  56. help='loss: ' +
  57. ' | '.join(LOSS_NAMES) +
  58. ' (default: BCEDiceLoss)')
  59. # dataset
  60. parser.add_argument('--dataset', default='dsb2018_96',
  61. help='dataset name')
  62. parser.add_argument('--img_ext', default='.png',
  63. help='image file extension')
  64. parser.add_argument('--mask_ext', default='.png',
  65. help='mask file extension')
  66. # optimizer
  67. parser.add_argument('--optimizer', default='SGD',
  68. choices=['Adam', 'SGD'],
  69. help='loss: ' +
  70. ' | '.join(['Adam', 'SGD']) +
  71. ' (default: Adam)')
  72. parser.add_argument('--lr', '--learning_rate', default=1e-3, type=float,
  73. metavar='LR', help='initial learning rate')
  74. parser.add_argument('--momentum', default=0.9, type=float,
  75. help='momentum')
  76. parser.add_argument('--weight_decay', default=1e-4, type=float,
  77. help='weight decay')
  78. parser.add_argument('--nesterov', default=False, type=str2bool,
  79. help='nesterov')
  80. # scheduler
  81. parser.add_argument('--scheduler', default='CosineAnnealingLR',
  82. choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])
  83. parser.add_argument('--min_lr', default=1e-5, type=float,
  84. help='minimum learning rate')
  85. parser.add_argument('--factor', default=0.1, type=float)
  86. parser.add_argument('--patience', default=2, type=int)
  87. parser.add_argument('--milestones', default='1,2', type=str)
  88. parser.add_argument('--gamma', default=2/3, type=float)
  89. parser.add_argument('--early_stopping', default=-1, type=int,
  90. metavar='N', help='early stopping (default: -1)')
  91. parser.add_argument('--num_workers', default=0, type=int)
  92. config = parser.parse_args()
  93. return config
  94. def train(config, train_loader, model, criterion, optimizer):
  95. avg_meters = {'loss': AverageMeter(),
  96. 'iou': AverageMeter()}
  97. model.train()
  98. pbar = tqdm(total=len(train_loader))
  99. for input, target, _ in train_loader:
  100. input = input.cuda()
  101. target = target.cuda()
  102. # compute output
  103. if config['deep_supervision']:
  104. outputs = model(input)
  105. loss = 0
  106. for output in outputs:
  107. loss += criterion(output, target)
  108. loss /= len(outputs)
  109. iou = iou_score(outputs[-1], target)
  110. else:
  111. output = model(input)
  112. loss = criterion(output, target)
  113. iou = iou_score(output, target)
  114. # compute gradient and do optimizing step
  115. optimizer.zero_grad()
  116. loss.backward()
  117. optimizer.step()
  118. avg_meters['loss'].update(loss.item(), input.size(0))
  119. avg_meters['iou'].update(iou, input.size(0))
  120. postfix = OrderedDict([
  121. ('loss', avg_meters['loss'].avg),
  122. ('iou', avg_meters['iou'].avg),
  123. ])
  124. pbar.set_postfix(postfix)
  125. pbar.update(1)
  126. pbar.close()
  127. return OrderedDict([('loss', avg_meters['loss'].avg),
  128. ('iou', avg_meters['iou'].avg)])
  129. def validate(config, val_loader, model, criterion):
  130. avg_meters = {'loss': AverageMeter(),
  131. 'iou': AverageMeter()}
  132. # switch to evaluate mode
  133. model.eval()
  134. with torch.no_grad():
  135. pbar = tqdm(total=len(val_loader))
  136. for input, target, _ in val_loader:
  137. input = input.cuda()
  138. target = target.cuda()
  139. # compute output
  140. if config['deep_supervision']:
  141. outputs = model(input)
  142. loss = 0
  143. for output in outputs:
  144. loss += criterion(output, target)
  145. loss /= len(outputs)
  146. iou = iou_score(outputs[-1], target)
  147. else:
  148. output = model(input)
  149. loss = criterion(output, target)
  150. iou = iou_score(output, target)
  151. avg_meters['loss'].update(loss.item(), input.size(0))
  152. avg_meters['iou'].update(iou, input.size(0))
  153. postfix = OrderedDict([
  154. ('loss', avg_meters['loss'].avg),
  155. ('iou', avg_meters['iou'].avg),
  156. ])
  157. pbar.set_postfix(postfix)
  158. pbar.update(1)
  159. pbar.close()
  160. return OrderedDict([('loss', avg_meters['loss'].avg),
  161. ('iou', avg_meters['iou'].avg)])
  162. def main():
  163. config = vars(parse_args())
  164. if config['name'] is None:
  165. if config['deep_supervision']:
  166. config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])
  167. else:
  168. config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])
  169. os.makedirs('models/%s' % config['name'], exist_ok=True)
  170. print('-' * 20)
  171. for key in config:
  172. print('%s: %s' % (key, config[key]))
  173. print('-' * 20)
  174. with open('models/%s/config.yml' % config['name'], 'w') as f:
  175. yaml.dump(config, f)
  176. # define loss function (criterion)
  177. if config['loss'] == 'BCEWithLogitsLoss':
  178. criterion = nn.BCEWithLogitsLoss().cuda()#WithLogits 就是先将输出结果经过sigmoid再交叉熵
  179. else:
  180. criterion = losses.__dict__[config['loss']]().cuda()
  181. cudnn.benchmark = True
  182. # create model
  183. print("=> creating model %s" % config['arch'])
  184. model = archs.__dict__[config['arch']](config['num_classes'],
  185. config['input_channels'],
  186. config['deep_supervision'])
  187. model = model.cuda()
  188. params = filter(lambda p: p.requires_grad, model.parameters())
  189. if config['optimizer'] == 'Adam':
  190. optimizer = optim.Adam(
  191. params, lr=config['lr'], weight_decay=config['weight_decay'])
  192. elif config['optimizer'] == 'SGD':
  193. optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],
  194. nesterov=config['nesterov'], weight_decay=config['weight_decay'])
  195. else:
  196. raise NotImplementedError
  197. if config['scheduler'] == 'CosineAnnealingLR':
  198. scheduler = lr_scheduler.CosineAnnealingLR(
  199. optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
  200. elif config['scheduler'] == 'ReduceLROnPlateau':
  201. scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'],
  202. verbose=1, min_lr=config['min_lr'])
  203. elif config['scheduler'] == 'MultiStepLR':
  204. scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])
  205. elif config['scheduler'] == 'ConstantLR':
  206. scheduler = None
  207. else:
  208. raise NotImplementedError
  209. # Data loading code
  210. img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
  211. img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
  212. train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)
  213. #数据增强:需要安装albumentations包
  214. train_transform = Compose([
  215. #角度旋转
  216. albu.RandomRotate90(),
  217. #图像翻转
  218. albu.Flip(),
  219. OneOf([
  220. transforms.HueSaturationValue(),
  221. transforms.RandomBrightness(),
  222. transforms.RandomContrast(),
  223. ], p=1),#按照归一化的概率选择执行哪一个
  224. albu.Resize(config['input_h'], config['input_w']),
  225. albu.Normalize(),
  226. ])
  227. val_transform = Compose([
  228. albu.Resize(config['input_h'], config['input_w']),
  229. albu.Normalize(),
  230. ])
  231. train_dataset = Dataset(
  232. img_ids=train_img_ids,
  233. img_dir=os.path.join('inputs', config['dataset'], 'images'),
  234. mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
  235. img_ext=config['img_ext'],
  236. mask_ext=config['mask_ext'],
  237. num_classes=config['num_classes'],
  238. transform=train_transform)
  239. val_dataset = Dataset(
  240. img_ids=val_img_ids,
  241. img_dir=os.path.join('inputs', config['dataset'], 'images'),
  242. mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
  243. img_ext=config['img_ext'],
  244. mask_ext=config['mask_ext'],
  245. num_classes=config['num_classes'],
  246. transform=val_transform)
  247. train_loader = torch.utils.data.DataLoader(
  248. train_dataset,
  249. batch_size=config['batch_size'],
  250. shuffle=True,
  251. num_workers=config['num_workers'],
  252. drop_last=True)#不能整除的batch是否就不要了
  253. val_loader = torch.utils.data.DataLoader(
  254. val_dataset,
  255. batch_size=config['batch_size'],
  256. shuffle=False,
  257. num_workers=config['num_workers'],
  258. drop_last=False)
  259. log = OrderedDict([
  260. ('epoch', []),
  261. ('lr', []),
  262. ('loss', []),
  263. ('iou', []),
  264. ('val_loss', []),
  265. ('val_iou', []),
  266. ])
  267. best_iou = 0
  268. trigger = 0
  269. for epoch in range(config['epochs']):
  270. print('Epoch [%d/%d]' % (epoch, config['epochs']))
  271. # train for one epoch
  272. train_log = train(config, train_loader, model, criterion, optimizer)
  273. # evaluate on validation set
  274. val_log = validate(config, val_loader, model, criterion)
  275. if config['scheduler'] == 'CosineAnnealingLR':
  276. scheduler.step()
  277. elif config['scheduler'] == 'ReduceLROnPlateau':
  278. scheduler.step(val_log['loss'])
  279. print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'
  280. % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))
  281. log['epoch'].append(epoch)
  282. log['lr'].append(config['lr'])
  283. log['loss'].append(train_log['loss'])
  284. log['iou'].append(train_log['iou'])
  285. log['val_loss'].append(val_log['loss'])
  286. log['val_iou'].append(val_log['iou'])
  287. pd.DataFrame(log).to_csv('models/%s/log.csv' %
  288. config['name'], index=False)
  289. trigger += 1
  290. if val_log['iou'] > best_iou:
  291. torch.save(model.state_dict(), 'models/%s/model.pth' %
  292. config['name'])
  293. best_iou = val_log['iou']
  294. print("=> saved best model")
  295. trigger = 0
  296. # early stopping
  297. if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:
  298. print("=> early stopping")
  299. break
  300. torch.cuda.empty_cache()
  301. if __name__ == '__main__':
  302. main()

四,模型结果:

 五:注意事项以及常见问题

安装增加模块albumentations,主要为数据增强模块,方便快捷

pip install albumentations

常见问题:

AttributeError: module ‘cv2’ has no attribute ‘gapi_wip_gst_GStreamerPipeline’

解决:opencv-python-headless和opencv-python的版本对应即可

pip install opencv-python-headless==4.2.0.32 -i https://pypi.tuna.tsinghua.edu.cn/simple

问题2

AttributeError: module ‘albumentations.augmentations.transforms’ has no attribute ‘RandomRotate90’

解决:直接导入import albumentations as albu

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/351112
推荐阅读
相关标签
  

闽ICP备14008679号