赞
踩
tensorflow1.x的读取方式:
tensorflow1.12以上的读取方式:(最好是1.13.1或者2.x)
https://blog.csdn.net/Black_Friend/article/details/104529859
- import tensorflow as tf
- import random
- import pathlib
- data_path = pathlib.Path('./database1/')
- print(type(data_path))#<class 'pathlib.WindowsPath'>
- all_image_paths = list(data_path.glob('*/*'))
- print(type(data_path.glob('*/*')))#<class 'generator'>
- # print(all_image_paths)
-
- all_image_paths = [str(path) for path in all_image_paths] # 所有图片路径的列表
- random.shuffle(all_image_paths) # 打散
- # print(all_image_paths[0:3])
-
- image_count = len(all_image_paths)
- print('image_count: ',image_count)
-
- label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir())
- print('label_names: ',label_names)
- label_to_index = dict((name, index) for index, name in enumerate(label_names))
- print('label_to_index: ',label_to_index)
- all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
-
-
-
- db_train = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))
-
- def load_and_preprocess_from_path_label(path, label):
-
- image = tf.io.read_file(path) # 读取图片
- image = tf.image.decode_jpeg(image, channels=3)
- image = tf.cast(image, dtype=tf.float32) / 255.0
- # image = tf.image.resize(image, [28, 28]) # 原始图片大小为(100, 100, 3),重设为(192, 192)
- # image /= 255.0 # 归一化到[0,1]范围
-
- label = tf.cast(label, dtype=tf.int32)
- label = tf.one_hot(label, depth=10)
- return image, label
-
- db_train.shuffle(1000)
- db_train.map(load_and_preprocess_from_path_label)
- db_train.batch(64)
- db_train.repeat(2)
- print(type(db_train))#<class 'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter'>
- print(db_train.output_shapes)#(TensorShape([]), TensorShape([]))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。