当前位置:   article > 正文

Deep Learning for Medical Image Segmentation: Tricks, Challenges and Future Directions 2D部分代码笔记-训练部分_config_vit.patches.grid

config_vit.patches.grid

目录

前言

一.训练整体架构(2DUnet/train.py)

 1.1  训练参数设置(2DUnet/options.py)

1.2  模型训练(2DUnet/networker_trainer.py)

1.2.1 设置GPU设备号、日志配置、设置随机种子

1.2.2 设置主干网络

1.2.3 设置损失函数

1.2.4 设置优化器和数据集

1.2.5 train

(1).AverageMeter类

(2).深监督(DeepS)

1.2.6 验证过程

1.2.7 运行过程

二、 后续

前言

        本篇博客将作为我研0到研3的学习经历的一个记录平台,也是我经验分享的第一步,也许对他人不会起到什么帮助,但是这将会成为我一步步走在科研路上的坚实记录。

人最宝贵的是生命。生命属于人只有一次。人的一生应当这样度过:当他回首往事的时候,不会因为碌碌无为、虚度年华而悔恨,也不会因为为人卑劣、生活庸俗而愧疚。---《钢铁是怎样炼成的》

       这篇博客将说明《Deep Learning for Medical Image Segmentation: Tricks, Challenges and Future Directions》这篇论文所提供的2DUnet训练部分代码的一些理解,包括整体架构,部分模块含义以及论文中提到的一些训练技巧在代码中的体现。

