当前位置:   article > 正文

语义分割模型的优化_语义分割 边缘优化

语义分割 边缘优化

语义分割模型的优化

当发现验证集指标和训练集指标相差较大时,主要可以检查这些原因:

  1. 数据集类别芜杂、数据量不够,先检查数据集和数据迭代器的质量。
  2. 如果训练集较快拟合,则模型过于庞大,降低了鲁棒性,可以降低batchsize或减少层数、卷积核数量。
  3. 若训练集指标也不正常,将学习率调整到[1e-2, 1e-6],各测试一遍。若loss曲线仍有问题,调整损失函数和激活函数。再不行,换个网络和显卡试试吧。

常用优化技巧

  1. 数据增强
    使用 albumentations,随机旋转、镜像、模糊、色彩映射、噪音、曝光、翻转、色差平移。
    albumentations

  2. 模型的投票
    融合多个模型进行集成,采用多个模型对各个像素点进行投票,对于每张图片在不同scale(不同放大比例、旋转参数)下的结果融合。

  3. TTA (Test-Time Augmentation) ,即测试时的数据增

  4. 强。
    TTA
    TTA
    mask_xor = (mask^bg)&mask

  5. 联合损失函数BCE+Dice+Focal+lovasz_softmax
    将不同损失函数以一定权重混合,获得较好的鲁棒性。
    联合损失函数
    LovaszSoftmax

  6. CRF后处理
    SegNet做语义分割时通常在末端加入CRF模块做后处理,旨在进一步精修边缘的分割结果。
    CRF
    pydensecrf
    CRF-semantic-segmentation

关于学习率的优化

  1. Warm up 预热
    采用小学习率预训练。当训练时出现训练指标上升、验证集指标不动的奇怪现象时,此方法极其有效。预训练后再加大LR。
  2. 按val_miou减小
    使用keras.callbacks.ReduceLROnPlateau,根据val_miou的变化来动态调整lr。
  3. 余弦退火
    CosineAnnealing

损失函数和激活函数、通用评估指标

损失函数

损失函数选用keras.losses.binary_crossentropy

激活函数

单通道输出(二分类,[batch_size, 512, 512 ,1]),选用Sigmoid。
多通道输出(多分类,[batch_size, 512, 512 , N]),选用Softmax,其中第一层是背景层。

通用评估指标 IOU 代码

from tensorflow.keras import backend as K
def Iou_score(y_true, y_pred):
    '''总体的IOU'''
    smooth = 1e-5
    threhold = 0.5
    # score calculation
    y_pred = K.greater(y_pred, threhold)
    y_pred = K.cast(y_pred, K.floatx())
    intersection = K.sum(y_true * y_pred, axis=[0,1,2])
    '''
    这里y_pred为四维,[16,512,512,2],
    axis=[0,1,2]时输出的intersection是每个类别的得分(准确的数量,[20, 30])。
    [16,512,512,1]时候也通用。
    '''
    union = K.sum(y_true + y_pred, axis=[0,1,2]) - intersection
    return (intersection + smooth) / (union + smooth)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

实现数据增强的语义分割数据迭代器代码

from tensorflow.keras.utils import Sequence
import random, os, gc, cv2
import numpy as np
seed = 295
random.seed(seed)
import os
import cv2 as cv
import albumentations as A
 # pip install albumentations -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com

