当前位置:   article > 正文

学习TensrFlow 2 的随笔(四)Dataset.map()的使用_dataset map

dataset map

上一篇初步说了Dataset中的一些问题,这里还要记录一下Dataset.map()中的一些特别容易出问题的东西。学习TensrFlow 2 的随笔(三)tf.data.Dataset

  • 1、怎么在映射函数中进行tf.Tensor的类型转换?
    在映射函数里,一般都要求进行TensorFlow的操作。此时如果想将Tensor的类型进行转换,比如想将bool类型转为float。这样就可以直接将掩膜转为可以计算的Tensor.
    mask=tf.image.convert_image_dtype(mask,tf.float32)
    一般进行数据类型的转换,这个函数可以进行。但是它只可以进行以下类型间的转换:bfloat16, half, float, double, uint8, int8, uint16, int16, int32, uint32, uint64, int64, complex64, complex128。可以发现没有bool,所以不可以。
    mask=tf.cast(mask,tf.float32)
    这个函数才可以实现转换,因为它有bool类型。可将将True转为1,False转为0输出。
  • 2、映射函数中就是需要进行python的逻辑计算该怎么办?
    答案:使用tf.py_function()函数
    比如想将图片进行任意角度旋转,这个在tf.image没法进行。那就采用python函数来进行。例子:
    首先定义一个图像旋转函数,这个函数是纯python的,跟TensorFlow运算没有任何关系。
import scipy.ndimage as ndimage
def random_rotate_image(image):
  image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)
  return image
  • 1
  • 2
  • 3
  • 4

然后,利用tf.py_function()把这个函数包起来,放在一个映射函数里。特别注意的是使用这个函数时要明确表示出返回的形状shapes和类型types。原因和上一篇里说的一样,tf.Graph需要边缘。一定要对返回进行shape定义,就是set_shape(tf.TensorShape([None,None,3]))这一步一定要有。

def tf_random_rotate_image(image, label):
  im_shape = image.shape
  [image,] = tf.py_function(random_rotate_image, [image], [tf.float32])
  image.set_shape(im_shape)
  return image, label
  • 1
  • 2
  • 3
  • 4
  • 5

最后将映射函数放入到map中:

rot_ds = images_ds.map(tf_random_rotate_image)
for image, label in rot_ds.take(2):
  show(image, label)
  • 1
  • 2
  • 3

参考:
1.tf.data
2.tf.py_function

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/249761
推荐阅读
  

闽ICP备14008679号