一.训练整体架构(2DUnet/train.py)

     不同于一些习惯于将模型训练过程和数据预处理、设置网络、设置训练参数等等写在一起的代码风格,这篇论文的coder将基本上所有关于网络训练部分的设置全部模块化处理(测试时也是相同),可以让人很清晰的按照train.py所提供的训练顺序来了解医疗影像分割训练的整体过程,也方便后续人们进行所需代码模块的调用。

    下面为train.py的具体代码:

  1. import sys
  2. sys.path.append('../') # 把路径添加到系统路径中去,防止路径报错
  3. import network_trainer
  4. from options import Options
  5. def main():
  6. '''opt实际上为存放模型、训练、测试、图片变形的一系列操作变量的大字典'''
  7. opt = Options(isTrain=True) # 获取基本变量
  8. opt.parse() # 解析环境变量
  9. opt.save_options() # 保存环境变量
  10. trainer = network_trainer.NetworkTrainer(opt) # 网络训练部分
  11. trainer.set_GPU_device() # 设置使用哪几个GPU来进行训练、测试工作
  12. trainer.set_logging() # 设置日志
  13. trainer.set_randomseed() # 给训练设置随机种子 2022
  14. trainer.set_network() # 设置训练的主干网络
  15. trainer.set_loss() # 设置损失函数 初始为交叉熵
  16. trainer.set_optimizer() # Adam+余弦学习率更新
  17. trainer.set_dataloader() # 加载数据集
  18. trainer.run()
  19. if __name__ == "__main__":
  20. main()

     从代码来看,整体网络训练的过程如下:

    设置环境、训练变量(argparse库实现)---- 将设置的变量保存为字典,之后保存到.txt文件中方便查看----设置训练时的一些配置(GPU、日志、seed等)----设置主干网络----设置损失函数、优化器等等----加载数据集----训练

   下面也将按照此流程进行说明

 1.1  训练参数设置(2DUnet/options.py)

    optipn.py主要设置训练时的一些参数,也包括将这些配置写入.txt文件的代码

  1. import os
  2. import argparse
  3. from NetworkTrainer.dataloaders.get_transform import get_transform
  4. class Options:
  5. def __init__(self, isTrain):
  6. self.isTrain = isTrain # 判断是否训练
  7. self.model = dict() # 模型属性
  8. self.train = dict() # 训练部分的字典
  9. self.test = dict() # 测试部分的字典
  10. self.transform = dict() # 三种模型状态下 图片变形的方式
  11. self.post = dict() # 数据后处理字典
  12. def parse(self):
  13. """ Parse the options, replace the default value if there is a new input ---网络训练时的一些默认参数"""
  14. parser = argparse.ArgumentParser(description='')
  15. parser.add_argument('--dataset', type=str, default='isic2018', help='isic2018 or conic') # 数据集类型
  16. parser.add_argument('--task', type=str, default='debug', help='') # 调试任务
  17. parser.add_argument('--fold', type=int, default=0, help='0-4, five fold cross validation') # 交叉验证文件夹
  18. parser.add_argument('--name', type=str, default='res50', help='res34, res50, res101, res152') # 残差模型名称
  19. parser.add_argument('--pretrained', type=bool, default=False, help='True or False') # 是否采用预训练模型
  20. parser.add_argument('--in-c', type=int, default=3, help='input channel') # 输入图片的维度
  21. parser.add_argument('--input-size', type=list, default=[256,256], help='input size of the image') # 输入图片的尺寸
  22. parser.add_argument('--train-gan-aug', type=bool, default=False, help='if use the augmente samples generated by GAN') # 是否使用GAN模型数据增强
  23. parser.add_argument('--train-train-epochs', type=int, default=200, help='number of training epochs') # 训练周期
  24. parser.add_argument('--train-batch-size', type=int, default=32, help='batch size')
  25. parser.add_argument('--train-checkpoint-freq', type=int, default=500, help='epoch to save checkpoints')
  26. parser.add_argument('--train-lr', type=float, default=3e-4, help='initial learning rate')
  27. parser.add_argument('--train-weight-decay', type=float, default=1e-5, help='weight decay')
  28. parser.add_argument('--train-workers', type=int, default=16, help='number of workers to load images') # 训练时cpu个数
  29. parser.add_argument('--train-gpus', type=list, default=[0, ], help='select gpu devices')
  30. parser.add_argument('--train-start-epoch', type=int, default=0, help='start epoch')
  31. parser.add_argument('--train-checkpoint', type=str, default='', help='checkpoint')
  32. parser.add_argument('--train-seed', type=int, default=2022, help='bn or in')
  33. parser.add_argument('--train-loss', type=str, default='ce', help='loss function, e.g., ce, dice, focal, ohem, tversky, wce') # loss
  34. parser.add_argument('--train-deeps', type=bool, default=False, help='if use deep supervision')
  35. parser.add_argument('--test-model-path', type=str, default=None, help='model path to test')
  36. parser.add_argument('--test-test-epoch', type=int, default=0, help='the checkpoint to test')
  37. parser.add_argument('--test-gpus', type=list, default=[0, ], help='select gpu devices')
  38. parser.add_argument('--test-save-flag', type=bool, default=False, help='if save the predicted results')
  39. parser.add_argument('--test-batch-size', type=int, default=4, help='batch size')
  40. parser.add_argument('--test-flip', type=bool, default=False, help='Test Time Augmentation with flipping')
  41. parser.add_argument('--test-rotate', type=bool, default=False, help='Test Time Augmentation with rotation') # TTA选项
  42. parser.add_argument('--post-abl', type=bool, default=False, help='True or False, post processing') # 后处理选项
  43. parser.add_argument('--post-rsa', type=bool, default=False, help='True or False, post processing')
  44. args = parser.parse_args()
  45. self.dataset = args.dataset
  46. self.task = args.task
  47. self.fold = args.fold
  48. self.root_dir = 'C:\\Users\\Desktop\\MedISeg-main\\isic2018' # 数据集目录
  49. self.result_dir = os.path.expanduser("~") + f'/Experiment/isic-2018/{self.dataset}/'
  50. self.model['name'] = args.name
  51. self.model['pretrained'] = args.pretrained
  52. self.model['in_c'] = args.in_c
  53. self.model['input_size'] = args.input_size
  54. '''train部分属性'''
  55. # --- training params --- #
  56. self.train['save_dir'] = '{:s}/{:s}/{:s}/fold_{:d}'.format(self.result_dir, self.task, self.model['name'], self.fold) # path to save results
  57. self.train['train_epochs'] = args.train_train_epochs
  58. self.train['batch_size'] = args.train_batch_size
  59. self.train['checkpoint_freq'] = args.train_checkpoint_freq
  60. self.train['lr'] = args.train_lr
  61. self.train['weight_decay'] = args.train_weight_decay
  62. self.train['workers'] = args.train_workers
  63. self.train['gpus'] = args.train_gpus
  64. self.train['seed'] = args.train_seed
  65. self.train['loss'] = args.train_loss
  66. self.train['deeps'] = args.train_deeps
  67. self.train['gan_aug'] = args.train_gan_aug
  68. # --- resume training --- #
  69. self.train['start_epoch'] = args.train_start_epoch
  70. self.train['checkpoint'] = args.train_checkpoint
  71. # --- test parameters --- #
  72. '''test部分属性'''
  73. self.test['test_epoch'] = args.test_test_epoch
  74. self.test['gpus'] = args.test_gpus
  75. self.test['save_flag'] = args.test_save_flag
  76. self.test['batch_size'] = args.test_batch_size
  77. self.test['flip'] = args.test_flip
  78. self.test['rotate'] = args.test_rotate
  79. self.test['save_dir'] = '{:s}/test_results'.format(self.train['save_dir'])
  80. if not args.test_model_path:
  81. self.test['model_path'] = '{:s}/checkpoints/checkpoint_{:d}.pth.tar'.format(self.train['save_dir'], self.test['test_epoch'])
  82. # --- post processing --- #
  83. '''图片后处理部分'''
  84. self.post['abl'] = args.post_abl
  85. self.post['rsa'] = args.post_rsa
  86. # define data transforms for training
  87. self.transform['train'] = get_transform(self, 'train')
  88. self.transform['val'] = get_transform(self, 'val')
  89. self.transform['test'] = get_transform(self, 'test')
  90. def save_options(self):
  91. if not os.path.exists(self.train['save_dir']):
  92. '''建立测试结果和检查点的文件夹'''
  93. os.makedirs(self.train['save_dir'], exist_ok=True)
  94. os.makedirs(os.path.join(self.train['save_dir'], 'test_results'), exist_ok=True)
  95. os.makedirs(os.path.join(self.train['save_dir'], 'checkpoints'), exist_ok=True)
  96. if self.isTrain:
  97. filename = '{:s}/train_options.txt'.format(self.train['save_dir'])
  98. else:
  99. filename = '{:s}/test_options.txt'.format(self.test['save_dir'])
  100. file = open(filename, 'w')
  101. groups = ['model', 'test', 'post', 'transform']
  102. file.write("# ---------- Options ---------- #")
  103. file.write('\ndataset: {:s}\n'.format(self.dataset))
  104. file.write('isTrain: {}\n'.format(self.isTrain))
  105. '''获取类中self.xxx的属性值,查看对象内部所有属性名和属性值组成的字典,该部分是将groups = ['model', 'test', 'post', 'transform']的几种状态的变量值写到文件中去'''
  106. for group, options in self.__dict__.items(): # 11个属性
  107. if group not in groups:
  108. continue
  109. file.write('\n\n-------- {:s} --------\n'.format(group))
  110. if group == 'transform':
  111. for name, val in options.items():
  112. if (self.isTrain and name != 'test') or (not self.isTrain and name == 'test'):
  113. file.write("{:s}:\n".format(name))
  114. for t_val in val:
  115. file.write("\t{:s}\n".format(t_val.__class__.__name__))
  116. else:
  117. for name, val in options.items():
  118. file.write("{:s} = {:s}\n".format(name, repr(val)))
  119. file.close()

    这一部分代码主要设置了“model”、“train”、“test”、“post”部分的参数,方便后续调用,这部分操作在每项代码都会有体现,即设置环境变量、储存各部分需要的初始参数等,在此不作详细说明

