当前位置:   article > 正文

TensorFlow读取自定义数据集(python版本)_基于tensorflow如何读取自己的图片数据集python

基于tensorflow如何读取自己的图片数据集python

以anime数据集为例:

  1. import multiprocessing
  2. import tensorflow as tf
  3. def batch_dataset(dataset,
  4. batch_size,
  5. drop_remainder=True,
  6. n_prefetch_batch=1,
  7. filter_fn=None,
  8. map_fn=None,
  9. n_map_threads=None,
  10. filter_after_map=False,
  11. shuffle=True,
  12. shuffle_buffer_size=None,
  13. repeat=None):
  14. # set defaults
  15. if n_map_threads is None:
  16. n_map_threads = multiprocessing.cpu_count()
  17. if shuffle and shuffle_buffer_size is None:
  18. shuffle_buffer_size = max(batch_size * 128, 2048) # set the maximum buffer size as 2048
  19. # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
  20. if shuffle:
  21. dataset = dataset.shuffle(shuffle_buffer_size)
  22. if not filter_after_map:
  23. if filter_fn:
  24. dataset = dataset.filter(filter_fn)
  25. if map_fn:
  26. dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
  27. else: # [*] this is slower
  28. if map_fn:
  29. dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
  30. if filter_fn:
  31. dataset = dataset.filter(filter_fn)
  32. dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
  33. dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)
  34. return dataset
  35. def memory_data_batch_dataset(memory_data,
  36. batch_size,
  37. drop_remainder=True,
  38. n_prefetch_batch=1,
  39. filter_fn=None,
  40. map_fn=None,
  41. n_map_threads=None,
  42. filter_after_map=False,
  43. shuffle=True,
  44. shuffle_buffer_size=None,
  45. repeat=None):
  46. """Batch dataset of memory data.
  47. Parameters
  48. ----------
  49. memory_data : nested structure of tensors/ndarrays/lists
  50. """
  51. dataset = tf.data.Dataset.from_tensor_slices(memory_data) # 将路径转换为tensor类型
  52. dataset = batch_dataset(dataset,
  53. batch_size,
  54. drop_remainder=drop_remainder,
  55. n_prefetch_batch=n_prefetch_batch,
  56. filter_fn=filter_fn,
  57. map_fn=map_fn,
  58. n_map_threads=n_map_threads,
  59. filter_after_map=filter_after_map,
  60. shuffle=shuffle,
  61. shuffle_buffer_size=shuffle_buffer_size,
  62. repeat=repeat)
  63. return dataset
  64. def disk_image_batch_dataset(img_paths,
  65. batch_size,
  66. labels=None,
  67. drop_remainder=True,
  68. n_prefetch_batch=1,
  69. filter_fn=None,
  70. map_fn=None,
  71. n_map_threads=None,
  72. filter_after_map=False,
  73. shuffle=True,
  74. shuffle_buffer_size=None,
  75. repeat=None):
  76. """Batch dataset of disk image for PNG and JPEG.
  77. Parameters
  78. ----------
  79. img_paths : 1d-tensor/ndarray/list of str
  80. labels : nested structure of tensors/ndarrays/lists
  81. """
  82. if labels is None: # 此时图片数据都还没有读进内存
  83. memory_data = img_paths
  84. else:
  85. memory_data = (img_paths, labels)
  86. import tensorflow_io as tfio
  87. def parse_fn(path, *label): # 将图片数据读进内存
  88. img = tf.io.read_file(path)
  89. img = tf.image.decode_jpeg(img, channels=3) # fix channels to 3
  90. # 读取医学图像dicom个数的数据,使用的api是tfio.image.decode_dicom_image()
  91. # 需要先使用 img = image_bytes = tf.io.read_file('xx.dcm')将dicom数据读进内存
  92. # img = tfio.image.decode_dicom_image()
  93. return (img,) + label
  94. if map_fn: # fuse `map_fn` and `parse_fn`
  95. def map_fn_(*args):
  96. return map_fn(*parse_fn(*args))
  97. else:
  98. map_fn_ = parse_fn
  99. dataset = memory_data_batch_dataset(memory_data,
  100. batch_size,
  101. drop_remainder=drop_remainder,
  102. n_prefetch_batch=n_prefetch_batch,
  103. filter_fn=filter_fn,
  104. map_fn=map_fn_,
  105. n_map_threads=n_map_threads,
  106. filter_after_map=filter_after_map,
  107. shuffle=shuffle,
  108. shuffle_buffer_size=shuffle_buffer_size,
  109. repeat=repeat)
  110. return dataset
  111. # 加载自定义数据集进TensorFlow的主要函数,drop_reminder参数是当数据集大小不能整除batch_size时是否丢掉余数部分
  112. def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
  113. # @tf.function
  114. def _map_fn(img): # 对图片数据进行归一化处理
  115. img = tf.image.resize(img, [resize, resize])
  116. # img = tf.image.random_crop(img,[resize, resize])
  117. # img = tf.image.random_flip_left_right(img)
  118. # img = tf.image.random_flip_up_down(img)
  119. img = tf.clip_by_value(img, 0, 255)
  120. img = img / 127.5 - 1 # -1~1
  121. return img
  122. dataset = disk_image_batch_dataset(img_paths,
  123. batch_size,
  124. drop_remainder=drop_remainder,
  125. map_fn=_map_fn,
  126. shuffle=shuffle,
  127. repeat=repeat)
  128. img_shape = (resize, resize, 3)
  129. len_dataset = len(img_paths) // batch_size
  130. return dataset, img_shape, len_dataset
  131. '''
  132. 说下自己对代码的理解:
  133. 将图片路径转化为tensor,map函数中的第一个参数func函数负责将图片读进内存并讲图片数据归一化,此处这个func函数的调用使用的是
  134. 回调函数机制。数据集的批量大小以及drop_remainder均是通过dataset.batch这个api来实现和处理的。
  135. 粗浅的理解不知正确与否,若有大佬知道,恳请指点
  136. '''

 

这个代码出自龙良曲老师的《深度学习与TensorFlow入门实战》GAN实战-3,不过现在B站已经将这个视频下架了(所以填转载都没有链接了,只能厚颜无耻的写成原创了),只能去某盘找了

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/123806
推荐阅读
相关标签
  

闽ICP备14008679号