当前位置:   article > 正文

【Tensorflow】训练keras模型+keras的数据生成器ImageDataGenerator+jpg图像数据格式的MNIST数据集+对比flow和flow_from_directory_tensorflow 数据集 数据生成器

tensorflow 数据集 数据生成器

1.数据集

jpg图像数据格式的MNIST数据集:(放在database1文件夹下面)

 

2.利用tkeras的数据生成器ImageDataGenerator去训练keras模型(MobileNet)

通过tf.io.file读取图像文件,然后用tf.image.decode_jpeg转文件格式,保存为np.array(list(float,int))格式

这样子的话,图像数据集就是全部被读取到内存中,这是非常占内存的!!!

接下来可以有两种方法训练keras模型:

  1. np.array(list(float,int))直接用model.fit
  2. np.array(list(float,int))转ImageDataGenerator,然后用model.fit_generator

这里面的区别就在于:

第1种方法则只能在保存为np.array(list(float,int))之前做一些数据增强或者图像处理,处理完之后就会固定下来不变,不能在训练过程中处理图像数据。

第2种方法中ImageDataGenerator可以在训练过程随机的进行数据增强或者图像处理。

 

训练过程如下:

  1. 通过tf.io.file读取图像文件,然后用tf.image.decode_jpeg转文件格式,保存为np.array(list(float,int))格式
  2. np.array(list(float,int))转ImageDataGenerator(或者不转也可以)
  3. 构建keras网络模型(MobileNet):from tensorflow_core.python.keras.applications.mobilenet import MobileNet
  4. keras训练:model.fit或者model.fit_generator

 

1)np.array(list(float,int))直接用model.fit

完整代码:

train_keras_from_nparraydata.py

  1. import tensorflow as tf
  2. import random
  3. import pathlib
  4. from tensorflow import keras
  5. import os
  6. import numpy as np
  7. from my_input_data import input_data_list
  8. import gc
  9. im_w=128
  10. im_h=128
  11. # im_channels=3
  12. train_image_list,train_label_list,test_image_list,test_label_list,label_names=input_data_list('../database1/')
  13. classes=len(label_names)
  14. nums_for_training=len(train_image_list)
  15. batch_size=32
  16. steps_per_epoch=int(nums_for_training/batch_size)
  17. epochs=10
  18. print(nums_for_training)
  19. print(steps_per_epoch)
  20. def train_preprocess_image(path):
  21. image = tf.io.read_file(path) # 读取图片
  22. image = tf.image.decode_jpeg(image, channels=3)
  23. # image = tf.image.grayscale_to_rgb(image)
  24. image = tf.image.resize(image, [im_w, im_h]) # 原始图片大小为(100, 100, 3),重设为(192, 192)
  25. #随机调整图像的亮度
  26. image = tf.image.random_brightness(image,max_delta=30)
  27. #随机设置图片的对比度
  28. image = tf.image.random_contrast(image,lower=0.2,upper=1.8)
  29. #随机设置图片的色度
  30. image = tf.image.random_hue(image,max_delta=0.3)
  31. #随机设置图片的饱和度
  32. image = tf.image.random_saturation(image,lower=0.2,upper=1.8)
  33. image = tf.cast(image, dtype=tf.float32) / 255.0
  34. return image
  35. def test_preprocess_image(path):
  36. image = tf.io.read_file(path) # 读取图片
  37. image = tf.image.decode_jpeg(image, channels=3)
  38. # image = tf.image.grayscale_to_rgb(image)
  39. image = tf.image.resize(image, [im_w, im_h]) # 原始图片大小为(100, 100, 3),重设为(192, 192)
  40. image = tf.cast(image, dtype=tf.float32) / 255.0
  41. return image
  42. def preprocess_label(label):
  43. label = tf.cast(label, dtype=tf.int32)
  44. label = tf.one_hot(label, depth=classes)
  45. return label
  46. train_image = []
  47. train_label = []
  48. for image, label in zip(train_image_list, train_label_list):
  49. r_image = train_preprocess_image(image)
  50. r_label = preprocess_label(label)
  51. train_image.append(r_image)
  52. train_label.append(r_label)
  53. train_images = np.array(train_image)
  54. train_labels = np.array(train_label)
  55. print(train_images.shape)
  56. test_image = []
  57. test_label = []
  58. for image, label in zip(test_image_list, test_label_list):
  59. r_image = train_preprocess_image(image)
  60. r_label = preprocess_label(label)
  61. test_image.append(r_image)
  62. test_label.append(r_label)
  63. test_images = np.array(test_image)
  64. test_labels = np.array(test_label)
  65. from tensorflow_core.python.keras.applications.mobilenet import MobileNet
  66. model=MobileNet(input_shape=(im_w, im_h,3),weights=None,include_top=True,classes=classes)
  67. model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])
  68. model.fit(train_images, train_labels,batch_size=batch_size,steps_per_epoch=steps_per_epoch, epochs=epochs)
  69. cost = model.evaluate(test_images, test_labels)
  70. print('test loss: ', cost)
  71. del train_image,train_label,test_images,test_labels
  72. gc.collect()