class SequenceData(Sequence):

    def __init__(self, images_dir_X, images_dir_Y, img_size=(256,256),imgOnly=False,
                 batch_size=1, classes=None, imgListTxt=None, isOneHot=True, dataEnhancement=True):
        with open(imgListTxt) as f:
            # 为了对应样本和标签 读入的文件名不带后缀
            self.datas = list(f.readlines())
        self.images_dir_X = images_dir_X
        self.images_dir_Y = images_dir_Y
        self.batch_size = batch_size
        self.L = len(self.datas)
        self.img_size = img_size
        self.index = random.sample(range(self.L), self.L)
        self.classes=classes
        self.isOneHot=isOneHot
        self.imgOnly=imgOnly
        
        self.dataEnhancement=dataEnhancement
        prob = 0.4
        self.transform = A.Compose([
            # A.RandomCrop(width=256, height=256),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=(0, 0.2),
                                       contrast_limit=(0, 0.2), p=prob),
            A.Rotate(limit=30, interpolation=cv2.INTER_CUBIC, border_mode=4, p=prob),
            A.RandomGamma(gamma_limit=(80, 120), eps=1e-07, p=prob),
            A.MotionBlur(blur_limit=5, p=prob),
            A.IAASharpen(p=prob),
            A.IAAPerspective(p=prob),
            A.GaussNoise(var_limit=(10.0, 50.0), p=prob),
            
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=110,
                                          val_shift_limit=10, p=prob),
            A.RGBShift(r_shift_limit=5, g_shift_limit=5, b_shift_limit=5, p=prob),
            A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, interpolation=2, p=prob),
            A.GridDistortion(num_steps=5, distort_limit=0.3, interpolation=2, p=prob),
        ])
       
        
    # 返回长度,通过len(<你的实例>)调用
    def __len__(self):
        return int(np.ceil(len(self.datas) / self.batch_size))
    
    # 通过索引获取a[0],a[1]这种
    def __getitem__(self, idx):
        
        batch_indexs = self.index[idx:(idx+self.batch_size)]
        batch_datas = [self.datas[k] for k in batch_indexs]
        
        images = self.load_image_from_directory(
            images_dir=self.images_dir_X, img_size=self.img_size,
            suffix='.jpg', imgList=batch_datas)
        if self.imgOnly:
            return images
        
        labels = self.load_image_from_directory(
            images_dir=self.images_dir_Y, img_size=self.img_size,
            isLabel=True, isOneHot=self.isOneHot, classes=self.classes, suffix='.png',
            imgList=batch_datas)
        
        if self.dataEnhancement:
            images,labels = self.dataEnhancing(images,labels)
            
        
        return images,labels

    def load_image_from_directory(self, images_dir, img_size, isLabel=False,
                                  classes=None,isOneHot=True, suffix='.png',
                                  dtype=np.float32, imgList=None):
        """
        从数据集目录中加载图像数组
        :param images_dir: 数据集目录,原图或标签,该文件夹下直接是图像,三通道,24位深度
        :param img_size: 网络要求的图像大小
        :param isLabel: 是否为标签, 测试集的标签由于只是显示作用,不参与训练,因此保持默认False即可
        :param classes: 类,与isLabel相匹配,isLabel=True时,必须赋值
        :param suffix: 图像后缀
        :param dtype: 图像数据类型
        :param imgList: 图像文件名列表
        :return: 返回图像数组,[图像个数,高,宽,通道数]
        """
        images_path = []
        for fname in imgList:
            # if fname.endswith(suffix) and not fname.startswith('.'):
            # 数组填充,将图像绝对路径添加至数组images_path末尾
            images_path.append(os.path.join(images_dir, fname[:-1]+suffix))
        images_path = sorted(images_path)  # 按顺序整理
        # print(len(images_path))
        images = []  # 创建空数组
        for i, path in enumerate(images_path):
            img = cv.imdecode(np.fromfile(path, dtype=np.uint8),
                              cv.IMREAD_COLOR)  # 可为中文路径读取图像
            # img = cv.imread(path)  # 读取图像, 无中文路径
            if img.shape[:2] is not img_size:
                # img = cv.resize(img, dsize=img_size,
                #                 interpolation=cv.INTER_CUBIC)  # 不满足条件重置图像尺寸
                img = self.resize_img_keep_ratio(img=img, target_size=img_size)
            img = img[:, :, ::-1]  # 交换图像通道
            if isLabel:
                if isOneHot:
                    # 创建空数组用于储存图像数组
                    newImg = np.zeros(img.shape[:2] + (len(classes),), dtype=dtype)
                    for j, value in enumerate(classes.values()):
                        newImg[np.bitwise_and(np.bitwise_and(
                            img[:, :, 0] == value[0], img[:, :, 1] == value[1]),
                            img[:, :, 2] == value[2]), j] = 1  # 把标签转为one-hot形式
                    img = newImg
                else:
                    img = img[:,:,0] / 255.0
                    img[img>=0.5] = 1.0
                    img[img<0.5] = 0.0
                    img = np.expand_dims(img, axis=-1)  # 扩充维度
            images.append(img)
            # images.append(self.resize_img_keep_ratio(img=img, target_size=img_size))
        images = np.array(images, dtype=dtype)  # 改变数组类型
        if not isLabel:
            images = images/255.0
        # print(images.shape)
        return images
    def resize_img_keep_ratio(self, img=None,img_name=None,target_size=(256,256)):
            '''
            1.resize图片,先计算最长边的resize的比例,然后按照该比例resize。
            2.计算四个边需要padding的像素宽度,然后padding
            '''
            if img is None:
                img = cv2.imread(img_name)
            old_size= img.shape[0:2]
            ratio = min(float(target_size[i])/(old_size[i]) for i in range(len(old_size)))
            new_size = tuple([int(i*ratio) for i in old_size])
            img = cv2.resize(img,(new_size[1], new_size[0]),interpolation=cv2.INTER_CUBIC)  #注意插值算法
            pad_w = target_size[1] - new_size[1]
            pad_h = target_size[0] - new_size[0]
            top,bottom = pad_h//2, pad_h-(pad_h//2)
            left,right = pad_w//2, pad_w -(pad_w//2)
            img_new = cv2.copyMakeBorder(img,top,bottom,left,right,cv2.BORDER_CONSTANT,None,(0,0,0))
            # return cv2.cvtColor(img_new, cv2.COLOR_BGR2RGB)
            return img_new
    def dataEnhancing(self,images,labels,dtype=np.float32):
        '''数据增强'''
        transformed_images,transformed_masks=[],[]
        # for image,mask in zip(images,labels):
        # print(images.shape[0])
        for i in range(images.shape[0]):
            image=np.asarray(images[i,...]*255.0).astype(np.uint8)
            mask=labels[i,...]
            if(mask.shape[-1]==1):
                mask=mask.reshape(mask.shape[:-1])
            transformed = self.transform(image=image, mask=mask)
            outImg = np.asarray(transformed["image"]).astype(dtype)/255.0
            transformed_images.append(transformed["image"])
            transformed_masks.append(transformed["mask"])
        transformed_images=np.asarray(transformed_images,dtype=dtype)
        transformed_masks=np.asarray(transformed_masks,dtype=dtype)
        # print(transformed_images.shape,transformed_masks.shape)
        return transformed_images,transformed_masks
   