1.2  模型训练(2DUnet/networker_trainer.py)

   这部分是训练的主干,它将模型的训练过程分别模块化,因此在此也会逐步分析每个模块 

1.2.1 设置GPU设备号、日志配置、设置随机种子

  1. class NetworkTrainer:
  2. def __init__(self, opt):
  3. self.opt = opt
  4. self.criterion = CELoss() # 评价函数初始为CEloss(交叉熵)
  5. def set_GPU_device(self):
  6. os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(x) for x in self.opt.train['gpus']) # 设置在哪个GPU上跑
  7. def set_logging(self):
  8. self.logger, self.logger_results = setup_logging(self.opt) # 设置日志
  9. def set_randomseed(self):
  10. num = self.opt.train['seed'] # 获得初始种子 2022
  11. random.seed(num) # 设置随机种子
  12. os.environ['PYTHONHASHSEED'] = str(num) # 加到系统环境变量中
  13. np.random.seed(num) # 作用同random.seed(num)
  14. # 给CPU、当前GPU、所有GPU设置随机数种子
  15. torch.manual_seed(num)
  16. torch.cuda.manual_seed(num)
  17. torch.cuda.manual_seed_all(num)

   setup_logging是定义在2DUnet/utils.py中的,主要是利用logging库来设置日志参数等,代码中的随机数种子设置为2022(有点数学竞赛选填总会有一道x的2022次方题的味道了),但实际上效果比较好的种子一般会设为1337、0等。(其余的我也不太知道了)

setup_logging代码如下:

  1. '''记录模型整个过程的日志'''
  2. def setup_logging(opt):
  3. mode = 'a' if opt.train['checkpoint'] else 'w'
  4. # create logger for training information
  5. logger = logging.getLogger(
  6. 'train_logger') # 应当通过模块级别的函数 logging.getLogger(name) 。多次使用相同的名字调用 getLogger() 会一直返回相同的 Logger 对象的引用】
  7. logger.setLevel(logging.DEBUG) # 日志等级小于 debug会被忽略。严重性为 level 或更高的日志消息将由该记录器的任何一个或多个处理器发出
  8. # create console handler and file handler
  9. # console_handler = logging.StreamHandler()
  10. console_handler = RichHandler(show_level=False, show_time=False, show_path=False)
  11. console_handler.setLevel(logging.INFO) # 记录info级别的日志
  12. file_handler = logging.FileHandler('{:s}/train_log.txt'.format(opt.train['save_dir']), mode=mode)
  13. file_handler.setLevel(logging.DEBUG) # 新建日志文件 记录DEBUG级别以上的日志
  14. # create formatter
  15. # formatter = logging.Formatter('%(asctime)s\t%(message)s', datefmt='%m-%d %I:%M')
  16. '''建立格式化对象formatter 将消息包括在日志记录调用中'''
  17. formatter = logging.Formatter('%(message)s')
  18. # add formatter to handlers
  19. console_handler.setFormatter(formatter)
  20. file_handler.setFormatter(formatter)
  21. # add handlers to logger
  22. logger.addHandler(console_handler)
  23. logger.addHandler(file_handler)
  24. # create logger for epoch results
  25. logger_results = logging.getLogger('results')
  26. logger_results.setLevel(logging.DEBUG)
  27. file_handler2 = logging.FileHandler('{:s}/epoch_results.txt'.format(opt.train['save_dir']), mode=mode)
  28. file_handler2.setFormatter(logging.Formatter('%(message)s'))
  29. logger_results.addHandler(file_handler2)
  30. logger.info('***** Training starts *****')
  31. logger.info('save directory: {:s}'.format(opt.train['save_dir']))
  32. if mode == 'w':
  33. logger_results.info('epoch\ttrain_loss\ttrain_loss_vor\ttrain_loss_cluster\ttrain_loss_repel')
  34. return logger, logger_results

     设置日志是为了查看训练代码运行中出现的问题,处于调试方便一般都会设立此模块

