当前位置:   article > 正文

Keras CIFAR-10图像分类 ResNet 篇_keras入门(五)搭建resnet对cifar-10进行图像分类

keras入门(五)搭建resnet对cifar-10进行图像分类

Keras CIFAR-10图像分类 ResNet 篇

除了用pytorch可以进行图像分类之外,我们也可以利用tensorflow来进行图像分类,其中利用tensorflow的后端keras更是尤为简单,接下来我们就利用keras对CIFAR10数据集进行分类。

keras介绍

在这里插入图片描述

keras是python深度学习中常用的一个学习框架,它有着极其强大的功能,基本能用于常用的各个模型。

keras具有的特性

1、相同的代码可以在cpu和gpu上切换;
2、在模型定义上,可以用函数式API,也可以用Sequential类;
3、支持任意网络架构,如多输入多输出;
4、能够使用卷积网络、循环网络及其组合。

keras与后端引擎

Keras 是一个模型级的库,在开发中只用做高层次的操作,不处于张量计算,微积分计算等低级操作。但是keras最终处理数据时数据都是以张量形式呈现,不处理张量操作的keras是如何解决张量运算的呢?

keras依赖于专门处理张量的后端引擎,关于张量运算方面都是通过后端引擎完成的。这也就是为什么下载keras时需要下载TensorFlow 或者Theano的原因。而TensorFlow 、Theano、以及CNTK都属于处理数值张量的后端引擎。
在这里插入图片描述

keras设计原则

  • 用户友好:Keras是为人类而不是天顶星人设计的API。用户的使用体验始终是我们考虑的首要和中心内容。Keras遵循减少认知困难的最佳实践:Keras提供一致而简洁的API, 能够极大减少一般应用下用户的工作量,同时,Keras提供清晰和具有实践意义的bug反馈。
  • 模块性:模型可理解为一个层的序列或数据的运算图,完全可配置的模块可以用最少的代价自由组合在一起。具体而言,网络层、损失函数、优化器、初始化策略、激活函数、正则化方法都是独立的模块,你可以使用它们来构建自己的模型。
  • 易扩展性:添加新模块超级容易,只需要仿照现有的模块编写新的类或函数即可。创建新模块的便利性使得Keras更适合于先进的研究工作。
  • 与Python协作:Keras没有单独的模型配置文件类型(作为对比,caffe有),模型由python代码描述,使其更紧凑和更易debug,并提供了扩展的便利性。

在这里插入图片描述

安装keras

安装也是很简单的,我们直接安装keras即可,如果需要tensorflow,就还需要安装tensorflow

pip install keras
  • 1

导入库

import keras
from keras.models import Sequential
from keras.datasets import cifar10
from keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten, Dense, Activation
from keras.optimizers import adam_v2
from keras.utils.vis_utils import plot_model
from keras.utils.np_utils import to_categorical
from keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
import numpy as np
import os
import shutil
import matplotlib
matplotlib.style.use('ggplot')
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

控制GPU显存(可选)

这个是tensorflow来控制选择的GPU,因为存在多卡的时候可以指定GPU,其次还可以控制GPU的显存

这段语句就是动态显存,动态分配显存

config.gpu_options.allow_growth = True
  • 1

这段语句就是说明,我们使用的最大显存不能超过50%

config.gpu_options.per_process_gpu_memory_fraction = 0.5
  • 1
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 忽略低级别的警告
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# The GPU id to use, usually either "0" or "1"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
config = tf.compat.v1.ConfigProto()
# config = tf.ConfigProto()
# config.gpu_options.per_process_gpu_memory_fraction = 0.5
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