# from tensorflow.keras.preprocessing.image import array_to_img
# def showImg(frame):
#     array_to_img(frame).show()
# classes = dict(
#         [('background', [0, 0, 0]), ('object', [255, 255, 255])]) 
# trainSDG = SequenceData(r".\ourData\combine\images\1",
#                 r"ourData\combine\labels\1",
#             img_size=(512,512), classes=classes, batch_size=1,
#             imgListTxt=r".\ourData\combine\sets\train.txt", dataEnhancement=False)
# xx=trainSDG.__getitem__(9)
# x0=np.array(xx[0])[0,...]*255.0
# x1=np.array(xx[1])[0,...]
# showImg(np.expand_dims(x1[:,:,0],axis=-1))
# showImg(np.expand_dims(x1[:,:,1],axis=-1))
# showImg(x0)

# from tensorflow.keras.preprocessing.image import array_to_img
# def showImg(frame):
#     array_to_img(frame).show()
# classes = dict(
#         [('background', [0, 0, 0]), ('object', [255, 255, 255])]) 
# trainSDG = SequenceData(r".\ourData\combine\images\1",
#                 r"ourData\combine\labels\1", isOneHot=False,
#             img_size=(512,512), classes=classes, batch_size=2,
#             imgListTxt=r".\ourData\combine\sets\train.txt", dataEnhancement=False)
# xx=trainSDG.__getitem__(5)
# x0=np.array(xx[0])[0,...]*255.0
# x1=np.array(xx[1])[0,...]
# showImg(x1[:,:])
# showImg(x0)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/234624
推荐阅读
相关标签
  

闽ICP备14008679号