1.2.2 设置主干网络

    代码中使用的网络类型很多,如:densenet,resnet,vit,resunet等,除此之外还有是否使用深监督等一系列设置,通过字典中名字来分别设置主干网络

  1. def set_network(self):
  2. if 'res' in self.opt.model['name']:
  3. self.net = ResUNet(net=self.opt.model['name'], seg_classes=2, colour_classes=3, pretrained=self.opt.model['pretrained']) # Res50, 2D 二分类,channel=3
  4. if self.opt.train['deeps']:
  5. self.net = ResUNet_ds(net=self.opt.model['name'], seg_classes=2, colour_classes=3, pretrained=self.opt.model['pretrained'])
  6. elif 'dense' in self.opt.model['name']:
  7. self.net = DenseUNet(net=self.opt.model['name'], seg_classes=2)
  8. elif 'trans' in self.opt.model['name']:
  9. config_vit = CONFIGS_ViT_seg[self.opt.model['name']]
  10. config_vit.n_classes = 2
  11. config_vit.n_skip = 4
  12. if self.opt.model['name'].find('R50') != -1:
  13. config_vit.patches.grid = (int(self.opt.model['input_size'][0] / 16), int(self.opt.model['input_size'][1] / 16))
  14. self.net = ViT_seg(config_vit, img_size=self.opt.model['input_size'][0], num_classes=config_vit.n_classes).cuda()
  15. else:
  16. self.net = UNet(3, 2, 2) # 默认主干网络
  17. self.net = torch.nn.DataParallel(self.net) # 使用多个GPU加速训练
  18. self.net = self.net.cuda()

1.2.3 设置损失函数

    论文中提及了4种损失函数计算方法,即Dice,Focalloss,tverskyloss和ohemloss,代码中加入了WCEloss等,此部分代码如下:

  1. def set_loss(self):
  2. # set loss function
  3. if self.opt.train['loss'] == 'ce':
  4. self.criterion = CELoss()
  5. elif self.opt.train['loss'] == 'dice':
  6. self.criterion = DiceLoss()
  7. elif self.opt.train['loss'] == 'focal':
  8. self.criterion = FocalLoss(apply_nonlin=torch.nn.Softmax(dim=1))
  9. elif self.opt.train['loss'] == 'tversky':
  10. self.criterion = TverskyLoss()
  11. elif self.opt.train['loss'] == 'ohem':
  12. self.criterion = OHEMLoss()
  13. elif self.opt.train['loss'] == 'wce':
  14. self.criterion = CELoss(weight=torch.tensor([0.2, 0.8]))

