赞
踩
通过一系列的
随即翻转
随机旋转
随机剪裁
随机亮度调节
随机对比度调节
随机色相调节
随机饱和度调节
随机高斯噪声
让数据集变得更强大!
class Parse(object): ''' callable的类,返回一个进行数据解析和数据增强的仿函数 使用示例: def make_dataset(self,batch_size,output_shape,argumentation=False,buffer_size=4000,epochs=None,shuffle=True): filename=[self.tfr_path] parse=Parse(self.img_shape,output_shape,argumentation,num_classes=self.num_classes) dataset=tf.data.TFRecordDataset(filename) dataset=dataset.prefetch(tf.contrib.data.AUTOTUNE) dataset=dataset.shuffle(buffer_size=buffer_size,seed=int(time())) dataset=dataset.repeat(count=None dataset=dataset.map(parse) dataset=dataset.batch(batch_size=batch_size) dataset=dataset.apply(tf.data.experimental.prefetch_to_device("/gpu:0")) return dataset ''' def __init__(self,raw_shape,out_shape,argumentation,num_classes): """ 返回一个Parse类的对象 raw_shape:TFRecord文件中example的图像大小 out_shape:随机剪裁后的图像大小 argumentation:Bool变量,如果为0就只解析出图像并裁剪,不进行任何数据增强 num_classes:类别总数(包括背景),用于确定one hot的维度 """ self.__argumantation=argumentation self.raw_shape=raw_shape self.out_shape=out_shape self.num_classes=num_classes def argumentation(self,image,labels): """ 单独对标签进行数据增强 输入标签的one hot编码后的张量,输出和原图大小同样的张量 """ image=tf.cast(image,tf.float32) image,labels=self.random_crop_flip_rotate(image,labels) image=tf.image.random_brightness(image,max_delta=0.4) #随机亮度调节 image=tf.image.random_contrast(image,lower=0.7,upper=1.3)#随机对比度 image=tf.image.random_hue(image,max_delta=0.3)#随机色相 image=tf.image.random_saturation(image,lower=0.8,upper=1.3)#随机饱和度 image=tf.cast(image,tf.float32) image=image+tf.truncated_normal(stddev=4,mean=2,shape=image.shape.as_list(),seed=int(time()))#加上高斯噪声 image=tf.clip_by_value(image,0.0,255.0) return image,labels def random_rotate(self,input_image, min_angle=-np.pi/2,max_angle=np.pi/2): ''' TensorFlow对图像进行随机旋转 :param input_image: 图像输入 :param min_angle: 最小旋转角度 :param max_angle: 最大旋转角度 :return: 旋转后的图像 ''' distorted_image = tf.expand_dims(input_image, 0) random_angles = tf.random.uniform(shape=(tf.shape(distorted_image)[0],), minval = min_angle , maxval = max_angle) distorted_image = tf.contrib.image.transform( distorted_image, tf.contrib.image.angles_to_projective_transforms( random_angles, tf.cast(tf.shape(distorted_image)[1], tf.float32), tf.cast(tf.shape(distorted_image)[2], tf.float32) )) rotate_image = tf.squeeze(distorted_image, [0]) return rotate_image def random_crop_flip_rotate(self,image1,image2):#图片和对应标签同步进行翻转,旋转和裁剪 image=tf.concat([image1,image2],axis=-1) channel=image.shape.as_list()[-1] shape=self.out_shape+[channel] print(shape) image=self.random_rotate(image) image=tf.image.random_crop(image,shape) image=tf.image.random_flip_left_right(image) image1=tf.slice(image,[0,0,0],self.out_shape+[3]) image2=tf.slice(image,[0,0,3],self.out_shape+[-1]) return image1,image2 def __call__(self,tensor): """ make it callable inputs labels 分别为图片和对应标签 应当具有[w,h,3] 和[w,h] 的形状 """ feature=tf.parse_single_example(tensor,features={ "inputs":tf.FixedLenFeature([],tf.string), "labels":tf.FixedLenFeature([],tf.string) }) inputs=tf.decode_raw(feature["inputs"],tf.uint8) inputs=tf.reshape(inputs,self.raw_shape+[3]) labels=tf.decode_raw(feature["labels"],tf.uint8) labels=tf.reshape(labels,self.raw_shape) labels=tf.one_hot(labels,self.num_classes) if self.__argumantation: inputs,labels=self.argumentation(inputs,labels) else: inputs=tf.image.resize_image_with_crop_or_pad(inputs,self.out_shape[0],self.out_shape[1]) labels=tf.image.resize_image_with_crop_or_pad(labels,self.out_shape[0],self.out_shape[1]) inputs=tf.image.per_image_standardization(inputs) #标准化 return inputs,labels
#UPD:
初步排查了一下是random_size里面tf.image.resize_images的问题
解决起来好办
就是resize过后强行把边框几个像素丢掉
def random_size(self,image,minratio=0.5,maxratio=2.0,prob=0.5):
height,width=image.shape.as_list()[:2]
min_height=height*minratio
max_height=height*maxratio
min_width=width*minratio
max_width=width*maxratio
height=tf.random_uniform(shape=[],minval=min_height,maxval=max_height)
height=tf.cast(height,tf.int32)
width=tf.random_uniform(shape=[],minval=min_width,maxval=max_width)
width=tf.cast(width,tf.int32)
_prob=tf.random_uniform(shape=[],minval=0.0,maxval=1.0)
r_image=tf.image.resize_images(image,[height+4,width+4],method=2)
r_image=tf.image.resize_image_with_crop_or_pad(r_image,self.out_shape[0],self.out_shape[1])
return tf.cond(_prob<prob,lambda:r_image,lambda:image)
虽然损失了边框部分
但是问题不是很大,反而不去除的话会干扰训练过程
网络表示并不能知道为什么图片中有个边界框…
最终效果:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。