当前位置:   article > 正文

在语义分割中数据增强的正确打开方式(tf.data)_self.random_crop

self.random_crop

通过一系列的
随即翻转
随机旋转
随机剪裁
随机亮度调节
随机对比度调节
随机色相调节
随机饱和度调节
随机高斯噪声
让数据集变得更强大!

class Parse(object):
	'''
	 callable的类,返回一个进行数据解析和数据增强的仿函数
		使用示例:
			def make_dataset(self,batch_size,output_shape,argumentation=False,buffer_size=4000,epochs=None,shuffle=True):
		        filename=[self.tfr_path]
			    parse=Parse(self.img_shape,output_shape,argumentation,num_classes=self.num_classes)
			    dataset=tf.data.TFRecordDataset(filename)
			    dataset=dataset.prefetch(tf.contrib.data.AUTOTUNE)
			    dataset=dataset.shuffle(buffer_size=buffer_size,seed=int(time()))
		        dataset=dataset.repeat(count=None
		        dataset=dataset.map(parse)
		        dataset=dataset.batch(batch_size=batch_size)
		        dataset=dataset.apply(tf.data.experimental.prefetch_to_device("/gpu:0"))
		        return dataset
    '''
    def __init__(self,raw_shape,out_shape,argumentation,num_classes):
        """
            返回一个Parse类的对象
            raw_shape:TFRecord文件中example的图像大小
            out_shape:随机剪裁后的图像大小
            argumentation:Bool变量,如果为0就只解析出图像并裁剪,不进行任何数据增强
            num_classes:类别总数(包括背景),用于确定one hot的维度
        """
        self.__argumantation=argumentation
        self.raw_shape=raw_shape
        self.out_shape=out_shape
        self.num_classes=num_classes
    def argumentation(self,image,labels):
        """
           单独对标签进行数据增强
           输入标签的one hot编码后的张量,输出和原图大小同样的张量
        """
        image=tf.cast(image,tf.float32)
        image,labels=self.random_crop_flip_rotate(image,labels)
        image=tf.image.random_brightness(image,max_delta=0.4)  #随机亮度调节
        image=tf.image.random_contrast(image,lower=0.7,upper=1.3)#随机对比度
        image=tf.image.random_hue(image,max_delta=0.3)#随机色相
        image=tf.image.random_saturation(image,lower=0.8,upper=1.3)#随机饱和度
        image=tf.cast(image,tf.float32)
        image=image+tf.truncated_normal(stddev=4,mean=2,shape=image.shape.as_list(),seed=int(time()))#加上高斯噪声
        image=tf.clip_by_value(image,0.0,255.0)
        return image,labels
    
    def random_rotate(self,input_image, min_angle=-np.pi/2,max_angle=np.pi/2):
        '''
        TensorFlow对图像进行随机旋转
        :param input_image: 图像输入
        :param min_angle: 最小旋转角度
        :param max_angle: 最大旋转角度
        :return: 旋转后的图像
        '''
        distorted_image = tf.expand_dims(input_image, 0)
        random_angles = tf.random.uniform(shape=(tf.shape(distorted_image)[0],), minval = min_angle , maxval = max_angle)
        distorted_image = tf.contrib.image.transform(
            distorted_image,
            tf.contrib.image.angles_to_projective_transforms(
                random_angles, tf.cast(tf.shape(distorted_image)[1], tf.float32), tf.cast(tf.shape(distorted_image)[2], tf.float32)
            ))
        rotate_image = tf.squeeze(distorted_image, [0])
        return rotate_image
    def random_crop_flip_rotate(self,image1,image2):#图片和对应标签同步进行翻转,旋转和裁剪
        image=tf.concat([image1,image2],axis=-1)
        channel=image.shape.as_list()[-1]
        shape=self.out_shape+[channel]
        print(shape)
        image=self.random_rotate(image)
        image=tf.image.random_crop(image,shape)
        image=tf.image.random_flip_left_right(image)
        image1=tf.slice(image,[0,0,0],self.out_shape+[3])
        image2=tf.slice(image,[0,0,3],self.out_shape+[-1])
        return image1,image2
    def __call__(self,tensor):
        """
            make it callable
            inputs labels 分别为图片和对应标签
            应当具有[w,h,3] 和[w,h] 的形状
        """
        feature=tf.parse_single_example(tensor,features={
            "inputs":tf.FixedLenFeature([],tf.string),
            "labels":tf.FixedLenFeature([],tf.string)
        })
        inputs=tf.decode_raw(feature["inputs"],tf.uint8)
        inputs=tf.reshape(inputs,self.raw_shape+[3])
        labels=tf.decode_raw(feature["labels"],tf.uint8)
        labels=tf.reshape(labels,self.raw_shape)
        labels=tf.one_hot(labels,self.num_classes)
        if self.__argumantation:
            inputs,labels=self.argumentation(inputs,labels)
        else:
            inputs=tf.image.resize_image_with_crop_or_pad(inputs,self.out_shape[0],self.out_shape[1])
            labels=tf.image.resize_image_with_crop_or_pad(labels,self.out_shape[0],self.out_shape[1])
        inputs=tf.image.per_image_standardization(inputs) #标准化
        return inputs,labels
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94

#UPD:
在这里插入图片描述
初步排查了一下是random_size里面tf.image.resize_images的问题
解决起来好办
就是resize过后强行把边框几个像素丢掉

    def random_size(self,image,minratio=0.5,maxratio=2.0,prob=0.5):
        height,width=image.shape.as_list()[:2]
        min_height=height*minratio
        max_height=height*maxratio
        min_width=width*minratio
        max_width=width*maxratio
        height=tf.random_uniform(shape=[],minval=min_height,maxval=max_height)
        height=tf.cast(height,tf.int32)
        width=tf.random_uniform(shape=[],minval=min_width,maxval=max_width)
        width=tf.cast(width,tf.int32)
        _prob=tf.random_uniform(shape=[],minval=0.0,maxval=1.0)
        r_image=tf.image.resize_images(image,[height+4,width+4],method=2)
        r_image=tf.image.resize_image_with_crop_or_pad(r_image,self.out_shape[0],self.out_shape[1])
        return tf.cond(_prob<prob,lambda:r_image,lambda:image)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

虽然损失了边框部分
但是问题不是很大,反而不去除的话会干扰训练过程
网络表示并不能知道为什么图片中有个边界框…

最终效果:
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/92591
推荐阅读
相关标签
  

闽ICP备14008679号