其中,几种损失函数都定义在了loss_imbalance.py中,各种损失函数定义代码如下:

  1. """
  2. In ISIC dataset, the label shape is (b, x, y)
  3. In Kitti dataset, the label shape is (b, 1, x, y, z)
  4. """
  5. import ctypes
  6. import torch
  7. import torch.nn as nn
  8. import numpy as np
  9. class CELoss(nn.Module):
  10. def __init__(self, weight=None, reduction='mean'):
  11. self.weight = weight
  12. self.reduction = reduction
  13. def __call__(self, y_pred, y_true):
  14. y_true = y_true.long()
  15. if self.weight is not None:
  16. self.weight = self.weight.to(y_pred.device)
  17. if len(y_true.shape) == 5:
  18. y_true = y_true[:, 0, ...] # 约等于y_pred[:,0,...] 降维取每维第一个元素
  19. loss = nn.CrossEntropyLoss(weight=self.weight, reduction=self.reduction)
  20. return loss(y_pred, y_true)
  21. '''视之为二分类的Diceloss,但代码流程为普通DiceLoss'''
  22. class DiceLoss(nn.Module):
  23. def __init__(self, smooth=1e-8):
  24. super(DiceLoss, self).__init__()
  25. self.smooth = smooth # 极小量保证除式分母不为0
  26. def forward(self, y_pred, y_true):
  27. # first convert y_true to one-hot format
  28. axis = identify_axis(y_pred.shape) # axis=[2,3,4]
  29. y_pred = nn.Softmax(dim=1)(y_pred) # 沿1维softmax 降为(a,b,c,d)
  30. tp, fp, fn, _ = get_tp_fp_fn_tn(y_pred, y_true, axis) # 获取计算DiceLoss的tp,fp,fn
  31. intersection = 2 * tp + self.smooth
  32. union = 2 * tp + fp + fn + self.smooth
  33. dice = 1 - (intersection / union) #
  34. return dice.mean()
  35. # taken from https://github.com/JunMa11/SegLoss/blob/master/test/nnUNetV2/loss_functions/focal_loss.py
  36. class FocalLoss(nn.Module):
  37. """
  38. copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
  39. This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
  40. 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
  41. Focal_Loss= -1*alpha*(1-pt)*log(pt)
  42. :param num_class:
  43. :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
  44. :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
  45. focus on hard misclassified example
  46. :param smooth: (float,double) smooth value when cross entropy
  47. :param balance_index: (int) balance class index, should be specific when alpha is float
  48. :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
  49. """
  50. # 支持多分类和二分类
  51. def __init__(self, apply_nonlin=None, alpha=0.25, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
  52. super(FocalLoss, self).__init__()
  53. self.apply_nonlin = apply_nonlin #
  54. self.alpha = alpha # 公式中的alpha
  55. self.gamma = gamma # 可调节因子 公式中的次方数
  56. self.balance_index = balance_index
  57. self.smooth = smooth
  58. self.size_average = size_average # loss是否平均
  59. if self.smooth is not None:
  60. if self.smooth < 0 or self.smooth > 1.0:
  61. raise ValueError('smooth value should be in [0,1]')
  62. def forward(self, logit, target):
  63. if self.apply_nonlin is not None:
  64. logit = self.apply_nonlin(logit)
  65. num_class = logit.shape[1] # 3个channel
  66. '''这一部分的操作为将logits降维 (voxels,channels)'''
  67. if logit.dim() > 2:
  68. # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
  69. logit = logit.view(logit.size(0), logit.size(1), -1) # logits从(1,3,5,5,5) 变为(1,3,125)
  70. logit = logit.permute(0, 2, 1).contiguous() # N,C,m--->N,m,C
  71. logit = logit.view(-1, logit.size(-1)) # view(-1)变成了一行数据,也就是说不管原来是什么维度的张量,经过view操作之后,\
  72. # 行优先的顺序变成了一行数据\
  73. # 变为(125,3)
  74. target = torch.squeeze(target, 1)
  75. target = target.view(-1, 1) # 变为1列
  76. alpha = self.alpha
  77. if alpha is None: # alpha没有则随机取
  78. alpha = torch.ones(num_class, 1)
  79. elif isinstance(alpha, (list, np.ndarray)): # alpha是否为list或array
  80. assert len(alpha) == num_class
  81. alpha = torch.FloatTensor(alpha).view(num_class, 1)
  82. alpha = alpha / alpha.sum() # 归一化
  83. elif isinstance(alpha, float):
  84. alpha = torch.ones(num_class, 1) # 1初始化
  85. alpha = alpha * (1 - self.alpha)
  86. alpha[self.balance_index] = self.alpha # alpha[0]
  87. else:
  88. raise TypeError('Not support alpha type')
  89. '''统一device'''
  90. if alpha.device != logit.device:
  91. alpha = alpha.to(logit.device)
  92. idx = target.cpu().long()
  93. one_hot_key = torch.FloatTensor(target.size(0), num_class) # 随机创建(125,3)的FloatTensor
  94. one_hot_key = one_hot_key.zero_() # 把自己填0
  95. one_hot_key = one_hot_key.scatter_(1, idx, 1) # torch.Tensor.scatter_(dim, index, src) → Tensor
  96. '''scatter_具体算法 建立onehot和target的映射
  97. self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
  98. self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
  99. self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2'''
  100. if one_hot_key.device != logit.device:
  101. one_hot_key = one_hot_key.to(logit.device)
  102. '''这一步主要是为了防止log运算出现 log0 的情况'''
  103. if self.smooth:
  104. one_hot_key = torch.clamp(
  105. one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) # 把onehot编码限制到(min,max)范围 无限接近于(0,1) 避免下面log运算报错
  106. pt = (one_hot_key * logit).sum(1) + self.smooth # 计算pt false sample prob*0+true sample prob*1+极小量=pt、
  107. # pt越大 越好学习 loss中占的比重小
  108. logpt = pt.log()
  109. gamma = self.gamma
  110. alpha = alpha[idx] # idx : (125,1)
  111. alpha = torch.squeeze(alpha) # 得到(125,)
  112. '''公式实现'''
  113. mul = torch.pow((1 - pt), gamma) # (125,)
  114. loss = -1 * alpha* mul * logpt
  115. '''batch内平均loss'''
  116. if self.size_average:
  117. loss = loss.mean()
  118. else:
  119. loss = loss.sum()
  120. return loss
  121. '''T loss为Diceloss的改进版本 强调fp fn权重 alpha beta 其余同Diceloss'''
  122. class TverskyLoss(nn.Module):
  123. def __init__(self, alpha=0.3, beta=0.7, eps=1e-7): # 0.3 0.7一般效果比较好
  124. super(TverskyLoss, self).__init__()
  125. self.alpha = alpha
  126. self.beta = beta
  127. self.eps = eps
  128. def forward(self, y_pred, y_true):
  129. axis = identify_axis(y_pred.shape)
  130. y_pred = nn.Softmax(dim=1)(y_pred)
  131. y_true = to_onehot(y_pred, y_true)
  132. y_pred = torch.clamp(y_pred, self.eps, 1. - self.eps)
  133. tp, fp, fn, _ = get_tp_fp_fn_tn(y_pred, y_true, axis)
  134. tversky = (tp + self.eps) / (tp + self.eps + self.alpha * fn + self.beta * fp)
  135. return (y_pred.shape[1] - tversky.sum()) / y_pred.shape[1] # 0.6590
  136. # return (1-tversky).mean() # 0.6707
  137. class OHEMLoss(nn.CrossEntropyLoss):
  138. """
  139. Network has to have NO LINEARITY!
  140. """
  141. def __init__(self, weight=None, ignore_index=-100, k=0.7):
  142. super(OHEMLoss, self).__init__()
  143. self.k = k
  144. self.weight = weight
  145. self.ignore_index = ignore_index
  146. def forward(self, y_pred, y_true):
  147. res = CELoss(reduction='none')(y_pred, y_true) # 算CEloss
  148. num_voxels = np.prod(res.shape, dtype=np.int64) # 算体素
  149. res, _ = torch.topk(res.view((-1,)), int(num_voxels * self.k), sorted=False) # 排序取前k个损失最大的pixel\
  150. # 该函数返回2个值,第一个值为排序的数组,第二个值为该数组中获取到的元素在原数组中的位置标号
  151. return res.mean() # 最后,求这些 hard example 的损失的均值作为最终损失
  152. def to_onehot(y_pred, y_true):
  153. shp_x = y_pred.shape # tensor(1,3,5,5,5)
  154. shp_y = y_true.shape
  155. with torch.no_grad():
  156. "predict & target batch size don't match"
  157. if len(shp_x) != len(shp_y):
  158. y_true = y_true.view((shp_y[0], 1, *shp_y[1:])) # tensor(1,1,5,5,5)
  159. if all([i == j for i, j in zip(y_pred.shape, y_true.shape)]): # 预测的分类张量形式和onehot形式一样
  160. # if this is the case then gt is probably already a one hot encoding
  161. y_onehot = y_true # 认定为已经转变为onehot编码
  162. else:
  163. y_true = y_true.long() # 数据类型变为LongTensor
  164. y_onehot = torch.zeros(shp_x, device=y_pred.device)
  165. y_onehot.scatter_(1, y_true, 1) # scatter_(input, dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中
  166. return y_onehot
  167. '''Diceloss源码对应部分'''
  168. # def make_one_hot(input, num_classes):
  169. # """Convert class index tensor to one hot encoding tensor.
  170. # Args:
  171. # input: A tensor of shape [N, 1, *]
  172. # num_classes: An int of number of class
  173. # Returns:
  174. # A tensor of shape [N, num_classes, *]
  175. # """
  176. # shape = np.array(input.shape)
  177. # shape[1] = num_classes
  178. # shape = tuple(shape)
  179. # result = torch.zeros(shape)
  180. # result = result.scatter_(1, input.cpu(), 1)
  181. #
  182. # return result
  183. def get_tp_fp_fn_tn(net_output, gt, axes=None, square=False):
  184. """
  185. net_output must be (b, c, x, y(, z)))
  186. gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
  187. if mask is provided it must have shape (b, 1, x, y(, z)))
  188. :param net_output:
  189. :param gt:
  190. :return:
  191. """
  192. if axes is None:
  193. axes = tuple(range(2, len(net_output.size()))) # (2,3,4) 3D # shape=tupe(shape)
  194. y_onehot = to_onehot(net_output, gt) # 转变为one_hot编码
  195. tp = net_output * y_onehot # 概率值乘标签图==TP
  196. fp = net_output * (1 - y_onehot)
  197. fn = (1 - net_output) * y_onehot
  198. tn = (1 - net_output) * (1 - y_onehot)
  199. if square:
  200. tp = tp ** 2
  201. fp = fp ** 2
  202. fn = fn ** 2
  203. tn = tn ** 2
  204. if len(axes) > 0:
  205. tp = sum_tensor(tp, axes, keepdim=False)
  206. fp = sum_tensor(fp, axes, keepdim=False)
  207. fn = sum_tensor(fn, axes, keepdim=False)
  208. tn = sum_tensor(tn, axes, keepdim=False)
  209. return tp, fp, fn, tn
  210. '''对张量求和'''
  211. def sum_tensor(inp, axes, keepdim=False):
  212. axes = np.unique(axes).astype(int)
  213. if keepdim:
  214. for ax in axes:
  215. inp = inp.sum(int(ax), keepdim=True)
  216. else:
  217. for ax in sorted(axes, reverse=True):
  218. inp = inp.sum(int(ax)) # 沿着4、3、2维累计求和
  219. return inp
  220. def identify_axis(shape):
  221. """
  222. Helper function to enable loss function to be flexibly used for
  223. both 2D or 3D image segmentation - source: https://github.com/frankkramer-lab/MIScnn
  224. """
  225. # Three dimensional
  226. if len(shape) == 5:
  227. return [2, 3, 4]
  228. # Two dimensional
  229. elif len(shape) == 4:
  230. return [2, 3]
  231. # Exception - Unknown
  232. else:
  233. raise ValueError('Metric: Shape of tensor is neither 2D or 3D.')

    因其模块化的特性,因此可以十分容易将这些损失函数应用/继承到其他代码当中去,这也是这篇论文的写作初衷。

1.2.4 设置优化器和数据集

  代码中使用的是Adam优化器+余弦学习率更新策略,而载入的数据集需要提前将ISIC2018的.png格式文件转换为.npy文件,因此在载入数据集前需要一定的预处理。

  1. def set_optimizer(self):
  2. self.optimizer = torch.optim.Adam(self.net.parameters(), self.opt.train['lr'], betas=(0.9, 0.99), weight_decay=self.opt.train['weight_decay']) # Adam
  3. self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.opt.train['train_epochs']) # 余弦学习率更新算法
  4. '''获取载入数据集'''
  5. def set_dataloader(self):
  6. self.train_set = DataFolder(root_dir=self.opt.root_dir, phase='train', fold=self.opt.fold, gan_aug=self.opt.train['gan_aug'], data_transform=A.Compose(self.opt.transform['train']))
  7. self.val_set = DataFolder(root_dir=self.opt.root_dir, phase='val', data_transform=A.Compose(self.opt.transform['val']), fold=self.opt.fold)
  8. self.train_loader = DataLoader(self.train_set, batch_size=self.opt.train['batch_size'], shuffle=True, num_workers=self.opt.train['workers'])
  9. self.val_loader = DataLoader(self.val_set, batch_size=self.opt.train['batch_size'], shuffle=False, drop_last=False, num_workers=self.opt.train['workers'])