加载 CIFAR-10 数据集

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( arplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。

与 MNIST 数据集中目比, CIFAR-10 具有以下不同点:

  • CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。
  • CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。
  • 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。

在这里插入图片描述

num_classes = 10  # 有多少个类别
  • 1
(x_train, y_train), (x_val, y_val) = cifar10.load_data()
  • 1
print("训练集的维度大小:",x_train.shape)
print("验证集的维度大小:",x_val.shape)
  • 1
  • 2
训练集的维度大小: (50000, 32, 32, 3)
验证集的维度大小: (10000, 32, 32, 3)
  • 1
  • 2

可视化数据

class_names = ['airplane','automobile','bird','cat','deer',
               'dog','frog','horse','ship','truck']
fig = plt.figure(figsize=(20,5))
for i in range(num_classes):
    ax = fig.add_subplot(2, 5, 1 + i, xticks=[], yticks=[])
    idx = np.where(y_train[:]==i)[0] # 取得类别样本
    features_idx = x_train[idx,::] # 取得图片
    img_num = np.random.randint(features_idx.shape[0]) # 随机挑选图片
    im = features_idx[img_num,::]
    ax.set_title(class_names[i])
    plt.imshow(im)
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

在这里插入图片描述

数据预处理

x_train = x_train.astype('float32')/255
x_val = x_val.astype('float32')/255
  • 1
  • 2
# 将向量转化为二分类矩阵,也就是one-hot编码
y_train = to_categorical(y_train, num_classes)
y_val = to_categorical(y_val, num_classes)
  • 1
  • 2
  • 3
output_dir = './output'  # 输出目录
if os.path.exists(output_dir) is False:
    os.mkdir(output_dir)
#     shutil.rmtree(output_dir)
#     print('%s文件夹已存在,但是没关系,我们删掉了' % output_dir)
#     os.mkdir(output_dir)
    print('%s已创建' % output_dir)
print('%s文件夹已存在' % output_dir)
model_name = 'resnet'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
./output已创建
./output文件夹已存在
  • 1
  • 2

ResNet网络

当大家还在惊叹 GoogLeNet 的 inception 结构的时候,微软亚洲研究院的研究员已经在设计更深但结构更加简单的网络 ResNet,并且凭借这个网络斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名。

如果想详细了解并查看论文,可以看我的另一篇博客【论文泛读】 ResNet:深度残差网络

下图是ResNet18层模型的结构简图

在这里插入图片描述

还有ResNet-34模型
在这里插入图片描述

在ResNet网络中有如下几个亮点:

(1)提出residual结构(残差结构),并搭建超深的网络结构(突破1000层)

(2)使用Batch Normalization加速训练(丢弃dropout)

在ResNet网络提出之前,传统的卷积神经网络都是通过将一系列卷积层与下采样层进行堆叠得到的。但是当堆叠到一定网络深度时,就会出现两个问题。

(1)梯度消失或梯度爆炸。

(2)退化问题(degradation problem)。

残差结构

在ResNet论文中说通过数据的预处理以及在网络中使用BN(Batch Normalization)层能够解决梯度消失或者梯度爆炸问题,residual结构(残差结构)来减轻退化问题。此时拟合目标就变为F(x),F(x)就是残差
在这里插入图片描述

这里有一个点是很重要的,对于我们的第二个layer,它是没有relu激活函数的,他需要与x相加最后再进行激活函数relu

ResNet18/34 的Residual结构

我们先对ResNet18/34的残差结构进行一个分析。如下图所示,该残差结构的主分支是由两层3x3的卷积层组成,而残差结构右侧的连接线是shortcut分支也称捷径分支(注意为了让主分支上的输出矩阵能够与我们捷径分支上的输出矩阵进行相加,必须保证这两个输出特征矩阵有相同的shape)。我们会发现有一些虚线结构,论文中表述为用1x1的卷积进行降维,下图给出了详细的残差结构。

在这里插入图片描述

ResNet50/101/152的Bottleneck结构

接着我们再来分析下针对ResNet50/101/152的残差结构,如下图所示。在该残差结构当中,主分支使用了三个卷积层,第一个是1x1的卷积层用来压缩channel维度,第二个是3x3的卷积层,第三个是1x1的卷积层用来还原channel维度(注意主分支上第一层卷积层和第二次卷积层所使用的卷积核个数是相同的,第三次是第一层的4倍),这种又叫做bottleneck模型
在这里插入图片描述

ResNet网络结构配置

这是在ImageNet数据集中更深的残差网络的模型,这里面给出了残差结构给出了主分支上卷积核的大小与卷积核个数,表中的xN表示将该残差结构重复N次。

在这里插入图片描述

对于我们ResNet18/34/50/101/152,表中conv3_x, conv4_x, conv5_x所对应的一系列残差结构的第一层残差结构都是虚线残差结构。因为这一系列残差结构的第一层都有调整输入特征矩阵shape的使命(将特征矩阵的高和宽缩减为原来的一半,将深度channel调整成下一层残差结构所需要的channel)

  • ResNet-50:我们用3层瓶颈块替换34层网络中的每一个2层块,得到了一个50层ResNe。我们使用1x1卷积核来增加维度。该模型有38亿FLOP
  • ResNet-101/152:我们通过使用更多的3层瓶颈块来构建101层和152层ResNets。值得注意的是,尽管深度显著增加,但152层ResNet(113亿FLOP)仍然比VGG-16/19网络(153/196亿FLOP)具有更低的复杂度。
input_shape = (32,32,3)
  • 1
from keras.layers import BatchNormalization, AveragePooling2D, Input
from keras.models import Model
from keras.regularizers import l2
from keras import layers
def conv2d_bn(x, filters, kernel_size, weight_decay=.0, strides=(1, 1)):
    layer = Conv2D(filters=filters,
                   kernel_size=kernel_size,
                   strides=strides,
                   padding='same',
                   use_bias=False,
                   kernel_regularizer=l2(weight_decay)
                   )(x)
    layer = BatchNormalization()(layer)
    return layer


def conv2d_bn_relu(x, filters, kernel_size, weight_decay=.0, strides=(1, 1)):
    layer = conv2d_bn(x, filters, kernel_size, weight_decay, strides)
    layer = Activation('relu')(layer)
    return layer

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
def ResidualBlock(x, filters, kernel_size, weight_decay, downsample=True):
    if downsample:
        # residual_x = conv2d_bn_relu(x, filters, kernel_size=1, strides=2)
        residual_x = conv2d_bn(x, filters, kernel_size=1, strides=2)
        stride = 2
    else:
        residual_x = x
        stride = 1
    residual = conv2d_bn_relu(x,
                              filters=filters,
                              kernel_size=kernel_size,
                              weight_decay=weight_decay,
                              strides=stride,
                              )
    residual = conv2d_bn(residual,
                         filters=filters,
                         kernel_size=kernel_size,
                         weight_decay=weight_decay,
                         strides=1,
                         )
    out = layers.add([residual_x, residual])
    out = Activation('relu')(out)
    return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
def ResNet18(classes, input_shape, weight_decay=1e-4):
    input = Input(shape=input_shape)
    x = input
    # x = conv2d_bn_relu(x, filters=64, kernel_size=(7, 7), weight_decay=weight_decay, strides=(2, 2))
    # x = MaxPool2D(pool_size=(3, 3), strides=(2, 2),  padding='same')(x)
    x = conv2d_bn_relu(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, strides=(1, 1))

    # # conv 2
    x = ResidualBlock(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False)
    x = ResidualBlock(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False)
    # # conv 3
    x = ResidualBlock(x, filters=128, kernel_size=(3, 3), weight_decay=weight_decay, downsample=True)
    x = ResidualBlock(x, filters=128, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False)
    # # conv 4
    x = ResidualBlock(x, filters=256, kernel_size=(3, 3), weight_decay=weight_decay, downsample=True)
    x = ResidualBlock(x, filters=256, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False)
    # # conv 5
    x = ResidualBlock(x, filters=512, kernel_size=(3, 3), weight_decay=weight_decay, downsample=True)
    x = ResidualBlock(x, filters=512, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False)
    x = AveragePooling2D(pool_size=(4, 4), padding='valid')(x)
    x = Flatten()(x)
    x = Dense(classes, activation='softmax')(x)
    model = Model(input, x, name='ResNet18')
    return model


def ResNetForCIFAR10(classes, name, input_shape, block_layers_num, weight_decay):
    input = Input(shape=input_shape)
    x = input
    x = conv2d_bn_relu(x, filters=16, kernel_size=(3, 3), weight_decay=weight_decay, strides=(1, 1))

    # # conv 2
    for i in range(block_layers_num):
        x = ResidualBlock(x, filters=16, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False)
    # # conv 3
    x = ResidualBlock(x, filters=32, kernel_size=(3, 3), weight_decay=weight_decay, downsample=True)
    for i in range(block_layers_num - 1):
        x = ResidualBlock(x, filters=32, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False)
    # # conv 4
    x = ResidualBlock(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, downsample=True)
    for i in range(block_layers_num - 1):
        x = ResidualBlock(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False)
    x = AveragePooling2D(pool_size=(8, 8), padding='valid')(x)
    x = Flatten()(x)
    x = Dense(classes, activation='softmax')(x)
    model = Model(input, x, name=name)
    return model


def ResNet20ForCIFAR10(classes, input_shape, weight_decay):
    return ResNetForCIFAR10(classes, 'resnet20', input_shape, 3, weight_decay)


def ResNet32ForCIFAR10(classes, input_shape, weight_decay):
    return ResNetForCIFAR10(classes, 'resnet32', input_shape, 5, weight_decay)


def ResNet56ForCIFAR10(classes, input_shape, weight_decay):
    return ResNetForCIFAR10(classes, 'resnet56', input_shape, 9, weight_decay)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
input_shape = (32,32,3)
  • 1
weight_decay = 1e-4
model = ResNet32ForCIFAR10(input_shape=(32, 32, 3), classes=num_classes, weight_decay=weight_decay)
model.summary()
  • 1
  • 2
  • 3
Model: "resnet32"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 32, 32, 16)   432         ['input_1[0][0]']                
                                                                                                  
 batch_normalization (BatchNorm  (None, 32, 32, 16)  64          ['conv2d[0][0]']                 
 alization)                                                                                       
                                                                                                  
 activation (Activation)        (None, 32, 32, 16)   0           ['batch_normalization[0][0]']    
                                                                                                  
 conv2d_1 (Conv2D)              (None, 32, 32, 16)   2304        ['activation[0][0]']             
                                                                                                  
 batch_normalization_1 (BatchNo  (None, 32, 32, 16)  64          ['conv2d_1[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 activation_1 (Activation)      (None, 32, 32, 16)   0           ['batch_normalization_1[0][0]']  
                                                                                                  
 conv2d_2 (Conv2D)              (None, 32, 32, 16)   2304        ['activation_1[0][0]']           
                                                                                                  
 batch_normalization_2 (BatchNo  (None, 32, 32, 16)  64          ['conv2d_2[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 add (Add)                      (None, 32, 32, 16)   0           ['activation[0][0]',             
                                                                  'batch_normalization_2[0][0]']  
                                                                                                  
 activation_2 (Activation)      (None, 32, 32, 16)   0           ['add[0][0]']                    
                                                                                                  
 conv2d_3 (Conv2D)              (None, 32, 32, 16)   2304        ['activation_2[0][0]']           
                                                                                                  
 batch_normalization_3 (BatchNo  (None, 32, 32, 16)  64          ['conv2d_3[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 activation_3 (Activation)      (None, 32, 32, 16)   0           ['batch_normalization_3[0][0]']  
                                                                                                  
 conv2d_4 (Conv2D)              (None, 32, 32, 16)   2304        ['activation_3[0][0]']           
                                                                                                  
 batch_normalization_4 (BatchNo  (None, 32, 32, 16)  64          ['conv2d_4[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 add_1 (Add)                    (None, 32, 32, 16)   0           ['activation_2[0][0]',           
                                                                  'batch_normalization_4[0][0]']  
                                                                                                  
 activation_4 (Activation)      (None, 32, 32, 16)   0           ['add_1[0][0]']                  
                                                                                                  
 conv2d_5 (Conv2D)              (None, 32, 32, 16)   2304        ['activation_4[0][0]']           
                                                                                                  
 batch_normalization_5 (BatchNo  (None, 32, 32, 16)  64          ['conv2d_5[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 activation_5 (Activation)      (None, 32, 32, 16)   0           ['batch_normalization_5[0][0]']  
                                                                                                  
 conv2d_6 (Conv2D)              (None, 32, 32, 16)   2304        ['activation_5[0][0]']           
                                                                                                  
 batch_normalization_6 (BatchNo  (None, 32, 32, 16)  64          ['conv2d_6[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 add_2 (Add)                    (None, 32, 32, 16)   0           ['activation_4[0][0]',           
                                                                  'batch_normalization_6[0][0]']  
                                                                                                  
 activation_6 (Activation)      (None, 32, 32, 16)   0           ['add_2[0][0]']                  
                                                                                                  
 conv2d_7 (Conv2D)              (None, 32, 32, 16)   2304        ['activation_6[0][0]']           
                                                                                                  
 batch_normalization_7 (BatchNo  (None, 32, 32, 16)  64          ['conv2d_7[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 activation_7 (Activation)      (None, 32, 32, 16)   0           ['batch_normalization_7[0][0]']  
                                                                                                  
 conv2d_8 (Conv2D)              (None, 32, 32, 16)   2304        ['activation_7[0][0]']           
                                                                                                  
 batch_normalization_8 (BatchNo  (None, 32, 32, 16)  64          ['conv2d_8[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 add_3 (Add)                    (None, 32, 32, 16)   0           ['activation_6[0][0]',           
                                                                  'batch_normalization_8[0][0]']  
                                                                                                  
 activation_8 (Activation)      (None, 32, 32, 16)   0           ['add_3[0][0]']                  
                                                                                                  
 conv2d_9 (Conv2D)              (None, 32, 32, 16)   2304        ['activation_8[0][0]']           
                                                                                                  
 batch_normalization_9 (BatchNo  (None, 32, 32, 16)  64          ['conv2d_9[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 activation_9 (Activation)      (None, 32, 32, 16)   0           ['batch_normalization_9[0][0]']  
                                                                                                  
 conv2d_10 (Conv2D)             (None, 32, 32, 16)   2304        ['activation_9[0][0]']           
                                                                                                  
 batch_normalization_10 (BatchN  (None, 32, 32, 16)  64          ['conv2d_10[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_4 (Add)                    (None, 32, 32, 16)   0           ['activation_8[0][0]',           
                                                                  'batch_normalization_10[0][0]'] 
                                                                                                  
 activation_10 (Activation)     (None, 32, 32, 16)   0           ['add_4[0][0]']                  
                                                                                                  
 conv2d_12 (Conv2D)             (None, 16, 16, 32)   4608        ['activation_10[0][0]']          
                                                                                                  
 batch_normalization_12 (BatchN  (None, 16, 16, 32)  128         ['conv2d_12[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_11 (Activation)     (None, 16, 16, 32)   0           ['batch_normalization_12[0][0]'] 
                                                                                                  
 conv2d_11 (Conv2D)             (None, 16, 16, 32)   512         ['activation_10[0][0]']          
                                                                                                  
 conv2d_13 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_11[0][0]']          
                                                                                                  
 batch_normalization_11 (BatchN  (None, 16, 16, 32)  128         ['conv2d_11[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 batch_normalization_13 (BatchN  (None, 16, 16, 32)  128         ['conv2d_13[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_5 (Add)                    (None, 16, 16, 32)   0           ['batch_normalization_11[0][0]', 
                                                                  'batch_normalization_13[0][0]'] 
                                                                                                  
 activation_12 (Activation)     (None, 16, 16, 32)   0           ['add_5[0][0]']                  
                                                                                                  
 conv2d_14 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_12[0][0]']          
                                                                                                  
 batch_normalization_14 (BatchN  (None, 16, 16, 32)  128         ['conv2d_14[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_13 (Activation)     (None, 16, 16, 32)   0           ['batch_normalization_14[0][0]'] 
                                                                                                  
 conv2d_15 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_13[0][0]']          
                                                                                                  
 batch_normalization_15 (BatchN  (None, 16, 16, 32)  128         ['conv2d_15[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_6 (Add)                    (None, 16, 16, 32)   0           ['activation_12[0][0]',          
                                                                  'batch_normalization_15[0][0]'] 
                                                                                                  
 activation_14 (Activation)     (None, 16, 16, 32)   0           ['add_6[0][0]']                  
                                                                                                  
 conv2d_16 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_14[0][0]']          
                                                                                                  
 batch_normalization_16 (BatchN  (None, 16, 16, 32)  128         ['conv2d_16[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_15 (Activation)     (None, 16, 16, 32)   0           ['batch_normalization_16[0][0]'] 
                                                                                                  
 conv2d_17 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_15[0][0]']          
                                                                                                  
 batch_normalization_17 (BatchN  (None, 16, 16, 32)  128         ['conv2d_17[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_7 (Add)                    (None, 16, 16, 32)   0           ['activation_14[0][0]',          
                                                                  'batch_normalization_17[0][0]'] 
                                                                                                  
 activation_16 (Activation)     (None, 16, 16, 32)   0           ['add_7[0][0]']                  
                                                                                                  
 conv2d_18 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_16[0][0]']          
                                                                                                  
 batch_normalization_18 (BatchN  (None, 16, 16, 32)  128         ['conv2d_18[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_17 (Activation)     (None, 16, 16, 32)   0           ['batch_normalization_18[0][0]'] 
                                                                                                  
 conv2d_19 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_17[0][0]']          
                                                                                                  
 batch_normalization_19 (BatchN  (None, 16, 16, 32)  128         ['conv2d_19[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_8 (Add)                    (None, 16, 16, 32)   0           ['activation_16[0][0]',          
                                                                  'batch_normalization_19[0][0]'] 
                                                                                                  
 activation_18 (Activation)     (None, 16, 16, 32)   0           ['add_8[0][0]']                  
                                                                                                  
 conv2d_20 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_18[0][0]']          
                                                                                                  
 batch_normalization_20 (BatchN  (None, 16, 16, 32)  128         ['conv2d_20[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_19 (Activation)     (None, 16, 16, 32)   0           ['batch_normalization_20[0][0]'] 
                                                                                                  
 conv2d_21 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_19[0][0]']          
                                                                                                  
 batch_normalization_21 (BatchN  (None, 16, 16, 32)  128         ['conv2d_21[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_9 (Add)                    (None, 16, 16, 32)   0           ['activation_18[0][0]',          
                                                                  'batch_normalization_21[0][0]'] 
                                                                                                  
 activation_20 (Activation)     (None, 16, 16, 32)   0           ['add_9[0][0]']                  
                                                                                                  
 conv2d_23 (Conv2D)             (None, 8, 8, 64)     18432       ['activation_20[0][0]']          
                                                                                                  
 batch_normalization_23 (BatchN  (None, 8, 8, 64)    256         ['conv2d_23[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_21 (Activation)     (None, 8, 8, 64)     0           ['batch_normalization_23[0][0]'] 
                                                                                                  
 conv2d_22 (Conv2D)             (None, 8, 8, 64)     2048        ['activation_20[0][0]']          
                                                                                                  
 conv2d_24 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_21[0][0]']          
                                                                                                  
 batch_normalization_22 (BatchN  (None, 8, 8, 64)    256         ['conv2d_22[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 batch_normalization_24 (BatchN  (None, 8, 8, 64)    256         ['conv2d_24[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_10 (Add)                   (None, 8, 8, 64)     0           ['batch_normalization_22[0][0]', 
                                                                  'batch_normalization_24[0][0]'] 
                                                                                                  
 activation_22 (Activation)     (None, 8, 8, 64)     0           ['add_10[0][0]']                 
                                                                                                  
 conv2d_25 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_22[0][0]']          
                                                                                                  
 batch_normalization_25 (BatchN  (None, 8, 8, 64)    256         ['conv2d_25[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_23 (Activation)     (None, 8, 8, 64)     0           ['batch_normalization_25[0][0]'] 
                                                                                                  
 conv2d_26 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_23[0][0]']          
                                                                                                  
 batch_normalization_26 (BatchN  (None, 8, 8, 64)    256         ['conv2d_26[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_11 (Add)                   (None, 8, 8, 64)     0           ['activation_22[0][0]',          
                                                                  'batch_normalization_26[0][0]'] 
                                                                                                  
 activation_24 (Activation)     (None, 8, 8, 64)     0           ['add_11[0][0]']                 
                                                                                                  
 conv2d_27 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_24[0][0]']          
                                                                                                  
 batch_normalization_27 (BatchN  (None, 8, 8, 64)    256         ['conv2d_27[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_25 (Activation)     (None, 8, 8, 64)     0           ['batch_normalization_27[0][0]'] 
                                                                                                  
 conv2d_28 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_25[0][0]']          
                                                                                                  
 batch_normalization_28 (BatchN  (None, 8, 8, 64)    256         ['conv2d_28[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_12 (Add)                   (None, 8, 8, 64)     0           ['activation_24[0][0]',          
                                                                  'batch_normalization_28[0][0]'] 
                                                                                                  
 activation_26 (Activation)     (None, 8, 8, 64)     0           ['add_12[0][0]']                 
                                                                                                  
 conv2d_29 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_26[0][0]']          
                                                                                                  
 batch_normalization_29 (BatchN  (None, 8, 8, 64)    256         ['conv2d_29[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_27 (Activation)     (None, 8, 8, 64)     0           ['batch_normalization_29[0][0]'] 
                                                                                                  
 conv2d_30 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_27[0][0]']          
                                                                                                  
 batch_normalization_30 (BatchN  (None, 8, 8, 64)    256         ['conv2d_30[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_13 (Add)                   (None, 8, 8, 64)     0           ['activation_26[0][0]',          
                                                                  'batch_normalization_30[0][0]'] 
                                                                                                  
 activation_28 (Activation)     (None, 8, 8, 64)     0           ['add_13[0][0]']                 
                                                                                                  
 conv2d_31 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_28[0][0]']          
                                                                                                  
 batch_normalization_31 (BatchN  (None, 8, 8, 64)    256         ['conv2d_31[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_29 (Activation)     (None, 8, 8, 64)     0           ['batch_normalization_31[0][0]'] 
                                                                                                  
 conv2d_32 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_29[0][0]']          
                                                                                                  
 batch_normalization_32 (BatchN  (None, 8, 8, 64)    256         ['conv2d_32[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_14 (Add)                   (None, 8, 8, 64)     0           ['activation_28[0][0]',          
                                                                  'batch_normalization_32[0][0]'] 
                                                                                                  
 activation_30 (Activation)     (None, 8, 8, 64)     0           ['add_14[0][0]']                 
                                                                                                  
 average_pooling2d (AveragePool  (None, 1, 1, 64)    0           ['activation_30[0][0]']          
 ing2D)                                                                                           
                                                                                                  
 flatten (Flatten)              (None, 64)           0           ['average_pooling2d[0][0]']      
                                                                                                  
 dense (Dense)                  (None, 10)           650         ['flatten[0][0]']                
                                                                                                  
==================================================================================================
Total params: 469,370
Trainable params: 466,906
Non-trainable params: 2,464
__________________________________________________________________________________________________
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
model_img = output_dir + '/cifar10_%s.png'%(model_name)  # 模型结构图保存路径
plot_model(model, to_file=model_img, show_shapes=True) # 模型结构保存为一张图片
print('%s已保存' % model_img)
  • 1
  • 2
  • 3
./output/cifar10_resnet.png已保存
  • 1

在这里插入图片描述

开始训练模型

首先我们可以设置我们的迭代次数和batch_size

epochs = 20  # 迭代次数
batch_size = 128  # 批大小
  • 1
  • 2

这一部分是设置在训练的时候的一些参数

  • 首先保存最好的模型,先定义我们的model path
  • 设置save_best_only=True,也就是代表只保存一遍
  • monitor='val_loss’代表的是监视val_loss,着重观察val_loss,只选取最小的val_loss的模型进行保存,当然这个我们也可以换成val_acc也是可以的
checkpoint = ModelCheckpoint(output_dir + '/best_%s_simple.h5'%model_name,  # model filename
                             monitor='val_loss', # quantity to monitor
                             verbose=0, # verbosity - 0 or 1
                             save_best_only= True, # The latest best model will not be overwritten
                             mode='auto') # The decision to overwrite model is made 
                                          # automatically depending on the quantity to monitor 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

接下来我们就可以定义我们的优化器和损失函数了,keras很简单,并且定义我们需要计算的metrics为准确率即可

adam = adam_v2.Adam(lr = 0.001)
model.compile(loss = 'categorical_crossentropy', optimizer = adam, metrics = ['accuracy'])
  • 1
  • 2

最后我们使用内置的fit函数,并且加上我们所需要的超参数,就可以完成我们的训练了。

history = model.fit(x_train,y_train, 
                    batch_size=batch_size,
                    epochs=epochs,
                    validation_data=(x_val,y_val),
                    shuffle=True,
                    callbacks=[checkpoint])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
Epoch 1/20
391/391 [==============================] - 25s 48ms/step - loss: 1.5136 - accuracy: 0.4872 - val_loss: 2.3308 - val_accuracy: 0.3384
Epoch 2/20
391/391 [==============================] - 18s 46ms/step - loss: 1.0520 - accuracy: 0.6656 - val_loss: 1.3796 - val_accuracy: 0.5948
Epoch 3/20
391/391 [==============================] - 18s 46ms/step - loss: 0.8567 - accuracy: 0.7406 - val_loss: 1.0479 - val_accuracy: 0.6941
Epoch 4/20
391/391 [==============================] - 17s 45ms/step - loss: 0.7438 - accuracy: 0.7839 - val_loss: 1.4804 - val_accuracy: 0.5939
Epoch 5/20
391/391 [==============================] - 18s 47ms/step - loss: 0.6655 - accuracy: 0.8118 - val_loss: 1.0215 - val_accuracy: 0.7029
Epoch 6/20
391/391 [==============================] - 18s 46ms/step - loss: 0.6030 - accuracy: 0.8339 - val_loss: 1.3131 - val_accuracy: 0.6538
Epoch 7/20
391/391 [==============================] - 18s 47ms/step - loss: 0.5612 - accuracy: 0.8501 - val_loss: 0.9786 - val_accuracy: 0.7368
Epoch 8/20
391/391 [==============================] - 18s 45ms/step - loss: 0.5128 - accuracy: 0.8705 - val_loss: 1.2962 - val_accuracy: 0.6735
Epoch 9/20
391/391 [==============================] - 18s 45ms/step - loss: 0.4838 - accuracy: 0.8817 - val_loss: 1.5207 - val_accuracy: 0.6318
Epoch 10/20
391/391 [==============================] - 18s 47ms/step - loss: 0.4482 - accuracy: 0.8953 - val_loss: 0.9014 - val_accuracy: 0.7674
Epoch 11/20
391/391 [==============================] - 18s 45ms/step - loss: 0.4241 - accuracy: 0.9053 - val_loss: 0.9986 - val_accuracy: 0.7666
Epoch 12/20
391/391 [==============================] - 18s 45ms/step - loss: 0.3995 - accuracy: 0.9154 - val_loss: 1.0747 - val_accuracy: 0.7399
Epoch 13/20
391/391 [==============================] - 17s 45ms/step - loss: 0.3755 - accuracy: 0.9243 - val_loss: 1.4094 - val_accuracy: 0.7106
Epoch 14/20
391/391 [==============================] - 18s 45ms/step - loss: 0.3634 - accuracy: 0.9308 - val_loss: 1.3048 - val_accuracy: 0.7187
Epoch 15/20
391/391 [==============================] - 18s 46ms/step - loss: 0.3418 - accuracy: 0.9394 - val_loss: 1.1310 - val_accuracy: 0.7498
Epoch 16/20
391/391 [==============================] - 18s 45ms/step - loss: 0.3392 - accuracy: 0.9413 - val_loss: 1.1636 - val_accuracy: 0.7490
Epoch 17/20
391/391 [==============================] - 18s 45ms/step - loss: 0.3301 - accuracy: 0.9456 - val_loss: 1.6518 - val_accuracy: 0.6921
Epoch 18/20
391/391 [==============================] - 18s 45ms/step - loss: 0.3227 - accuracy: 0.9495 - val_loss: 1.2451 - val_accuracy: 0.7381
Epoch 19/20
391/391 [==============================] - 18s 45ms/step - loss: 0.3123 - accuracy: 0.9547 - val_loss: 1.2867 - val_accuracy: 0.7464
Epoch 20/20
391/391 [==============================] - 18s 45ms/step - loss: 0.3129 - accuracy: 0.9546 - val_loss: 1.6354 - val_accuracy: 0.6954
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

可视化准确率、损失函数

def plot_model_history(model_history):
    fig, axs = plt.subplots(1,2,figsize=(15,5))
    # summarize history for accuracy
    axs[0].plot(range(1,len(model_history.history['accuracy'])+1),model_history.history['accuracy'])
    axs[0].plot(range(1,len(model_history.history['val_accuracy'])+1),model_history.history['val_accuracy'])
    axs[0].set_title('Model Accuracy')
    axs[0].set_ylabel('Accuracy')
    axs[0].set_xlabel('Epoch')
   
    axs[0].legend(['train', 'val'], loc='best')
    # summarize history for loss
    axs[1].plot(range(1,len(model_history.history['loss'])+1),model_history.history['loss'])
    axs[1].plot(range(1,len(model_history.history['val_loss'])+1),model_history.history['val_loss'])
    axs[1].set_title('Model Loss')
    axs[1].set_ylabel('Loss')
    axs[1].set_xlabel('Epoch')
    axs[1].legend(['train', 'val'], loc='best')
    plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
plot_model_history(history)
  • 1

在这里插入图片描述

保存模型

model_path = output_dir + '/keras_cifar10_%s_model.h5'%model_name
model.save(model_path)
print('%s已保存' % model_path)
  • 1
  • 2
  • 3
./output/keras_cifar10_resnet_model.h5已保存
  • 1

预测结果

# 取验证集里面的图片拿来预测看看
name = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer',
        5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
n = 20  # 取多少张图片

x_test = x_val[:n]
y_test = y_val[:n]

# 预测
y_predict = model.predict(x_test, batch_size=n)

# 绘制预测结果
plt.figure(figsize=(18, 3))  # 指定画布大小
for i in range(n):
    plt.subplot(2, 10, i + 1)
    plt.axis('off')  # 取消x,y轴坐标
    plt.imshow(x_test[i])  # 显示图片
    if y_test[i].argmax() == y_predict[i].argmax():
        # 预测正确,用绿色标题
        plt.title('%s,%s' % (name[y_test[i].argmax()], name[y_predict[i].argmax()]), color='green')
    else:
        # 预测错误,用红色标题
        plt.title('%s,%s' % (name[y_test[i].argmax()], name[y_predict[i].argmax()]), color='red')
predict_img = output_dir + '/predict_%s.png'%(model_name)
print('%s已保存' % predict_img)
plt.savefig(predict_img)  # 保存预测图片
plt.show()  # 显示画布
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
./output/predict_resnet.png已保存
  • 1

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jdVm1j3p-1668299839132)(./img/Keras%20CIFAR-10%E5%88%86%E7%B1%BB%EF%BC%88ResNet%EF%BC%89_41_1.png)]

res = model.evaluate(x_test,y_test)
print('{:2f}%'.format(res[1]*100))
  • 1
  • 2
1/1 [==============================] - 0s 54ms/step - loss: 0.8094 - accuracy: 0.7500
75.000000%
  • 1
  • 2
res = model.evaluate(x_val,y_val)
print('{:2f}%'.format(res[1]*100))
  • 1
  • 2
313/313 [==============================] - 3s 11ms/step - loss: 1.1061 - accuracy: 0.7747
77.469999%
  • 1
  • 2

数据增强

除了用原图片进行训练之外,我们还有一种方式可以增加准确性,也就是数据增强。下面来介绍一下数据增强

数据增强(Data Augmentation)是一种通过让有限的数据产生更多的等价数据来人工扩展训练数据集的技术。它是克服训练数据不足的有效手段,目前在深度学习的各个领域中应用广泛。但是由于生成的数据与真实数据之间的差异,也不可避免地带来了噪声问题。深度神经网络在许多任务中表现良好,但这些网络通常需要大量数据才能避免过度拟合。遗憾的是,许多场景无法获得大量数据,数据增强技术的存在是为了解决这个问题,这是针对有限数据问题的解决方案。数据增强一套技术,可提高训练数据集的大小和质量,以便您可以使用它们来构建更好的深度学习模型。

计算视觉领域的数据增强

计算视觉领域的数据增强算法大致可以分为两类:第一类是基于基本图像处理技术的数据增强,第二个类别是基于深度学习的数据增强算法。

下面先介绍基于基本图像处理技术的数据增强方法:

  • 几何变换(Geometric Transformations):由于训练集与测试集合中可能存在潜在的位置偏差,使得模型在测试集中很难达到训练集中的效果,几何变换可以有效地克服训练数据中存在的位置偏差,而且易于实现,许多图像处理库都包含这个功能。
  • 颜色变换(Color Space):图片在输入计算机之前,通常会被编码为张量(高度×宽度×颜色通道),所以可以在色彩通道空间进行数据增强,比如将某种颜色通道关闭,或者改变亮度值。
  • 旋转 | 反射变换(Rotation/Reflection):选择一个角度,左右旋转图像,可以改变图像内容朝向。关于旋转角度需要慎重考虑,角度太大或者太小都不合适,适宜的角度是1度 到 20度。
  • 噪声注入(Noise Injection):从高斯分布中采样出的随机值矩阵加入到图像的RGB像素中,通过向图像添加噪点可以帮助CNN学习更强大的功能。
  • 内核过滤器(Kernel Filters):内核滤镜是在图像处理中一种非常流行的技术,比如锐化和模糊。将特定功能的内核滤镜与图像进行卷积操作,就可以得到增强后的数据。直观上,数据增强生成的图像可能会使得模型面对这种类型的图像具有更高的鲁棒性。
  • 混合图像(Mix):通过平均图像像素值将图像混合在一起是一种非常违反直觉的数据增强方法。对于人来说,混合图像生成的数据似乎没有意义。虽然这种方法缺乏可解释性,但是作为一种简单有效的数据增强算法,有一系列的工作进行相关的研究。Inoue在图像每个像素点混合像素值来混合图像,Summers和Dinneen又尝试以非线性的方法来混合图像,Takahashi和Matsubara通过随机图像裁剪和拼接来混合图像,以及后来的mixup方法均取得了不错的成果。
  • 随机擦除(Random Erasing):随机擦除是Zhong等人开发的数据增强技术。他们受到Dropout机制的启发,随机选取图片中的一部分,将这部分图片删除,这项技术可以提高模型在图片被部分遮挡的情况下性能,除此之外还可以确保网络关注整个图像,而不只是其中的一部分。
  • 缩放变换(Zoom):图像按照一定的比例进行放大和缩小并不改变图像中的内容,可以增加模型的泛化性能。
  • 移动(Translation):向左,向右,向上或向下移动图像可以避免数据中的位置偏差,比如在人脸识别数据集合中,如果所有图像都居中,使用这种数据增强方法可以避免可能出现的位置偏差导致的错误。
  • 翻转变换(Flipping):通常是关于水平或者竖直的轴进行图像翻转操作,这种扩充是最容易实现的扩充,并且已经证明对ImageNet数据集有效。
  • 裁剪(Cropping):如果输入数据集合的大小是变化的,裁剪可以作为数据预处理的一个手段,通过裁剪图像的中央色块,可以得到新的数据。在实际使用过程之中,这些数据增强算法不是只使用一种,而是使用一套数据增强策略,在AutoAugment这篇文章中,作者尝试让模型自动选择数据增强策略。

img

第二个类别是基于深度学习的数据增强算法:

  • 特征空间增强(Feature Space Augmentation):神经网络可以将图像这种高维向量映射为低维向量,之前讨论的所有图像数据增强方法都应用于输入空间中的图像。现在可以在特征空间进行数据增强操作,例如:SMOTE算法,它是一种流行的增强方法,通过将k个最近的邻居合并以形成新实例来缓解类不平衡问题。
  • 对抗生成(Adversarial Training):对抗攻击表明,图像表示的健壮性远不及预期的健壮性,Moosavi-Dezfooli等人充分证明了这一点。对抗生成可以改善学习的决策边界中的薄弱环节,提高模型的鲁棒性。
  • 基于GAN的数据增强(GAN-based Data Augmentation):使用 GAN 生成模型来生成更多的数据,可用作解决类别不平衡问题的过采样技术。
  • 神经风格转换(Neural Style Transfer):通过神经网络风格迁移来生成不同风格的数据,防止模型过拟合。
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import LearningRateScheduler
# fit data with data augmentation or not
data_augmentation = True

# def lr_scheduler(epoch):
#     return lr * (0.1 ** (epoch // 50))

# reduce_lr = LearningRateScheduler(lr_scheduler)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.93,
                              patience=1, min_lr=1e-6, verbose=1)
checkpoint = ModelCheckpoint(output_dir + '/best_%s_data_augmentation.h5'%model_name,  # model filename
                             monitor='val_loss', # quantity to monitor
                             verbose=0, # verbosity - 0 or 1
                             save_best_only= True, # The latest best model will not be overwritten
                             mode='auto') # The decision to overwrite model is made 
                                          # automatically depending on the quantity to monitor 
batch_size = 64
epochs = 30
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
from keras.optimizers import gradient_descent_v2
weight_decay = 1e-4
model = ResNet32ForCIFAR10(input_shape=(32, 32, 3), classes=num_classes, weight_decay=weight_decay)
adam = gradient_descent_v2.SGD(lr = 0.1, momentum=0.9, nesterov=True)
model.compile(loss = 'categorical_crossentropy', optimizer = adam, metrics = ['accuracy'])
model.summary()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
Model: "resnet32"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_2 (InputLayer)           [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d_33 (Conv2D)             (None, 32, 32, 16)   432         ['input_2[0][0]']                
                                                                                                  
 batch_normalization_33 (BatchN  (None, 32, 32, 16)  64          ['conv2d_33[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_31 (Activation)     (None, 32, 32, 16)   0           ['batch_normalization_33[0][0]'] 
                                                                                                  
 conv2d_34 (Conv2D)             (None, 32, 32, 16)   2304        ['activation_31[0][0]']          
                                                                                                  
 batch_normalization_34 (BatchN  (None, 32, 32, 16)  64          ['conv2d_34[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_32 (Activation)     (None, 32, 32, 16)   0           ['batch_normalization_34[0][0]'] 
                                                                                                  
 conv2d_35 (Conv2D)             (None, 32, 32, 16)   2304        ['activation_32[0][0]']          
                                                                                                  
 batch_normalization_35 (BatchN  (None, 32, 32, 16)  64          ['conv2d_35[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_15 (Add)                   (None, 32, 32, 16)   0           ['activation_31[0][0]',          
                                                                  'batch_normalization_35[0][0]'] 
                                                                                                  
 activation_33 (Activation)     (None, 32, 32, 16)   0           ['add_15[0][0]']                 
                                                                                                  
 conv2d_36 (Conv2D)             (None, 32, 32, 16)   2304        ['activation_33[0][0]']          
                                                                                                  
 batch_normalization_36 (BatchN  (None, 32, 32, 16)  64          ['conv2d_36[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_34 (Activation)     (None, 32, 32, 16)   0           ['batch_normalization_36[0][0]'] 
                                                                                                  
 conv2d_37 (Conv2D)             (None, 32, 32, 16)   2304        ['activation_34[0][0]']          
                                                                                                  
 batch_normalization_37 (BatchN  (None, 32, 32, 16)  64          ['conv2d_37[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_16 (Add)                   (None, 32, 32, 16)   0           ['activation_33[0][0]',          
                                                                  'batch_normalization_37[0][0]'] 
                                                                                                  
 activation_35 (Activation)     (None, 32, 32, 16)   0           ['add_16[0][0]']                 
                                                                                                  
 conv2d_38 (Conv2D)             (None, 32, 32, 16)   2304        ['activation_35[0][0]']          
                                                                                                  
 batch_normalization_38 (BatchN  (None, 32, 32, 16)  64          ['conv2d_38[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_36 (Activation)     (None, 32, 32, 16)   0           ['batch_normalization_38[0][0]'] 
                                                                                                  
 conv2d_39 (Conv2D)             (None, 32, 32, 16)   2304        ['activation_36[0][0]']          
                                                                                                  
 batch_normalization_39 (BatchN  (None, 32, 32, 16)  64          ['conv2d_39[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_17 (Add)                   (None, 32, 32, 16)   0           ['activation_35[0][0]',          
                                                                  'batch_normalization_39[0][0]'] 
                                                                                                  
 activation_37 (Activation)     (None, 32, 32, 16)   0           ['add_17[0][0]']                 
                                                                                                  
 conv2d_40 (Conv2D)             (None, 32, 32, 16)   2304        ['activation_37[0][0]']          
                                                                                                  
 batch_normalization_40 (BatchN  (None, 32, 32, 16)  64          ['conv2d_40[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_38 (Activation)     (None, 32, 32, 16)   0           ['batch_normalization_40[0][0]'] 
                                                                                                  
 conv2d_41 (Conv2D)             (None, 32, 32, 16)   2304        ['activation_38[0][0]']          
                                                                                                  
 batch_normalization_41 (BatchN  (None, 32, 32, 16)  64          ['conv2d_41[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_18 (Add)                   (None, 32, 32, 16)   0           ['activation_37[0][0]',          
                                                                  'batch_normalization_41[0][0]'] 
                                                                                                  
 activation_39 (Activation)     (None, 32, 32, 16)   0           ['add_18[0][0]']                 
                                                                                                  
 conv2d_42 (Conv2D)             (None, 32, 32, 16)   2304        ['activation_39[0][0]']          
                                                                                                  
 batch_normalization_42 (BatchN  (None, 32, 32, 16)  64          ['conv2d_42[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_40 (Activation)     (None, 32, 32, 16)   0           ['batch_normalization_42[0][0]'] 
                                                                                                  
 conv2d_43 (Conv2D)             (None, 32, 32, 16)   2304        ['activation_40[0][0]']          
                                                                                                  
 batch_normalization_43 (BatchN  (None, 32, 32, 16)  64          ['conv2d_43[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_19 (Add)                   (None, 32, 32, 16)   0           ['activation_39[0][0]',          
                                                                  'batch_normalization_43[0][0]'] 
                                                                                                  
 activation_41 (Activation)     (None, 32, 32, 16)   0           ['add_19[0][0]']                 
                                                                                                  
 conv2d_45 (Conv2D)             (None, 16, 16, 32)   4608        ['activation_41[0][0]']          
                                                                                                  
 batch_normalization_45 (BatchN  (None, 16, 16, 32)  128         ['conv2d_45[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_42 (Activation)     (None, 16, 16, 32)   0           ['batch_normalization_45[0][0]'] 
                                                                                                  
 conv2d_44 (Conv2D)             (None, 16, 16, 32)   512         ['activation_41[0][0]']          
                                                                                                  
 conv2d_46 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_42[0][0]']          
                                                                                                  
 batch_normalization_44 (BatchN  (None, 16, 16, 32)  128         ['conv2d_44[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 batch_normalization_46 (BatchN  (None, 16, 16, 32)  128         ['conv2d_46[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_20 (Add)                   (None, 16, 16, 32)   0           ['batch_normalization_44[0][0]', 
                                                                  'batch_normalization_46[0][0]'] 
                                                                                                  
 activation_43 (Activation)     (None, 16, 16, 32)   0           ['add_20[0][0]']                 
                                                                                                  
 conv2d_47 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_43[0][0]']          
                                                                                                  
 batch_normalization_47 (BatchN  (None, 16, 16, 32)  128         ['conv2d_47[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_44 (Activation)     (None, 16, 16, 32)   0           ['batch_normalization_47[0][0]'] 
                                                                                                  
 conv2d_48 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_44[0][0]']          
                                                                                                  
 batch_normalization_48 (BatchN  (None, 16, 16, 32)  128         ['conv2d_48[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_21 (Add)                   (None, 16, 16, 32)   0           ['activation_43[0][0]',          
                                                                  'batch_normalization_48[0][0]'] 
                                                                                                  
 activation_45 (Activation)     (None, 16, 16, 32)   0           ['add_21[0][0]']                 
                                                                                                  
 conv2d_49 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_45[0][0]']          
                                                                                                  
 batch_normalization_49 (BatchN  (None, 16, 16, 32)  128         ['conv2d_49[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_46 (Activation)     (None, 16, 16, 32)   0           ['batch_normalization_49[0][0]'] 
                                                                                                  
 conv2d_50 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_46[0][0]']          
                                                                                                  
 batch_normalization_50 (BatchN  (None, 16, 16, 32)  128         ['conv2d_50[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_22 (Add)                   (None, 16, 16, 32)   0           ['activation_45[0][0]',          
                                                                  'batch_normalization_50[0][0]'] 
                                                                                                  
 activation_47 (Activation)     (None, 16, 16, 32)   0           ['add_22[0][0]']                 
                                                                                                  
 conv2d_51 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_47[0][0]']          
                                                                                                  
 batch_normalization_51 (BatchN  (None, 16, 16, 32)  128         ['conv2d_51[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_48 (Activation)     (None, 16, 16, 32)   0           ['batch_normalization_51[0][0]'] 
                                                                                                  
 conv2d_52 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_48[0][0]']          
                                                                                                  
 batch_normalization_52 (BatchN  (None, 16, 16, 32)  128         ['conv2d_52[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_23 (Add)                   (None, 16, 16, 32)   0           ['activation_47[0][0]',          
                                                                  'batch_normalization_52[0][0]'] 
                                                                                                  
 activation_49 (Activation)     (None, 16, 16, 32)   0           ['add_23[0][0]']                 
                                                                                                  
 conv2d_53 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_49[0][0]']          
                                                                                                  
 batch_normalization_53 (BatchN  (None, 16, 16, 32)  128         ['conv2d_53[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_50 (Activation)     (None, 16, 16, 32)   0           ['batch_normalization_53[0][0]'] 
                                                                                                  
 conv2d_54 (Conv2D)             (None, 16, 16, 32)   9216        ['activation_50[0][0]']          
                                                                                                  
 batch_normalization_54 (BatchN  (None, 16, 16, 32)  128         ['conv2d_54[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_24 (Add)                   (None, 16, 16, 32)   0           ['activation_49[0][0]',          
                                                                  'batch_normalization_54[0][0]'] 
                                                                                                  
 activation_51 (Activation)     (None, 16, 16, 32)   0           ['add_24[0][0]']                 
                                                                                                  
 conv2d_56 (Conv2D)             (None, 8, 8, 64)     18432       ['activation_51[0][0]']          
                                                                                                  
 batch_normalization_56 (BatchN  (None, 8, 8, 64)    256         ['conv2d_56[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_52 (Activation)     (None, 8, 8, 64)     0           ['batch_normalization_56[0][0]'] 
                                                                                                  
 conv2d_55 (Conv2D)             (None, 8, 8, 64)     2048        ['activation_51[0][0]']          
                                                                                                  
 conv2d_57 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_52[0][0]']          
                                                                                                  
 batch_normalization_55 (BatchN  (None, 8, 8, 64)    256         ['conv2d_55[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 batch_normalization_57 (BatchN  (None, 8, 8, 64)    256         ['conv2d_57[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_25 (Add)                   (None, 8, 8, 64)     0           ['batch_normalization_55[0][0]', 
                                                                  'batch_normalization_57[0][0]'] 
                                                                                                  
 activation_53 (Activation)     (None, 8, 8, 64)     0           ['add_25[0][0]']                 
                                                                                                  
 conv2d_58 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_53[0][0]']          
                                                                                                  
 batch_normalization_58 (BatchN  (None, 8, 8, 64)    256         ['conv2d_58[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_54 (Activation)     (None, 8, 8, 64)     0           ['batch_normalization_58[0][0]'] 
                                                                                                  
 conv2d_59 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_54[0][0]']          
                                                                                                  
 batch_normalization_59 (BatchN  (None, 8, 8, 64)    256         ['conv2d_59[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_26 (Add)                   (None, 8, 8, 64)     0           ['activation_53[0][0]',          
                                                                  'batch_normalization_59[0][0]'] 
                                                                                                  
 activation_55 (Activation)     (None, 8, 8, 64)     0           ['add_26[0][0]']                 
                                                                                                  
 conv2d_60 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_55[0][0]']          
                                                                                                  
 batch_normalization_60 (BatchN  (None, 8, 8, 64)    256         ['conv2d_60[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_56 (Activation)     (None, 8, 8, 64)     0           ['batch_normalization_60[0][0]'] 
                                                                                                  
 conv2d_61 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_56[0][0]']          
                                                                                                  
 batch_normalization_61 (BatchN  (None, 8, 8, 64)    256         ['conv2d_61[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_27 (Add)                   (None, 8, 8, 64)     0           ['activation_55[0][0]',          
                                                                  'batch_normalization_61[0][0]'] 
                                                                                                  
 activation_57 (Activation)     (None, 8, 8, 64)     0           ['add_27[0][0]']                 
                                                                                                  
 conv2d_62 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_57[0][0]']          
                                                                                                  
 batch_normalization_62 (BatchN  (None, 8, 8, 64)    256         ['conv2d_62[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_58 (Activation)     (None, 8, 8, 64)     0           ['batch_normalization_62[0][0]'] 
                                                                                                  
 conv2d_63 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_58[0][0]']          
                                                                                                  
 batch_normalization_63 (BatchN  (None, 8, 8, 64)    256         ['conv2d_63[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_28 (Add)                   (None, 8, 8, 64)     0           ['activation_57[0][0]',          
                                                                  'batch_normalization_63[0][0]'] 
                                                                                                  
 activation_59 (Activation)     (None, 8, 8, 64)     0           ['add_28[0][0]']                 
                                                                                                  
 conv2d_64 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_59[0][0]']          
                                                                                                  
 batch_normalization_64 (BatchN  (None, 8, 8, 64)    256         ['conv2d_64[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_60 (Activation)     (None, 8, 8, 64)     0           ['batch_normalization_64[0][0]'] 
                                                                                                  
 conv2d_65 (Conv2D)             (None, 8, 8, 64)     36864       ['activation_60[0][0]']          
                                                                                                  
 batch_normalization_65 (BatchN  (None, 8, 8, 64)    256         ['conv2d_65[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 add_29 (Add)                   (None, 8, 8, 64)     0           ['activation_59[0][0]',          
                                                                  'batch_normalization_65[0][0]'] 
                                                                                                  
 activation_61 (Activation)     (None, 8, 8, 64)     0           ['add_29[0][0]']                 
                                                                                                  
 average_pooling2d_1 (AveragePo  (None, 1, 1, 64)    0           ['activation_61[0][0]']          
 oling2D)                                                                                         
                                                                                                  
 flatten_1 (Flatten)            (None, 64)           0           ['average_pooling2d_1[0][0]']    
                                                                                                  
 dense_1 (Dense)                (None, 10)           650         ['flatten_1[0][0]']              
                                                                                                  
==================================================================================================
Total params: 469,370
Trainable params: 466,906
Non-trainable params: 2,464
__________________________________________________________________________________________________

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291

这段代码我们使用的是基于基本图像处理的数据增强,我们设置了一些,比如roattion_range也就是旋转的角度,以及左右偏移大概0.1,以及水平翻转等,这些都是可以在我们的ImageDataGenerator中进行设置

%%time
if data_augmentation:
    # datagen
    datagen = ImageDataGenerator(
        featurewise_center=False,  # set input mean to 0 over the dataset
        samplewise_center=False,  # set each sample mean to 0
        featurewise_std_normalization=False,  # divide inputs by std of the dataset
        samplewise_std_normalization=False,  # divide each input by its std
        zca_whitening=False,  # apply ZCA whitening
        rotation_range=15,  # randomly rotate images in the range (degrees, 0 to 180)
        width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
        height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)
        horizontal_flip=True,  # randomly flip images
        vertical_flip=False, # randomly flip images
    ) 
    # (std, mean, and principal components if ZCA whitening is applied).
    datagen.fit(x_train)
    print('train with data augmentation')
    history = model.fit_generator(generator=datagen.flow(x_train, y_train, batch_size=batch_size), 
                                epochs=epochs,
                                callbacks=[reduce_lr, checkpoint],
                                validation_data=(x_val, y_val)
                                )
else:
    print('train without data augmentation')
    history = model.fit(x_train, y_train, 
                      batch_size=batch_size, epochs=epochs, 
                      callbacks=[reduce_lr],
                      validation_data=(x_val, y_val)
                      )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
782/782 [==============================] - 38s 44ms/step - loss: 1.8886 - accuracy: 0.3671 - val_loss: 1.7345 - val_accuracy: 0.4173 - lr: 0.1000
Epoch 2/30
782/782 [==============================] - 34s 43ms/step - loss: 1.4087 - accuracy: 0.5443 - val_loss: 1.4261 - val_accuracy: 0.5444 - lr: 0.1000
Epoch 3/30
782/782 [==============================] - ETA: 0s - loss: 1.1361 - accuracy: 0.6509
Epoch 00003: ReduceLROnPlateau reducing learning rate to 0.093000001385808.
782/782 [==============================] - 34s 44ms/step - loss: 1.1361 - accuracy: 0.6509 - val_loss: 1.5838 - val_accuracy: 0.5590 - lr: 0.1000
Epoch 4/30
782/782 [==============================] - 34s 44ms/step - loss: 0.9685 - accuracy: 0.7187 - val_loss: 1.0326 - val_accuracy: 0.7028 - lr: 0.0930
Epoch 5/30
781/782 [============================>.] - ETA: 0s - loss: 0.8858 - accuracy: 0.7521
Epoch 00005: ReduceLROnPlateau reducing learning rate to 0.08649000205099583.
782/782 [==============================] - 33s 43ms/step - loss: 0.8859 - accuracy: 0.7520 - val_loss: 1.5229 - val_accuracy: 0.5927 - lr: 0.0930
Epoch 6/30
782/782 [==============================] - ETA: 0s - loss: 0.8307 - accuracy: 0.7734
Epoch 00006: ReduceLROnPlateau reducing learning rate to 0.0804357048869133.
782/782 [==============================] - 34s 44ms/step - loss: 0.8307 - accuracy: 0.7734 - val_loss: 1.0997 - val_accuracy: 0.7028 - lr: 0.0865
Epoch 7/30
782/782 [==============================] - ETA: 0s - loss: 0.7843 - accuracy: 0.7906
Epoch 00007: ReduceLROnPlateau reducing learning rate to 0.07480520859360695.
782/782 [==============================] - 33s 43ms/step - loss: 0.7843 - accuracy: 0.7906 - val_loss: 1.0626 - val_accuracy: 0.7179 - lr: 0.0804
Epoch 8/30
781/782 [============================>.] - ETA: 0s - loss: 0.7455 - accuracy: 0.8035
Epoch 00008: ReduceLROnPlateau reducing learning rate to 0.06956884302198887.
782/782 [==============================] - 33s 43ms/step - loss: 0.7454 - accuracy: 0.8035 - val_loss: 1.3238 - val_accuracy: 0.6693 - lr: 0.0748
Epoch 9/30
782/782 [==============================] - 34s 43ms/step - loss: 0.7134 - accuracy: 0.8177 - val_loss: 0.9464 - val_accuracy: 0.7547 - lr: 0.0696
Epoch 10/30
782/782 [==============================] - 34s 43ms/step - loss: 0.6992 - accuracy: 0.8234 - val_loss: 0.9212 - val_accuracy: 0.7511 - lr: 0.0696
Epoch 11/30
782/782 [==============================] - 34s 43ms/step - loss: 0.6823 - accuracy: 0.8302 - val_loss: 0.8771 - val_accuracy: 0.7730 - lr: 0.0696
Epoch 12/30
781/782 [============================>.] - ETA: 0s - loss: 0.6782 - accuracy: 0.8307
Epoch 00012: ReduceLROnPlateau reducing learning rate to 0.06469902366399766.
782/782 [==============================] - 33s 42ms/step - loss: 0.6782 - accuracy: 0.8307 - val_loss: 0.9172 - val_accuracy: 0.7625 - lr: 0.0696
Epoch 13/30
782/782 [==============================] - 36s 46ms/step - loss: 0.6558 - accuracy: 0.8412 - val_loss: 0.7392 - val_accuracy: 0.8187 - lr: 0.0647
Epoch 14/30
782/782 [==============================] - ETA: 0s - loss: 0.6469 - accuracy: 0.8444
Epoch 00014: ReduceLROnPlateau reducing learning rate to 0.06017009228467941.
782/782 [==============================] - 33s 43ms/step - loss: 0.6469 - accuracy: 0.8444 - val_loss: 0.9919 - val_accuracy: 0.7541 - lr: 0.0647
Epoch 15/30
781/782 [============================>.] - ETA: 0s - loss: 0.6270 - accuracy: 0.8516
Epoch 00015: ReduceLROnPlateau reducing learning rate to 0.05595818527042866.
782/782 [==============================] - 35s 45ms/step - loss: 0.6269 - accuracy: 0.8516 - val_loss: 0.7825 - val_accuracy: 0.8096 - lr: 0.0602
Epoch 16/30
782/782 [==============================] - ETA: 0s - loss: 0.6099 - accuracy: 0.8560
Epoch 00016: ReduceLROnPlateau reducing learning rate to 0.05204111237078905.
782/782 [==============================] - 33s 43ms/step - loss: 0.6099 - accuracy: 0.8560 - val_loss: 0.8544 - val_accuracy: 0.7907 - lr: 0.0560
Epoch 17/30
781/782 [============================>.] - ETA: 0s - loss: 0.5900 - accuracy: 0.8624
Epoch 00017: ReduceLROnPlateau reducing learning rate to 0.04839823544025421.
782/782 [==============================] - 34s 43ms/step - loss: 0.5899 - accuracy: 0.8625 - val_loss: 0.7734 - val_accuracy: 0.8060 - lr: 0.0520
Epoch 18/30
781/782 [============================>.] - ETA: 0s - loss: 0.5732 - accuracy: 0.8675
Epoch 00018: ReduceLROnPlateau reducing learning rate to 0.04501035757362843.
782/782 [==============================] - 34s 43ms/step - loss: 0.5733 - accuracy: 0.8675 - val_loss: 0.9039 - val_accuracy: 0.7739 - lr: 0.0484
Epoch 19/30
782/782 [==============================] - 34s 43ms/step - loss: 0.5543 - accuracy: 0.8740 - val_loss: 0.6435 - val_accuracy: 0.8511 - lr: 0.0450
Epoch 20/30
781/782 [============================>.] - ETA: 0s - loss: 0.5497 - accuracy: 0.8767
Epoch 00020: ReduceLROnPlateau reducing learning rate to 0.04185963302850723.
782/782 [==============================] - 34s 43ms/step - loss: 0.5497 - accuracy: 0.8767 - val_loss: 0.8291 - val_accuracy: 0.7898 - lr: 0.0450
Epoch 21/30
781/782 [============================>.] - ETA: 0s - loss: 0.5368 - accuracy: 0.8797
Epoch 00021: ReduceLROnPlateau reducing learning rate to 0.03892945982515812.
782/782 [==============================] - 33s 43ms/step - loss: 0.5368 - accuracy: 0.8797 - val_loss: 0.6961 - val_accuracy: 0.8392 - lr: 0.0419
Epoch 22/30
781/782 [============================>.] - ETA: 0s - loss: 0.5215 - accuracy: 0.8839
Epoch 00022: ReduceLROnPlateau reducing learning rate to 0.03620439659804106.
782/782 [==============================] - 33s 42ms/step - loss: 0.5214 - accuracy: 0.8839 - val_loss: 0.7055 - val_accuracy: 0.8369 - lr: 0.0389
Epoch 23/30
782/782 [==============================] - ETA: 0s - loss: 0.5085 - accuracy: 0.8888
Epoch 00023: ReduceLROnPlateau reducing learning rate to 0.03367008984088898.
782/782 [==============================] - 34s 44ms/step - loss: 0.5085 - accuracy: 0.8888 - val_loss: 0.7107 - val_accuracy: 0.8380 - lr: 0.0362
Epoch 24/30
782/782 [==============================] - 33s 43ms/step - loss: 0.4883 - accuracy: 0.8942 - val_loss: 0.6060 - val_accuracy: 0.8619 - lr: 0.0337
Epoch 25/30
781/782 [============================>.] - ETA: 0s - loss: 0.4868 - accuracy: 0.8928
Epoch 00025: ReduceLROnPlateau reducing learning rate to 0.03131318382918835.
782/782 [==============================] - 34s 43ms/step - loss: 0.4871 - accuracy: 0.8927 - val_loss: 0.7552 - val_accuracy: 0.8338 - lr: 0.0337
Epoch 26/30
781/782 [============================>.] - ETA: 0s - loss: 0.4749 - accuracy: 0.8970
Epoch 00026: ReduceLROnPlateau reducing learning rate to 0.02912126172333956.
782/782 [==============================] - 34s 44ms/step - loss: 0.4750 - accuracy: 0.8969 - val_loss: 0.7968 - val_accuracy: 0.8160 - lr: 0.0313
Epoch 27/30
782/782 [==============================] - 33s 43ms/step - loss: 0.4645 - accuracy: 0.8998 - val_loss: 0.5906 - val_accuracy: 0.8637 - lr: 0.0291
Epoch 28/30
781/782 [============================>.] - ETA: 0s - loss: 0.4605 - accuracy: 0.9026
Epoch 00028: ReduceLROnPlateau reducing learning rate to 0.027082772813737395.
782/782 [==============================] - 33s 42ms/step - loss: 0.4606 - accuracy: 0.9026 - val_loss: 0.5991 - val_accuracy: 0.8682 - lr: 0.0291
Epoch 29/30
782/782 [==============================] - ETA: 0s - loss: 0.4414 - accuracy: 0.9069
Epoch 00029: ReduceLROnPlateau reducing learning rate to 0.025186978820711376.
782/782 [==============================] - 34s 44ms/step - loss: 0.4414 - accuracy: 0.9069 - val_loss: 0.7109 - val_accuracy: 0.8343 - lr: 0.0271
Epoch 30/30
782/782 [==============================] - ETA: 0s - loss: 0.4339 - accuracy: 0.9096
Epoch 00030: ReduceLROnPlateau reducing learning rate to 0.02342388980090618.
782/782 [==============================] - 33s 43ms/step - loss: 0.4339 - accuracy: 0.9096 - val_loss: 0.6589 - val_accuracy: 0.8427 - lr: 0.0252
CPU times: user 27min 58s, sys: 29.5 s, total: 28min 28s
Wall time: 17min 21s

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
plot_model_history(history)
  • 1

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eBZvonnc-1668299839133)(C:\Users\86137\AppData\Roaming\Typora\typora-user-images\image-20221113083648203.png)]

从结果可以看出,使用了数据增强之后,我们的结果是比没有进行数据增强是好的,已经达到了84.2%+的准确率,如果设置好数据增强的参数,很有可能可以得到更高的准确率,数据增强还是对结果有比较大的影响的,并且也更稳定

loss,acc = model.evaluate(x_val,y_val)
print('evaluate loss:%f acc:%f' % (loss, acc))
  • 1
  • 2

在这里插入图片描述

313/313 [==============================] - 3s 11ms/step - loss: 0.6589 - accuracy: 0.8427
evaluate loss:0.658943 acc:0.842700
  • 1
  • 2
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/598241
推荐阅读
相关标签
  

闽ICP备14008679号