赞
踩
glob.glob
)tf.reshape(test_image_label,[1])
)def get_image_path():
train_image_path=glob.glob('./data/dc_2000/train/*/*.jpg')
test_image_path=glob.glob('./data/dc_2000/test/*/*.jpg')
train_image_label=[1 if elem.split('\\')[1]=='dog' else 0 for elem in train_image_path]
test_image_label=[1 if elem.split('\\')[1]=='dog' else 0 for elem in train_image_path]
train_image_label=tf.reshape(train_image_label,[1])
test_image_label=tf.reshape(test_image_label,[1])
tf.io.read_file(path)
)tf.image.decode_jpeg(image,channels=3)
注意通道数量)tf.image.resize(image,[256,256])
)tf.image.cast(image,tf.float32)
)image=image/255
)注:tf.image.convert_image_dtype
函数会将图片格式转化为float32
,并执行归一化,如果原数据类型是float32
,则不会进行数据归一化的操作
def load_process_file(filepath):
image = tf.io.read_file(filepath)
image = tf.image.decode_jpeg(image,channels=3)
image = tf.image.resize(image,[256,256])
image=tf.cast(image,tf.float32)/255
# image = tf.image.convert_image_dtype #次函数会将图片格式转化为float32,并执行归一化,如果原数据类型是float32,则不会进行数据归一化的操作
tf.data.Dataset.from_tensor_slices((图片路径列表,标签列表))
生成dataset
dataset
执行图像预处理函数 (dataset.map(load_process_file,num_parallel_calls=tf.data.experimental.AUTOTUNE)
)dataset.shuffle(1000)
)datase.batch(32)
)prefetch(tf.data.experimental.AUTOTUNE)
增加图片读取速度train_ds=tf.data.Dataset.from_tensor_slices((train_image_path,train_image_label))
train_ds.map(load_process_file,num_parallel_calls=tf.data.experimental.AUTOTUNE) #使用多线程,线程数自适应
test_ds = tf.data.Dataset.from_tensor_slices((test_image_path, test_image_label))
test_ds.map(load_process_file, num_parallel_calls=tf.data.experimental.AUTOTUNE) # 使用多线程,线程数自适应
BATCH_SIZE=32
train_count=len(train_image_path)
test_count=len(test_image_path)
train_ds=train_ds.repeat().shuffle(train_count).batch(BATCH_SIZE)
test_ds=test_ds.batch(BATCH_SIZE)
train_ds=train_ds.prefetch(tf.data.experimental.AUTOTUNE)
test_ds=test_ds.prefetch(tf.data.experimental.AUTOTUNE)
imgs,labels =next(iter(train_ds))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。