我使用的png to npy 代码

  1. import numpy as np
  2. import imageio
  3. import os
  4. os.chdir('C:\\Users\\Desktop\\MedISeg-main\\isic2018\\mask') # 切换python工作路径到你要操作的图片文件夹,mri_2d_test为我的图片文件夹
  5. a = np.ones((2, 2848, 4288)) # 利用np.ones()函数生成一个三维数组,当然也可用np.zeros,此数组的每个元素a[i]保存一张图片
  6. i = 1
  7. for filename in os.listdir(r"C:\\Users\\Desktop\\MedISeg-main\\isic2018\\mask"): # 使用os.listdir()获取该文件夹下每一张图片的名字
  8. im = imageio.imread(filename)
  9. a[i] = im
  10. i = i + 1
  11. i
  12. if (i == 2): #
  13. break
  14. np.save('C:\\Users\\Desktop\\MedISeg-main\\isic2018\\NumpyData', a)

1.2.5 train

    训练过程其实是一个标准的普通深度学习过程(最基础的训练过程代码在注释),作者对代码的改变为加入了深监督选项,如果在选择主干网络时选择了带有深监督ResUnet,则需要在此选择加入深监督的训练过程,代码如下:

  1. def train(self):
  2. self.net.train() # 开始训练
  3. '''在train函数中采用自定义的AverageMeter类来管理一些变量的更新。在初始化的时候就调用的重置方法reset。当调用该类对象的update方法的时候就会进行变量更新,当要读取某个变量的时候,可以通过对象.属性的方式来读取,
  4. 本质上是对所有batch取平均?/'''
  5. losses = AverageMeter()
  6. for i_batch, sampled_batch in enumerate(self.train_loader): # 过dataloader产生 序号+训练图片(图片+标签) 在这里为 批次序号+采样过的图片序列和标签
  7. volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] # 分imglist+label
  8. volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() # 使用cuda
  9. outputs = self.net(volume_batch) # 输出y_pred结果
  10. '''倘若不使用深度监督 正常计算loss即可'''
  11. if not self.opt.train['deeps']:
  12. loss = self.criterion(outputs, label_batch)
  13. else:
  14. # compute loss for each deep layer, i.e., x0, x1, x2, x3
  15. '''------------------------------------深度监督loss计算部分---------------------------------- '''
  16. '''倘若采用深监督 上面的output应为含有4个大元素的tensor 每个元素之间为h,w参数下采样2倍的关系'''
  17. gts = []
  18. loss = 0.
  19. for i in range(4): # res50 算4个层的loss来判断图片是否良好
  20. gt = label_batch.float().cuda().view(label_batch.shape[0], 1, label_batch.shape[1], label_batch.shape[2]) # [c N h w]
  21. h, w = gt.shape[2] // (2 ** i), gt.shape[3] // (2 ** i) # h,w下采样2倍的原因:resunet网络中有下采样操作 因此各阶段对比的输出为上一阶段的2倍下采样,对应的label也需要下采样来匹配
  22. gt = F.interpolate(gt, size=[h, w], mode='bilinear', align_corners=False) # 降采样
  23. gt = gt.long().squeeze(1) # 降维 [C H W]
  24. gts.append(gt)
  25. loss_list = compute_loss_list(self.criterion, outputs, gts) # 计算4个层分别的loss
  26. for iloss in loss_list:
  27. loss += iloss # 算总loss
  28. self.optimizer.zero_grad() # 清空梯度,方便下面梯度积累
  29. loss.backward() # 反向传播计算梯度
  30. self.optimizer.step() # 根据累计的梯度更新网络参数
  31. losses.update(loss.item(), volume_batch.size(0)) # loss.item() 降低计算 更新参数 计算定量值
  32. return losses.avg # 返回平均loss
  33. '''下面为一个传统batch训练过程:和上面的过程并无二至'''
  34. # self.net.train() # 开始训练
  35. # losses = AverageMeter()
  36. # for i, (images, target) in enumerate(train_loader):
  37. # # 1. input output
  38. # images = images.cuda(non_blocking=True)
  39. # target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True) # 单阶段的训练
  40. # outputs = model(images)
  41. # loss = criterion(outputs, target)
  42. #
  43. # # 2.1 loss regularization
  44. # loss = loss / accumulation_steps # 不一定会有
  45. # # 2.2 back propagation
  46. # loss.backward() # 反向传播计算梯度
  47. # # 3. update parameters of net
  48. # if ((i + 1) % accumulation_steps) == 0:
  49. # # optimizer the net
  50. # optimizer.step() # update parameters of net
  51. # optimizer.zero_grad() # reset gradient

