当前位置:   article > 正文

【TF2.0-CNN】数据增强-训练Cats v Dogs模型_cnn提高训练效果

cnn提高训练效果

数据增强主要是通过以下方式获得更多的训练数据: 缩放、拉伸、旋转、剪切、翻转等。

本文将使用ImageDataGenerator的进行数据增强。

【例1】未使用增强

  1. !wget --no-check-certificate \
  2. https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip \
  3. -O /tmp/cats_and_dogs_filtered.zip
  4. import os
  5. import zipfile
  6. import tensorflow as tf
  7. from tensorflow.keras.optimizers import RMSprop
  8. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  9. local_zip = '/tmp/cats_and_dogs_filtered.zip'
  10. zip_ref = zipfile.ZipFile(local_zip, 'r')
  11. zip_ref.extractall('/tmp')
  12. zip_ref.close()
  13. base_dir = '/tmp/cats_and_dogs_filtered'
  14. train_dir = os.path.join(base_dir, 'train')
  15. validation_dir = os.path.join(base_dir, 'validation')
  16. # Directory with our training cat pictures
  17. train_cats_dir = os.path.join(train_dir, 'cats')
  18. # Directory with our training dog pictures
  19. train_dogs_dir = os.path.join(train_dir, 'dogs')
  20. # Directory with our validation cat pictures
  21. validation_cats_dir = os.path.join(validation_dir, 'cats')
  22. # Directory with our validation dog pictures
  23. validation_dogs_dir = os.path.join(validation_dir, 'dogs')
  24. model = tf.keras.models.Sequential([
  25. tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(150, 150, 3)),
  26. tf.keras.layers.MaxPooling2D(2, 2),
  27. tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
  28. tf.keras.layers.MaxPooling2D(2,2),
  29. tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
  30. tf.keras.layers.MaxPooling2D(2,2),
  31. tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
  32. tf.keras.layers.MaxPooling2D(2,2),
  33. tf.keras.layers.Flatten(),
  34. tf.keras.layers.Dense(512, activation='relu'),
  35. tf.keras.layers.Dense(1, activation='sigmoid')
  36. ])
  37. model.compile(loss='binary_crossentropy',
  38. optimizer=RMSprop(lr=1e-4),
  39. metrics=['acc'])
  40. # All images will be rescaled by 1./255
  41. train_datagen = ImageDataGenerator(rescale=1./255)
  42. test_datagen = ImageDataGenerator(rescale=1./255)
  43. # Flow training images in batches of 20 using train_datagen generator
  44. train_generator = train_datagen.flow_from_directory(
  45. train_dir, # This is the source directory for training images
  46. target_size=(150, 150), # All images will be resized to 150x150
  47. batch_size=20,
  48. # Since we use binary_crossentropy loss, we need binary labels
  49. class_mode='binary')
  50. # Flow validation images in batches of 20 using test_datagen generator
  51. validation_generator = test_datagen.flow_from_directory(
  52. validation_dir,
  53. target_size=(150, 150),
  54. batch_size=20,
  55. class_mode='binary')
  56. history = model.fit_generator(
  57. train_generator,
  58. steps_per_epoch=100, # 2000 images = batch_size * steps
  59. epochs=100,
  60. validation_data=validation_generator,
  61. validation_steps=50, # 1000 images = batch_size * steps
  62. verbose=2)
  63. import matplotlib.pyplot as plt
  64. acc = history.history['acc']
  65. val_acc = history.history['val_acc']
  66. loss = history.history['loss']
  67. val_loss = history.history['val_loss']
  68. epochs = range(len(acc))
  69. plt.plot(epochs, acc, 'bo', label='Training accuracy')
  70. plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
  71. plt.title('Training and validation accuracy')
  72. plt.figure()
  73. plt.plot(epochs, loss, 'bo', label='Training Loss')
  74. plt.plot(epochs, val_loss, 'b', label='Validation Loss')
  75. plt.title('Training and validation loss')
  76. plt.legend()
  77. plt.show()

【运行结果】

【解析】

训练集中的准确率接近100%,但验证集中的准确率只有70%,说明严重过拟合(overfiting)了。

解决的办法是通过数据增强获得更多的训练数据。

