当前位置:   article > 正文

Keras中的数据增强_keras数据增强

keras数据增强

深度网络需要大量的训练数据才能达到良好的性能。为了用很少的训练数据构建强大的图像分类器,通常需要图像增强来提高深度网络的性能。图像增强通过不同的处理方式或多种处理的组合,如随机旋转、平移、剪切、翻转等,人工生成训练图像。

  • 在Keras中使用imagedatgenerator生成增强图像
  • 使用对比度拉伸,直方图均衡化,自适应直方图均衡化生成自定义增强图像
  • 利用图像增强技术在CIFAR-10数据集上训练卷积神经网络

1.导入相关库

from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
from skimage.io import imread
from skimage import exposure, color
from skimage.transform import resize

import keras
from keras import backend as K
from keras.datasets import cifar10
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
# 设置GPU按需分配
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

2. 使用Keras的ImageDataGenerator生成增强图像

(1)定义生成增强图像函数。

def imgGen(img, zca=False, rotation=0., w_shift=0., h_shift=0., shear=0., zoom=0., h_flip=False, v_flip=False,  preprocess_fcn=None, batch_size=9):
    datagen = ImageDataGenerator(
            zca_whitening=zca,
            rotation_range=rotation,
            width_shift_range=w_shift,
            height_shift_range=h_shift,
            shear_range=shear,
            zoom_range=zoom,
            fill_mode='nearest',
            horizontal_flip=h_flip,
            vertical_flip=v_flip,
            preprocessing_function=preprocess_fcn,
            data_format=K.image_data_format())
    
    datagen.fit(img)

    i=0
    for img_batch in datagen.flow(img, batch_size=9, shuffle=False):
        for img in img_batch:
            plt.subplot(330 + 1 + i)
            plt.imshow(img)
            i=i+1    
        if i >= batch_size:
            break
    plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

(2)调用上面函数,实现数据增强

# 可视化原始图像
img = imread("img/cat.jpg")
plt.imshow(img)
plt.show()

# reshape原始图像,准备数据增强
img = img.astype('float32')
img /= 255
h_dim = np.shape(img)[0]
w_dim = np.shape(img)[1]
num_channel = np.shape(img)[2]
img = img.reshape(1, h_dim, w_dim, num_channel)
print(img.shape)

# 数据增强
imgGen(img, rotation=30, h_shift=0.5)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

在这里插入图片描述
在这里插入图片描述

3. 使用对比度调整生成增强数据

除了Keras中的ImageDataGenerator类提供的标准数据增强技术外,我们还可以使用自定义函数来生成增强的图像。例如,您可能想要使用对比度拉伸来调整图像的对比度。对比度拉伸是一种简单的图像处理技术,它通过将图像的亮度值范围缩放(“拉伸”)到所需的范围来增强对比度。

直方图均衡化是另一种利用图像直方图增加图像整体对比度的图像处理技术。均衡化后的图像具有线性累积分布函数。这种方法不需要参数,但有时会产生看起来不自然的图像。

另一种方法是自适应直方图均衡化(AHE),它通过计算对应于图像不同部分的几个直方图来改善图像的局部对比度(不同于普通直方图均衡化只使用一个直方图来调整全局对比度),并使用它们来调整局部对比度。然而,AHE在图像相对均匀的区域有过度放大噪声的趋势。

为了防止AHE导致的噪声放大问题,提出了一种对比度限制自适应直方图均衡化(CLAHE)算法。在gist中,在计算累积分布函数之前,通过将直方图裁剪到预定义值来限制AHE的对比度增强。

为了在Keras中实现用于图像增强的自定义预处理函数,首先定义我们的自定义函数,并将其作为参数传递给ImageDataGenerator。
(1)定义对比度调整函数。

# 对比度拉伸
def contrast_stretching(img):
    p2, p98 = np.percentile(img, (2, 98))
    img_rescale = exposure.rescale_intensity(img, in_range=(p2, p98))
    return img_rescale
# 直方图均衡化
def HE(img):
    img_eq = exposure.equalize_hist(img)
    return img_eq
# 自适应的直方图均衡化
def AHE(img):
    img_adapteq = exposure.equalize_adapthist(img, clip_limit=0.03)
    return img_adapteq
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

(2)调用上面函数,实现数据增强

imgGen(img, rotation=30, h_shift=0.5, preprocess_fcn = contrast_stretching)
imgGen(img, rotation=30, h_shift=0.5, preprocess_fcn = HE)
imgGen(img, rotation=30, h_shift=0.5, preprocess_fcn = AHE)
  • 1
  • 2
  • 3

原图
对比度拉伸
直方图均衡化
自适应直方图均衡化