代码中需要注意的一些部分

(1).AverageMeter类

    这个类是深度学习中来管理一些变量的更新,由类中的update方法来实现,具体代码如下:

  1. class AverageMeter(object):
  2. """Computes and stores the average and current value"""
  3. def __init__(self):
  4. self.reset()
  5. '''损失函数初始化(置零)'''
  6. def reset(self):
  7. self.val = 0
  8. self.avg = 0
  9. self.sum = 0
  10. self.count = 0
  11. '''更新参数,计算平均'''
  12. def update(self, val, n=1):
  13. self.val = val
  14. self.sum += val * n
  15. self.count += n
  16. self.avg = self.sum / self.count

    代码中设置这个类主要是对所有batch取平均,对医疗影像分割任务来说,这并不是一个好的策略,应该对一个batch内的所有图片loss取平均更好

(2).深监督(DeepS)

    深监督是为一种辅助学习策略,通过中间层某阶段的处理与label的loss计算来判断图的好坏,为后续的模型处理过程的优劣提供判断基础,示例过程如下:

在这里插入图片描述

     由于网络中有最大池化操作,代码中网络将label降采样对应到网络各阶段输出,这也可以帮助我们更好地了解到哪些训练样本是bad example。

1.2.6 验证过程

  1. def val(self):
  2. '''a) model.eval(),不启用 BatchNormalization 和 Dropout。此时pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会因BN层导致模型performance损失较大;
  3. b) model.train() :启用 BatchNormalization 和 Dropout。 在模型测试阶段使用model.train() 让model变成训练模式,此时 dropout和batch normalization的操作在训练起到防止网络过拟合的问题。'''
  4. self.net.eval() # 开启验证模式
  5. val_losses = AverageMeter() # 载入参数
  6. '''关闭梯度计算'''
  7. with torch.no_grad():
  8. for i_batch, sampled_batch in enumerate(self.val_loader):
  9. volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
  10. volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
  11. outputs = self.net(volume_batch)
  12. '''深监督 输出有4层 outputs[0]为最后的输出'''
  13. if self.opt.train['deeps']:
  14. outputs = outputs[0]
  15. '''算Diceloss'''
  16. val_loss = DiceLoss()(outputs, label_batch)
  17. '''更新参数'''
  18. val_losses.update(val_loss.item(), outputs.size(0))
  19. return val_losses.avg