【例2】数据增加的猫狗模型

  1. '''
  2. !wget - -no - check - certificate \
  3. https: // storage.googleapis.com / mledu - datasets / cats_and_dogs_filtered.zip \
  4. - O / tmp / cats_and_dogs_filtered.zip
  5. '''
  6. import os
  7. import zipfile
  8. import tensorflow as tf
  9. from tensorflow.keras.optimizers import RMSprop
  10. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  11. local_zip = '/tmp/cats_and_dogs_filtered.zip'
  12. zip_ref = zipfile.ZipFile(local_zip, 'r')
  13. zip_ref.extractall('/tmp')
  14. zip_ref.close()
  15. base_dir = '/tmp/cats_and_dogs_filtered'
  16. train_dir = os.path.join(base_dir, 'train')
  17. validation_dir = os.path.join(base_dir, 'validation')
  18. # Directory with our training cat pictures
  19. train_cats_dir = os.path.join(train_dir, 'cats')
  20. # Directory with our training dog pictures
  21. train_dogs_dir = os.path.join(train_dir, 'dogs')
  22. # Directory with our validation cat pictures
  23. validation_cats_dir = os.path.join(validation_dir, 'cats')
  24. # Directory with our validation dog pictures
  25. validation_dogs_dir = os.path.join(validation_dir, 'dogs')
  26. model = tf.keras.models.Sequential([
  27. tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),
  28. tf.keras.layers.MaxPooling2D(2, 2),
  29. tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
  30. tf.keras.layers.MaxPooling2D(2, 2),
  31. tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
  32. tf.keras.layers.MaxPooling2D(2, 2),
  33. tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
  34. tf.keras.layers.MaxPooling2D(2, 2),
  35. tf.keras.layers.Flatten(),
  36. tf.keras.layers.Dense(512, activation='relu'),
  37. tf.keras.layers.Dense(1, activation='sigmoid')
  38. ])
  39. model.compile(loss='binary_crossentropy',
  40. optimizer=RMSprop(lr=1e-4),
  41. metrics=['acc'])
  42. # This code has changed. Now instead of the ImageGenerator just rescaling
  43. # the image, we also rotate and do other operations
  44. # Updated to do image augmentation
  45. train_datagen = ImageDataGenerator(
  46. rescale=1. / 255,
  47. rotation_range=40,
  48. width_shift_range=0.2,
  49. height_shift_range=0.2,
  50. shear_range=0.2,
  51. zoom_range=0.2,
  52. horizontal_flip=True,
  53. fill_mode='nearest')
  54. test_datagen = ImageDataGenerator(rescale=1. / 255)
  55. # Flow training images in batches of 20 using train_datagen generator
  56. train_generator = train_datagen.flow_from_directory(
  57. train_dir, # This is the source directory for training images
  58. target_size=(150, 150), # All images will be resized to 150x150
  59. batch_size=20,
  60. # Since we use binary_crossentropy loss, we need binary labels
  61. class_mode='binary')
  62. # Flow validation images in batches of 20 using test_datagen generator
  63. validation_generator = test_datagen.flow_from_directory(
  64. validation_dir,
  65. target_size=(150, 150),
  66. batch_size=20,
  67. class_mode='binary')
  68. history = model.fit_generator(
  69. train_generator,
  70. steps_per_epoch=100, # 2000 images = batch_size * steps
  71. epochs=100,
  72. validation_data=validation_generator,
  73. validation_steps=50, # 1000 images = batch_size * steps
  74. verbose=2)
  75. import matplotlib.pyplot as plt
  76. acc = history.history['acc']
  77. val_acc = history.history['val_acc']
  78. loss = history.history['loss']
  79. val_loss = history.history['val_loss']
  80. epochs = range(len(acc))
  81. plt.plot(epochs, acc, 'bo', label='Training accuracy')
  82. plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
  83. plt.title('Training and validation accuracy')
  84. plt.figure()
  85. plt.plot(epochs, loss, 'bo', label='Training Loss')
  86. plt.plot(epochs, val_loss, 'b', label='Validation Loss')
  87. plt.title('Training and validation loss')
  88. plt.legend()
  89. plt.show()

与例1的区别是ImageDataGenerator的参数:

  1. train_datagen = ImageDataGenerator(
  2. rescale=1./255,
  3. rotation_range=40,
  4. width_shift_range=0.2,
  5. height_shift_range=0.2,
  6. shear_range=0.2,
  7. zoom_range=0.2,
  8. horizontal_flip=True,
  9. fill_mode='nearest')

其中:

  •     rescale:缩放
  •     rotation_range: 旋转
  •     width_shift_range: 宽度拉伸
  •     height_shift_range:高度拉伸
  •     shear_range: 切变
  •     zoom_range:剪切
  •     horizontal_flip:水平翻转
  •     fill_mode:填充模式

【运行结果】

【解析】

训练集与验证集的准确率相差不大,效果比例1有明显改善。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/283634?site
推荐阅读
相关标签
  

闽ICP备14008679号