4. 训练CIFAR10数据集

(1)加载和预处理CIFAR10图像数据

batch_size = 256
num_classes = 2
epochs = 10

# 输入图像的尺寸
img_rows, img_cols = 32, 32   

# 导入数据
(x_train, y_train), (x_test, y_test) = cifar10.load_data()   
print('x_train shape:', x_train.shape)

#只选择cats [=3] 和 dogs [=5]的数据
train_picks = np.ravel(np.logical_or(y_train==3,y_train==5))  
test_picks = np.ravel(np.logical_or(y_test==3,y_test==5))     

y_train = np.array(y_train[train_picks]==5,dtype=int)
y_test = np.array(y_test[test_picks]==5,dtype=int)

x_train = x_train[train_picks]
x_test = x_test[test_picks]

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 3, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 3, img_rows, img_cols)
    input_shape = (3, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 3)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 3)
    input_shape = (img_rows, img_cols, 3)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# one-hot编码
y_train = keras.utils.to_categorical(np.ravel(y_train), num_classes)
y_test = keras.utils.to_categorical(np.ravel(y_test), num_classes)

# x_train shape: (50000, 32, 32, 3)
# x_train shape: (10000, 32, 32, 3)
# 10000 train samples
# 2000 test samples
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

(2)构建网络

model = Sequential()
model.add(Conv2D(4, kernel_size=(3, 3),activation='relu',input_shape=input_shape))
model.add(Conv2D(8, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(16, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(2, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

(3)使用数据增强来训练模型

augmentation=True

if augmentation==True:
    datagen = ImageDataGenerator(
            rotation_range=0,
            width_shift_range=0,
            height_shift_range=0,
            shear_range=0,
            zoom_range=0,
            horizontal_flip=True,
            fill_mode='nearest',
#             preprocessing_function = contrast_adjusment,
#             preprocessing_function = HE,
            preprocessing_function = AHE)

    datagen.fit(x_train)
    
    print("Running augmented training now, with augmentation")
    history = model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
                    steps_per_epoch=x_train.shape[0] // batch_size,
                    epochs=epochs,
                    validation_data=(x_test, y_test))
else:
    print("Running regular training, no augmentation")
    history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, y_test))
print(history.history.keys())
# Running augmented training now, with augmentation
# Epoch 1/10
# 39/39 [==============================] - 31s 785ms/step - loss: 0.7410 - accuracy: 0.4965 - val_loss: 0.6971 - val_accuracy: 0.4925
# Epoch 2/10
# 39/39 [==============================] - 31s 804ms/step - loss: 0.7357 - accuracy: 0.5021 - val_loss: 0.6969 - val_accuracy: 0.4970
# Epoch 3/10
# 39/39 [==============================] - 31s 807ms/step - loss: 0.7324 - accuracy: 0.5018 - val_loss: 0.6967 - val_accuracy: 0.4975
# Epoch 4/10
# 39/39 [==============================] - 31s 803ms/step - loss: 0.7334 - accuracy: 0.4933 - val_loss: 0.6965 - val_accuracy: 0.4950
# Epoch 5/10
# 39/39 [==============================] - 32s 811ms/step - loss: 0.7319 - accuracy: 0.5026 - val_loss: 0.6963 - val_accuracy: 0.4955
# Epoch 6/10
# 39/39 [==============================] - 31s 799ms/step - loss: 0.7309 - accuracy: 0.4995 - val_loss: 0.6962 - val_accuracy: 0.4970
# Epoch 7/10
# 39/39 [==============================] - 31s 807ms/step - loss: 0.7331 - accuracy: 0.4944 - val_loss: 0.6961 - val_accuracy: 0.4970
# Epoch 8/10
# 39/39 [==============================] - 32s 809ms/step - loss: 0.7357 - accuracy: 0.4845 - val_loss: 0.6960 - val_accuracy: 0.4975
# Epoch 9/10
# 39/39 [==============================] - 31s 807ms/step - loss: 0.7284 - accuracy: 0.5025 - val_loss: 0.6959 - val_accuracy: 0.4955
# Epoch 10/10
# 39/39 [==============================] - 32s 814ms/step - loss: 0.7226 - accuracy: 0.5028 - val_loss: 0.6958 - val_accuracy: 0.4945
# dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52

(4)可视化训练结果

plt.plot(history.epoch,history.history['val_accuracy'],'-o',label='validation')
plt.plot(history.epoch,history.history['accuracy'],'-o',label='training')

plt.legend(loc=0)
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.grid(True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在这里插入图片描述

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/148860
推荐阅读
相关标签
  

闽ICP备14008679号