这里补充一下tensorflow-V1和tensorflow-V2的区别:(简单做了验证实验)

在tensorflow-V1中,tf.image的方法返回的变量类型都是<class 'tensorflow.python.framework.ops.Tensor'>也就是tf.Tensor,直接用print打印的话会得到:Tensor("resize_910/Squeeze:0", shape=(128, 128, 3), dtype=float32)

在tensorflow-V2中,tf.image的方法返回的变量类型是<class 'tensorflow.python.framework.ops.EagerTensor'>也就是tf.EagerTensor,直接用print打印的话会得到:tf.Tensor(arraydata, shape=(128, 128, 3), dtype=float32),其中arraydata是指矩阵的原始数据

对比之后可以发现,tf.EagerTensor中是带有arraydata数据的,因此可以通过np.array或者list,转为可以使用的原始float数据,而tf.Tensor只能是用sess.run之后才能得到原始float数据,这也是为什么tensoflow-V1难调试的原因。

以博主的水平难说好坏,就这样用着先吧~

 

2)np.array(list(float,int))转ImageDataGenerator,然后用model.fit_generator


完整代码:

train_keras_from_kerasgenerator.py

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Fri Jan 29 20:36:16 2021
  4. @author: Leon_PC
  5. """
  6. import tensorflow as tf
  7. import random
  8. import pathlib
  9. from tensorflow import keras
  10. import os
  11. import numpy as np
  12. from my_input_data import input_data_list
  13. import gc
  14. im_w=128
  15. im_h=128
  16. # im_channels=3
  17. train_image_list,train_label_list,test_image_list,test_label_list,label_names=input_data_list('../database1/')
  18. classes=len(label_names)
  19. nums_for_training=len(train_image_list)
  20. batch_size=32
  21. steps_per_epoch=int(nums_for_training/batch_size)
  22. epochs=10
  23. print(nums_for_training)
  24. print(steps_per_epoch)
  25. def train_preprocess_image(path):
  26. image = tf.io.read_file(path) # 读取图片
  27. image = tf.image.decode_jpeg(image, channels=3)
  28. # image = tf.image.grayscale_to_rgb(image)
  29. image = tf.image.resize(image, [im_w, im_h]) # 原始图片大小为(100, 100, 3),重设为(192, 192)
  30. # #随机调整图像的亮度
  31. # image = tf.image.random_brightness(image,max_delta=30)
  32. # #随机设置图片的对比度
  33. # image = tf.image.random_contrast(image,lower=0.2,upper=1.8)
  34. # #随机设置图片的色度
  35. # image = tf.image.random_hue(image,max_delta=0.3)
  36. # #随机设置图片的饱和度
  37. # image = tf.image.random_saturation(image,lower=0.2,upper=1.8)
  38. # image = tf.cast(image, dtype=tf.float32) / 255.0
  39. return image
  40. def test_preprocess_image(path):
  41. image = tf.io.read_file(path) # 读取图片
  42. image = tf.image.decode_jpeg(image, channels=3)
  43. # image = tf.image.grayscale_to_rgb(image)
  44. image = tf.image.resize(image, [im_w, im_h]) # 原始图片大小为(100, 100, 3),重设为(192, 192)
  45. # image = tf.cast(image, dtype=tf.float32) / 255.0
  46. return image
  47. def preprocess_label(label):
  48. label = tf.cast(label, dtype=tf.int32)
  49. label = tf.one_hot(label, depth=classes)
  50. return label
  51. train_image = []
  52. train_label = []
  53. for image, label in zip(train_image_list, train_label_list):
  54. r_image = train_preprocess_image(image)
  55. r_label = preprocess_label(label)
  56. train_image.append(r_image)
  57. train_label.append(r_label)
  58. train_images = np.array(train_image)
  59. train_labels = np.array(train_label)
  60. print(train_images.shape)
  61. test_image = []
  62. test_label = []
  63. for image, label in zip(test_image_list, test_label_list):
  64. r_image = train_preprocess_image(image)
  65. r_label = preprocess_label(label)
  66. test_image.append(r_image)
  67. test_label.append(r_label)
  68. test_images = np.array(test_image)
  69. test_labels = np.array(test_label)
  70. # from keras.utils import np_utils
  71. # y_train = np_utils.to_categorical(train_labels, classes)
  72. # y_test = np_utils.to_categorical(test_labels, classes)
  73. # train_datagen = keras.preprocessing.image.ImageDataGenerator()
  74. train_datagen = keras.preprocessing.image.ImageDataGenerator(
  75. rescale=1./255,
  76. shear_range=0.2,
  77. zoom_range=0.2,
  78. horizontal_flip=True)
  79. test_datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
  80. train_datagen.fit(train_images)
  81. test_datagen.fit(test_images)
  82. from tensorflow_core.python.keras.applications.mobilenet import MobileNet
  83. model=MobileNet(input_shape=(im_w, im_h,3),weights=None,include_top=True,classes=classes)
  84. model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])
  85. model.fit_generator(train_datagen.flow(train_images, train_labels, batch_size=batch_size),steps_per_epoch=steps_per_epoch, epochs=epochs)
  86. cost = model.evaluate(test_datagen.flow(test_images, test_labels))
  87. print('test loss: ', cost)
  88. del train_image,train_label,test_images,test_labels
  89. gc.collect()

 

3)相关补充代码:my_input_data.py

my_input_data.py

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Sat Jan 30 15:08:38 2021
  4. @author: Leon_PC
  5. """
  6. import random
  7. import pathlib
  8. import os
  9. def input_data_list(file_dir,train_ratio=4/5):
  10. # data_path = pathlib.Path('./database1/')
  11. data_path = pathlib.Path(file_dir)
  12. print(type(data_path))#<class 'pathlib.WindowsPath'>
  13. all_image_paths = list(data_path.glob('*/*'))
  14. print(type(data_path.glob('*/*')))#<class 'generator'>
  15. # print(all_image_paths)
  16. # all_image_paths = [str(path) for path in all_image_paths] # 所有图片的相对路径的列表
  17. all_image_paths = [os.path.abspath(path) for path in all_image_paths] # 所有图片的绝对路径的列表
  18. random.shuffle(all_image_paths) # 打散
  19. # print(all_image_paths[0:3])
  20. image_count = len(all_image_paths)
  21. print('image_count: ',image_count)
  22. label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir())
  23. # print('label_names: ',label_names)
  24. label_to_index = dict((name, index) for index, name in enumerate(label_names))
  25. # print('label_to_index: ',label_to_index)
  26. all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
  27. # classes=len(label_names)
  28. print(label_names)
  29. # train_ratio=4/5
  30. nums_for_training=int(len(all_image_paths)*train_ratio)
  31. train_image_list = list(all_image_paths[0:nums_for_training])
  32. train_label_list = list(all_image_labels[0:nums_for_training])
  33. test_image_list = list(all_image_paths[nums_for_training:len(all_image_paths)])
  34. test_label_list = list(all_image_labels[nums_for_training:len(all_image_paths)])
  35. return train_image_list,train_label_list,test_image_list,test_label_list,label_names
  36. if __name__=='__main__':
  37. input_data('../database1/')

 

