赞
踩
上一篇初步说了Dataset中的一些问题,这里还要记录一下Dataset.map()中的一些特别容易出问题的东西。学习TensrFlow 2 的随笔(三)tf.data.Dataset
mask=tf.image.convert_image_dtype(mask,tf.float32)
mask=tf.cast(mask,tf.float32)
tf.py_function()
函数import scipy.ndimage as ndimage
def random_rotate_image(image):
image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)
return image
然后,利用tf.py_function()
把这个函数包起来,放在一个映射函数里。特别注意的是使用这个函数时要明确表示出返回的形状shapes和类型types。原因和上一篇里说的一样,tf.Graph
需要边缘。一定要对返回进行shape定义,就是set_shape(tf.TensorShape([None,None,3]))
这一步一定要有。
def tf_random_rotate_image(image, label):
im_shape = image.shape
[image,] = tf.py_function(random_rotate_image, [image], [tf.float32])
image.set_shape(im_shape)
return image, label
最后将映射函数放入到map中:
rot_ds = images_ds.map(tf_random_rotate_image)
for image, label in rot_ds.take(2):
show(image, label)
参考:
1.tf.data
2.tf.py_function
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。