1.2.7 运行过程

   训练过程运行代码如下:

  1. def run(self):
  2. num_epoch = self.opt.train['train_epochs']
  3. self.logger.info("=> Initial learning rate: {:g}".format(self.opt.train['lr']))
  4. self.logger.info("=> Batch size: {:d}".format(self.opt.train['batch_size']))
  5. self.logger.info("=> Number of training iterations: {:d} * {:d}".format(num_epoch, int(len(self.train_loader))))
  6. self.logger.info("=> Training epochs: {:d}".format(self.opt.train['train_epochs']))
  7. dataprocess = tqdm(range(self.opt.train['start_epoch'], num_epoch)) # 进度条
  8. best_val_loss = 100.0
  9. for epoch in dataprocess:
  10. '''记录状态'''
  11. state = {'epoch': epoch + 1, 'state_dict': self.net.state_dict(), 'optimizer': self.optimizer.state_dict()}
  12. '''训练batch+计算loss'''
  13. train_loss = self.train()
  14. '''验证+计算loss'''
  15. val_loss = self.val()
  16. self.scheduler.step() # 更新学习率 epoch更新1次
  17. self.logger_results.info('{:d}\t{:.4f}\t{:.4f}'.format(epoch+1, train_loss, val_loss))
  18. if val_loss < best_val_loss:
  19. best_val_loss = val_loss
  20. save_bestcheckpoint(state, self.opt.train['save_dir'])
  21. print(f'save best checkpoint at epoch {epoch}')
  22. if (epoch > self.opt.train['train_epochs'] / 2.) and (epoch % self.opt.train['checkpoint_freq'] == 0):
  23. save_checkpoint(state, epoch, self.opt.train['save_dir'], True)
  24. logging.info("training finished")

二、 后续

    这篇论文的目的就是总结医疗影像分割中常见的训练、测试策略,所以其开源代码中各部分模块都可以直接应用到其余类似范畴的代码中,在之后将继续说明2DUnet的测试部分代码,由于本人仅为刚入门的新手,如有纰漏,希望各位能在评论区不吝赐教

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/83934
推荐阅读
相关标签
  

闽ICP备14008679号