4)对比flow和flow_from_directory

ImageDataGenerator.flow_from_directory方法,可以直接用分好train和test的文件夹读取数据,转化为ImageDataGenerator

所以还是:database1这个数据集了,博主没有分为train和test,直接将‘../database/’输出flow_from_directory

这个方法适合于提前做好train和test的数据集,但是一般来说应该都是train和test放在一起的然后按比例随机分配吧,不过还是蛮适合新手去用的,前提是用python写好批处理程序自动分好train和test文件夹。

官方给出的例子:https://tensorflow.google.cn/versions/r2.1/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator

完整代码:

train_keras_from_kerasgenerator_fromdirectory.py

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Fri Jan 29 20:36:16 2021
  4. @author: Leon_PC
  5. """
  6. import tensorflow as tf
  7. from tensorflow import keras
  8. im_w=128
  9. im_h=128
  10. # im_channels=3
  11. batch_size=32
  12. classes=10
  13. train_datagen = keras.preprocessing.image.ImageDataGenerator(
  14. rescale=1./255,
  15. shear_range=0.2,
  16. zoom_range=0.2,
  17. horizontal_flip=True)
  18. image_path='../database1/'
  19. train_data_gen = train_datagen.flow_from_directory(directory=image_path,
  20. batch_size=batch_size,
  21. shuffle=True, #打乱数据
  22. target_size=(im_h, im_w),
  23. class_mode='categorical')
  24. print(train_data_gen)#<keras_preprocessing.image.directory_iterator.DirectoryIterator object at 0x00000204624B8550>
  25. nums_for_training=train_data_gen.n
  26. steps_per_epoch=int(nums_for_training/batch_size)
  27. epochs=10
  28. print(nums_for_training)
  29. print(steps_per_epoch)
  30. from tensorflow_core.python.keras.applications.mobilenet import MobileNet
  31. model=MobileNet(input_shape=(im_w, im_h,3),weights=None,include_top=True,classes=classes)
  32. model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])
  33. model.fit_generator(train_data_gen,steps_per_epoch=steps_per_epoch, epochs=epochs,max_queue_size=1,workers=1)

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/123884
推荐阅读
相关标签
  

闽ICP备14008679号