当前位置:   article > 正文

快速理解【TensorFlow2.0】ImageDataGenerator使用

imagedatagenerator

【TensorFlow2.0】ImageDataGenerator使用

一、作用:

ImageDataGenerator()是keras.preprocessing.image模块中的图片生成器,同时也可以在batch中对数据进行增强,扩充数据集大小,增强模型的泛化能力。比如进行旋转,变形,归一化等等。还可以自动为训练数据生成标签。
总结起来就是三个点:
(1)图片生成器,负责生成一个批次一个批次的图片,以生成器的形式给模型训练;
(2)对每一个批次的训练图片,适时地进行数据增强处理(data augmentation);
(3)自动为训练数据生成标签

二、ImageDataGenerator简单介绍:

ImageDataGenerator()参数
featurewise_center: Boolean. 对输入的图片每个通道减去每个通道对应均值。它针对的是数据集dataset,
samplewise_center: Boolan. 每张图片减去样本均值, 使得每个样本均值为0。它针对的是单个输入图片的本身
featurewise_std_normalization(): Boolean()
samplewise_std_normalization(): Boolean()
zca_epsilon(): Default 12-6
zca_whitening: Boolean. 去除样本之间的相关性
rotation_range(): 旋转范围
width_shift_range(): 水平平移范围
height_shift_range(): 垂直平移范围
shear_range(): float, 透视变换的范围
zoom_range(): 缩放范围
fill_mode: 填充模式, constant, nearest, reflect
cval: fill_mode == 'constant'的时候填充值
horizontal_flip(): 水平反转
vertical_flip(): 垂直翻转
preprocessing_function(): user提供的处理函数
data_format(): channels_first或者channels_last
validation_split(): 多少数据用于验证集
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

ImageDataGenerator()方法:

apply_transform(x, transform_parameters):根据参数对x进行变换
fit(x, augment=False, rounds=1, seed=None): 将生成器用于数据x,从数据x中获得样本的统计参数, 只有featurewise_center, featurewise_std_normalization或者zca_whitening为True才需要
flow(x, y=None, batch_size=32, shuffle=True, sample_weight=None, seed=None, save_to_dir=None, save_prefix='', save_format='png', subset=None) ):按batch_size大小从x,y生成增强数据
flow_from_directory()从路径生成增强数据,和flow方法相比最大的优点在于不用一次将所有的数据读入内存当中,这样减小内存压力,这样不会发生OOM,血的教训。
get_random_transform(img_shape, seed=None): 返回包含随机图像变换参数的字典
random_transform(x, seed=None): 进行随机图像变换, 通过设置seed可以达到同步变换。
standardize(x): 对x进行归一化
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

三、ImageDataGenerator方法介绍以及处理流程

1、方法介绍

ImageDataGenerator的方法包括以下七个,其中常用的flow_from_directory()、fit()以及flow()

apply_transform(x, transform_parameters):
fit(x, augment=False, rounds=1, seed=None):
flow(x, y=None, batch_size=32, shuffle=True, sample_weight=None, seed=None, save_to_dir=None, save_prefix='', save_format='png', subset=None) )
flow_from_directory()
get_random_transform(img_shape, seed=None):
random_transform(x, seed=None)
standardize(x): 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

1、一般的对图像的处理流程(四步走)
第一步:数据集的划分,得到x_train,y_train,x_test,y_test;

x_train, y_train), (x_test, y_test) = mnist.load_data()
 
x_train = np.expand_dims(x_train, axis = 3)
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)
  • 1
  • 2
  • 3
  • 4
  • 5

第二步:构造ImageDataGenerator对象,其中要进行某一些操作是通过在构造函数中的参数
指定的,datagen = ImageDataGenerator(…),例如:

#对数据进行预处理,注意这里不是一次性要将所有的数据进行处理完,而是在后面的代码中进行逐批次处理
train_data = image.ImageDataGenerator(
     featurewise_center=True,
     featurewise_std_normalization=True
     rescale=1./255,  # 对图片的每个像素值均乘上这个放缩因子,把像素值放缩到0和1之间有利于模型的收敛
     shear_range=0.1, # 浮点数,剪切强度(逆时针方向的剪切变换角度)
     zoom_range=0.1, # 随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]
     width_shift_range=0.1, # 浮点数,图片宽度的某个比例,数据提升时图片水平偏移的幅度
     height_shift_range=0.1, # 浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度
     horizontal_flip=True,# 布尔值,进行随机水平翻转
     vertical_flip=True, # 布尔值,进行随机竖直翻转
     validation_split=0.1# 在 0 和 1 之间浮动。用作验证集的训练数据的比例
         )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

第三步:对样本数据进行data augmentation处理,通过fit方法。

train_data .fit(x_train) # 使用实时数据增益的批数据对模型进行拟合
这一步并不是必须的.当ImageDataGenerator构造函数中需要使用以下四个参数时,
featurewise_center
samplewise_center
featurewise_std_normalization
samplewise_std_normaliza
才需要使用fit方法,因为需要从fit方法中得到原始图形的统计信息,
比如均值、方差等等,否则是不需要改步骤的。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

第四步:进行训练,通过flow方法或flow_from_directory()

#使用.flow方法构造Iterator
data_iter = train_data.flow(x_train, y_train,
                        batch_size=8,
                        save_to_dir="save_data"
                         )  #返回的是一个“生成器对象”
 

#对目标目录下的数据进行数据扩增,
data_iter = train_data.flow_from_directory(base_path,
        target_size=(300,300),#所有的图像将被调整到的target_size尺寸。
        batch_size=8, 每一次处理8张图片
        class_mode="categorical", #对类型进行热编码:"categorical",返回one-hot 编码标签。
        subset="training", # 数据子集 ("training" 或 "validation")
        save_prefix="sjl_",将augmentation之后的图片添加sjl_前缀
        seed=0)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

值得注意的是:class_mode,
如果class_mode==“binary”,则输出的类别是0,1,2,来分别表示三个类,
如果是class_mode==“'categorical”,则这三个类别会用独热编码的方式来表示,即【0,0,1】、【0,1,0】、【1,0,0】.

四、ImageDataGenerator实例

import matplotlib.pyplot as  plt
import glob
from PIL import Image
from keras.preprocessing import image

path = './dogs_cats_sample/'
gen_path = './dogs_cats_gen/'


def print_result(path):
    name_list = glob.glob(path)
    fig = plt.figure()
    for i in range(9):
        img = Image.open(name_list[i])
        # add_subplot(331) 参数一:子图总行数,参数二:子图总列数,参数三:子图位置
        sub_img = fig.add_subplot(331 + i)
        sub_img.imshow(img)
    plt.show()
    return fig


# 打印图片列表
name_list = glob.glob(path + '*/*')
print(name_list)

# 打印图片
fig = print_result(path + '*/*')

# 保存图片
fig.savefig(gen_path + '/original_0.png', dpi=200, papertype='a5')

# 原图
datagen = image.ImageDataGenerator(
            rescale=1./255,  # 对图片的每个像素值均乘上这个放缩因子,把像素值放缩到0和1之间有利于模型的收敛
            shear_range=0.1, # 浮点数,剪切强度(逆时针方向的剪切变换角度)
            zoom_range=0.1, # 随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]
            width_shift_range=0.1, # 浮点数,图片宽度的某个比例,数据提升时图片水平偏移的幅度
            height_shift_range=0.1, # 浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度
            horizontal_flip=True,# 布尔值,进行随机水平翻转
            vertical_flip=True, # 布尔值,进行随机竖直翻转
            validation_split=0.1
            )
gen_data = datagen.flow_from_directory(
            path,
            batch_size=1, 
            shuffle=False, 
            save_to_dir=gen_path,
            save_prefix='dog_gen',#对处理后图片设置前缀。 
            target_size=(224, 224)
            )
for i in range(9):
    gen_data.next()

fig = print_result(gen_path + '/*')
fig.savefig(gen_path + '/original_1.png', dpi=200, papertype='a5')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55

ImageDataGenerator里面的路径结构需要注意
在这里插入图片描述
原图
在这里插入图片描述
数据增强后图
在这里插入图片描述

参考博客:
https://blog.csdn.net/qq_27825451/article/details/90172030
https://blog.csdn.net/qq_27825451/article/details/90056896
https://www.jianshu.com/p/d23b5994db64
https://blog.csdn.net/u012193416/article/details/79368855

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/123891
推荐阅读
相关标签
  

闽ICP备14008679号