赞
踩
除了用pytorch可以进行图像分类之外,我们也可以利用tensorflow来进行图像分类,其中利用tensorflow的后端keras更是尤为简单,接下来我们就利用keras对CIFAR10数据集进行分类。
keras是python深度学习中常用的一个学习框架,它有着极其强大的功能,基本能用于常用的各个模型。
1、相同的代码可以在cpu和gpu上切换;
2、在模型定义上,可以用函数式API,也可以用Sequential类;
3、支持任意网络架构,如多输入多输出;
4、能够使用卷积网络、循环网络及其组合。
Keras 是一个模型级的库,在开发中只用做高层次的操作,不处于张量计算,微积分计算等低级操作。但是keras最终处理数据时数据都是以张量形式呈现,不处理张量操作的keras是如何解决张量运算的呢?
keras依赖于专门处理张量的后端引擎,关于张量运算方面都是通过后端引擎完成的。这也就是为什么下载keras时需要下载TensorFlow 或者Theano的原因。而TensorFlow 、Theano、以及CNTK都属于处理数值张量的后端引擎。
安装也是很简单的,我们直接安装keras即可,如果需要tensorflow,就还需要安装tensorflow
pip install keras
import keras from keras.models import Sequential from keras.datasets import cifar10 from keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten, Dense, Activation from keras.optimizers import adam_v2 from keras.utils.vis_utils import plot_model from keras.utils.np_utils import to_categorical from keras.callbacks import ModelCheckpoint import matplotlib.pyplot as plt import numpy as np import os import shutil import matplotlib matplotlib.style.use('ggplot') %matplotlib inline plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots plt.rcParams['image.interpolation'] = 'nearest' plt.rcParams['image.cmap'] = 'gray'
这个是tensorflow来控制选择的GPU,因为存在多卡的时候可以指定GPU,其次还可以控制GPU的显存
这段语句就是动态显存,动态分配显存
config.gpu_options.allow_growth = True
这段语句就是说明,我们使用的最大显存不能超过50%
config.gpu_options.per_process_gpu_memory_fraction = 0.5
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 忽略低级别的警告
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# The GPU id to use, usually either "0" or "1"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
config = tf.compat.v1.ConfigProto()
# config = tf.ConfigProto()
# config.gpu_options.per_process_gpu_memory_fraction = 0.5
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( arplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。
与 MNIST 数据集中目比, CIFAR-10 具有以下不同点:
num_classes = 10 # 有多少个类别
(x_train, y_train), (x_val, y_val) = cifar10.load_data()
print("训练集的维度大小:",x_train.shape)
print("验证集的维度大小:",x_val.shape)
训练集的维度大小: (50000, 32, 32, 3)
验证集的维度大小: (10000, 32, 32, 3)
class_names = ['airplane','automobile','bird','cat','deer',
'dog','frog','horse','ship','truck']
fig = plt.figure(figsize=(20,5))
for i in range(num_classes):
ax = fig.add_subplot(2, 5, 1 + i, xticks=[], yticks=[])
idx = np.where(y_train[:]==i)[0] # 取得类别样本
features_idx = x_train[idx,::] # 取得图片
img_num = np.random.randint(features_idx.shape[0]) # 随机挑选图片
im = features_idx[img_num,::]
ax.set_title(class_names[i])
plt.imshow(im)
plt.show()
x_train = x_train.astype('float32')/255
x_val = x_val.astype('float32')/255
# 将向量转化为二分类矩阵,也就是one-hot编码
y_train = to_categorical(y_train, num_classes)
y_val = to_categorical(y_val, num_classes)
output_dir = './output' # 输出目录
if os.path.exists(output_dir) is False:
os.mkdir(output_dir)
# shutil.rmtree(output_dir)
# print('%s文件夹已存在,但是没关系,我们删掉了' % output_dir)
# os.mkdir(output_dir)
print('%s已创建' % output_dir)
print('%s文件夹已存在' % output_dir)
model_name = 'resnet'
./output已创建
./output文件夹已存在
当大家还在惊叹 GoogLeNet 的 inception 结构的时候,微软亚洲研究院的研究员已经在设计更深但结构更加简单的网络 ResNet,并且凭借这个网络斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名。
如果想详细了解并查看论文,可以看我的另一篇博客【论文泛读】 ResNet:深度残差网络
下图是ResNet18层模型的结构简图
还有ResNet-34模型
在ResNet网络中有如下几个亮点:
(1)提出residual结构(残差结构),并搭建超深的网络结构(突破1000层)
(2)使用Batch Normalization加速训练(丢弃dropout)
在ResNet网络提出之前,传统的卷积神经网络都是通过将一系列卷积层与下采样层进行堆叠得到的。但是当堆叠到一定网络深度时,就会出现两个问题。
(1)梯度消失或梯度爆炸。
(2)退化问题(degradation problem)。
在ResNet论文中说通过数据的预处理以及在网络中使用BN(Batch Normalization)层能够解决梯度消失或者梯度爆炸问题,residual结构(残差结构)来减轻退化问题。此时拟合目标就变为F(x),F(x)就是残差
这里有一个点是很重要的,对于我们的第二个layer,它是没有relu激活函数的,他需要与x相加最后再进行激活函数relu
我们先对ResNet18/34的残差结构进行一个分析。如下图所示,该残差结构的主分支是由两层3x3的卷积层组成,而残差结构右侧的连接线是shortcut分支也称捷径分支(注意为了让主分支上的输出矩阵能够与我们捷径分支上的输出矩阵进行相加,必须保证这两个输出特征矩阵有相同的shape)。我们会发现有一些虚线结构,论文中表述为用1x1的卷积进行降维,下图给出了详细的残差结构。
接着我们再来分析下针对ResNet50/101/152的残差结构,如下图所示。在该残差结构当中,主分支使用了三个卷积层,第一个是1x1的卷积层用来压缩channel维度,第二个是3x3的卷积层,第三个是1x1的卷积层用来还原channel维度(注意主分支上第一层卷积层和第二次卷积层所使用的卷积核个数是相同的,第三次是第一层的4倍),这种又叫做bottleneck模型
这是在ImageNet数据集中更深的残差网络的模型,这里面给出了残差结构给出了主分支上卷积核的大小与卷积核个数,表中的xN表示将该残差结构重复N次。
对于我们ResNet18/34/50/101/152,表中conv3_x, conv4_x, conv5_x所对应的一系列残差结构的第一层残差结构都是虚线残差结构。因为这一系列残差结构的第一层都有调整输入特征矩阵shape的使命(将特征矩阵的高和宽缩减为原来的一半,将深度channel调整成下一层残差结构所需要的channel)
input_shape = (32,32,3)
from keras.layers import BatchNormalization, AveragePooling2D, Input from keras.models import Model from keras.regularizers import l2 from keras import layers def conv2d_bn(x, filters, kernel_size, weight_decay=.0, strides=(1, 1)): layer = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same', use_bias=False, kernel_regularizer=l2(weight_decay) )(x) layer = BatchNormalization()(layer) return layer def conv2d_bn_relu(x, filters, kernel_size, weight_decay=.0, strides=(1, 1)): layer = conv2d_bn(x, filters, kernel_size, weight_decay, strides) layer = Activation('relu')(layer) return layer
def ResidualBlock(x, filters, kernel_size, weight_decay, downsample=True): if downsample: # residual_x = conv2d_bn_relu(x, filters, kernel_size=1, strides=2) residual_x = conv2d_bn(x, filters, kernel_size=1, strides=2) stride = 2 else: residual_x = x stride = 1 residual = conv2d_bn_relu(x, filters=filters, kernel_size=kernel_size, weight_decay=weight_decay, strides=stride, ) residual = conv2d_bn(residual, filters=filters, kernel_size=kernel_size, weight_decay=weight_decay, strides=1, ) out = layers.add([residual_x, residual]) out = Activation('relu')(out) return out
def ResNet18(classes, input_shape, weight_decay=1e-4): input = Input(shape=input_shape) x = input # x = conv2d_bn_relu(x, filters=64, kernel_size=(7, 7), weight_decay=weight_decay, strides=(2, 2)) # x = MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x) x = conv2d_bn_relu(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, strides=(1, 1)) # # conv 2 x = ResidualBlock(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False) x = ResidualBlock(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False) # # conv 3 x = ResidualBlock(x, filters=128, kernel_size=(3, 3), weight_decay=weight_decay, downsample=True) x = ResidualBlock(x, filters=128, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False) # # conv 4 x = ResidualBlock(x, filters=256, kernel_size=(3, 3), weight_decay=weight_decay, downsample=True) x = ResidualBlock(x, filters=256, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False) # # conv 5 x = ResidualBlock(x, filters=512, kernel_size=(3, 3), weight_decay=weight_decay, downsample=True) x = ResidualBlock(x, filters=512, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False) x = AveragePooling2D(pool_size=(4, 4), padding='valid')(x) x = Flatten()(x) x = Dense(classes, activation='softmax')(x) model = Model(input, x, name='ResNet18') return model def ResNetForCIFAR10(classes, name, input_shape, block_layers_num, weight_decay): input = Input(shape=input_shape) x = input x = conv2d_bn_relu(x, filters=16, kernel_size=(3, 3), weight_decay=weight_decay, strides=(1, 1)) # # conv 2 for i in range(block_layers_num): x = ResidualBlock(x, filters=16, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False) # # conv 3 x = ResidualBlock(x, filters=32, kernel_size=(3, 3), weight_decay=weight_decay, downsample=True) for i in range(block_layers_num - 1): x = ResidualBlock(x, filters=32, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False) # # conv 4 x = ResidualBlock(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, downsample=True) for i in range(block_layers_num - 1): x = ResidualBlock(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False) x = AveragePooling2D(pool_size=(8, 8), padding='valid')(x) x = Flatten()(x) x = Dense(classes, activation='softmax')(x) model = Model(input, x, name=name) return model def ResNet20ForCIFAR10(classes, input_shape, weight_decay): return ResNetForCIFAR10(classes, 'resnet20', input_shape, 3, weight_decay) def ResNet32ForCIFAR10(classes, input_shape, weight_decay): return ResNetForCIFAR10(classes, 'resnet32', input_shape, 5, weight_decay) def ResNet56ForCIFAR10(classes, input_shape, weight_decay): return ResNetForCIFAR10(classes, 'resnet56', input_shape, 9, weight_decay)
input_shape = (32,32,3)
weight_decay = 1e-4
model = ResNet32ForCIFAR10(input_shape=(32, 32, 3), classes=num_classes, weight_decay=weight_decay)
model.summary()
Model: "resnet32" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 32, 32, 3)] 0 [] conv2d (Conv2D) (None, 32, 32, 16) 432 ['input_1[0][0]'] batch_normalization (BatchNorm (None, 32, 32, 16) 64 ['conv2d[0][0]'] alization) activation (Activation) (None, 32, 32, 16) 0 ['batch_normalization[0][0]'] conv2d_1 (Conv2D) (None, 32, 32, 16) 2304 ['activation[0][0]'] batch_normalization_1 (BatchNo (None, 32, 32, 16) 64 ['conv2d_1[0][0]'] rmalization) activation_1 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_1[0][0]'] conv2d_2 (Conv2D) (None, 32, 32, 16) 2304 ['activation_1[0][0]'] batch_normalization_2 (BatchNo (None, 32, 32, 16) 64 ['conv2d_2[0][0]'] rmalization) add (Add) (None, 32, 32, 16) 0 ['activation[0][0]', 'batch_normalization_2[0][0]'] activation_2 (Activation) (None, 32, 32, 16) 0 ['add[0][0]'] conv2d_3 (Conv2D) (None, 32, 32, 16) 2304 ['activation_2[0][0]'] batch_normalization_3 (BatchNo (None, 32, 32, 16) 64 ['conv2d_3[0][0]'] rmalization) activation_3 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_3[0][0]'] conv2d_4 (Conv2D) (None, 32, 32, 16) 2304 ['activation_3[0][0]'] batch_normalization_4 (BatchNo (None, 32, 32, 16) 64 ['conv2d_4[0][0]'] rmalization) add_1 (Add) (None, 32, 32, 16) 0 ['activation_2[0][0]', 'batch_normalization_4[0][0]'] activation_4 (Activation) (None, 32, 32, 16) 0 ['add_1[0][0]'] conv2d_5 (Conv2D) (None, 32, 32, 16) 2304 ['activation_4[0][0]'] batch_normalization_5 (BatchNo (None, 32, 32, 16) 64 ['conv2d_5[0][0]'] rmalization) activation_5 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_5[0][0]'] conv2d_6 (Conv2D) (None, 32, 32, 16) 2304 ['activation_5[0][0]'] batch_normalization_6 (BatchNo (None, 32, 32, 16) 64 ['conv2d_6[0][0]'] rmalization) add_2 (Add) (None, 32, 32, 16) 0 ['activation_4[0][0]', 'batch_normalization_6[0][0]'] activation_6 (Activation) (None, 32, 32, 16) 0 ['add_2[0][0]'] conv2d_7 (Conv2D) (None, 32, 32, 16) 2304 ['activation_6[0][0]'] batch_normalization_7 (BatchNo (None, 32, 32, 16) 64 ['conv2d_7[0][0]'] rmalization) activation_7 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_7[0][0]'] conv2d_8 (Conv2D) (None, 32, 32, 16) 2304 ['activation_7[0][0]'] batch_normalization_8 (BatchNo (None, 32, 32, 16) 64 ['conv2d_8[0][0]'] rmalization) add_3 (Add) (None, 32, 32, 16) 0 ['activation_6[0][0]', 'batch_normalization_8[0][0]'] activation_8 (Activation) (None, 32, 32, 16) 0 ['add_3[0][0]'] conv2d_9 (Conv2D) (None, 32, 32, 16) 2304 ['activation_8[0][0]'] batch_normalization_9 (BatchNo (None, 32, 32, 16) 64 ['conv2d_9[0][0]'] rmalization) activation_9 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_9[0][0]'] conv2d_10 (Conv2D) (None, 32, 32, 16) 2304 ['activation_9[0][0]'] batch_normalization_10 (BatchN (None, 32, 32, 16) 64 ['conv2d_10[0][0]'] ormalization) add_4 (Add) (None, 32, 32, 16) 0 ['activation_8[0][0]', 'batch_normalization_10[0][0]'] activation_10 (Activation) (None, 32, 32, 16) 0 ['add_4[0][0]'] conv2d_12 (Conv2D) (None, 16, 16, 32) 4608 ['activation_10[0][0]'] batch_normalization_12 (BatchN (None, 16, 16, 32) 128 ['conv2d_12[0][0]'] ormalization) activation_11 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_12[0][0]'] conv2d_11 (Conv2D) (None, 16, 16, 32) 512 ['activation_10[0][0]'] conv2d_13 (Conv2D) (None, 16, 16, 32) 9216 ['activation_11[0][0]'] batch_normalization_11 (BatchN (None, 16, 16, 32) 128 ['conv2d_11[0][0]'] ormalization) batch_normalization_13 (BatchN (None, 16, 16, 32) 128 ['conv2d_13[0][0]'] ormalization) add_5 (Add) (None, 16, 16, 32) 0 ['batch_normalization_11[0][0]', 'batch_normalization_13[0][0]'] activation_12 (Activation) (None, 16, 16, 32) 0 ['add_5[0][0]'] conv2d_14 (Conv2D) (None, 16, 16, 32) 9216 ['activation_12[0][0]'] batch_normalization_14 (BatchN (None, 16, 16, 32) 128 ['conv2d_14[0][0]'] ormalization) activation_13 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_14[0][0]'] conv2d_15 (Conv2D) (None, 16, 16, 32) 9216 ['activation_13[0][0]'] batch_normalization_15 (BatchN (None, 16, 16, 32) 128 ['conv2d_15[0][0]'] ormalization) add_6 (Add) (None, 16, 16, 32) 0 ['activation_12[0][0]', 'batch_normalization_15[0][0]'] activation_14 (Activation) (None, 16, 16, 32) 0 ['add_6[0][0]'] conv2d_16 (Conv2D) (None, 16, 16, 32) 9216 ['activation_14[0][0]'] batch_normalization_16 (BatchN (None, 16, 16, 32) 128 ['conv2d_16[0][0]'] ormalization) activation_15 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_16[0][0]'] conv2d_17 (Conv2D) (None, 16, 16, 32) 9216 ['activation_15[0][0]'] batch_normalization_17 (BatchN (None, 16, 16, 32) 128 ['conv2d_17[0][0]'] ormalization) add_7 (Add) (None, 16, 16, 32) 0 ['activation_14[0][0]', 'batch_normalization_17[0][0]'] activation_16 (Activation) (None, 16, 16, 32) 0 ['add_7[0][0]'] conv2d_18 (Conv2D) (None, 16, 16, 32) 9216 ['activation_16[0][0]'] batch_normalization_18 (BatchN (None, 16, 16, 32) 128 ['conv2d_18[0][0]'] ormalization) activation_17 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_18[0][0]'] conv2d_19 (Conv2D) (None, 16, 16, 32) 9216 ['activation_17[0][0]'] batch_normalization_19 (BatchN (None, 16, 16, 32) 128 ['conv2d_19[0][0]'] ormalization) add_8 (Add) (None, 16, 16, 32) 0 ['activation_16[0][0]', 'batch_normalization_19[0][0]'] activation_18 (Activation) (None, 16, 16, 32) 0 ['add_8[0][0]'] conv2d_20 (Conv2D) (None, 16, 16, 32) 9216 ['activation_18[0][0]'] batch_normalization_20 (BatchN (None, 16, 16, 32) 128 ['conv2d_20[0][0]'] ormalization) activation_19 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_20[0][0]'] conv2d_21 (Conv2D) (None, 16, 16, 32) 9216 ['activation_19[0][0]'] batch_normalization_21 (BatchN (None, 16, 16, 32) 128 ['conv2d_21[0][0]'] ormalization) add_9 (Add) (None, 16, 16, 32) 0 ['activation_18[0][0]', 'batch_normalization_21[0][0]'] activation_20 (Activation) (None, 16, 16, 32) 0 ['add_9[0][0]'] conv2d_23 (Conv2D) (None, 8, 8, 64) 18432 ['activation_20[0][0]'] batch_normalization_23 (BatchN (None, 8, 8, 64) 256 ['conv2d_23[0][0]'] ormalization) activation_21 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_23[0][0]'] conv2d_22 (Conv2D) (None, 8, 8, 64) 2048 ['activation_20[0][0]'] conv2d_24 (Conv2D) (None, 8, 8, 64) 36864 ['activation_21[0][0]'] batch_normalization_22 (BatchN (None, 8, 8, 64) 256 ['conv2d_22[0][0]'] ormalization) batch_normalization_24 (BatchN (None, 8, 8, 64) 256 ['conv2d_24[0][0]'] ormalization) add_10 (Add) (None, 8, 8, 64) 0 ['batch_normalization_22[0][0]', 'batch_normalization_24[0][0]'] activation_22 (Activation) (None, 8, 8, 64) 0 ['add_10[0][0]'] conv2d_25 (Conv2D) (None, 8, 8, 64) 36864 ['activation_22[0][0]'] batch_normalization_25 (BatchN (None, 8, 8, 64) 256 ['conv2d_25[0][0]'] ormalization) activation_23 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_25[0][0]'] conv2d_26 (Conv2D) (None, 8, 8, 64) 36864 ['activation_23[0][0]'] batch_normalization_26 (BatchN (None, 8, 8, 64) 256 ['conv2d_26[0][0]'] ormalization) add_11 (Add) (None, 8, 8, 64) 0 ['activation_22[0][0]', 'batch_normalization_26[0][0]'] activation_24 (Activation) (None, 8, 8, 64) 0 ['add_11[0][0]'] conv2d_27 (Conv2D) (None, 8, 8, 64) 36864 ['activation_24[0][0]'] batch_normalization_27 (BatchN (None, 8, 8, 64) 256 ['conv2d_27[0][0]'] ormalization) activation_25 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_27[0][0]'] conv2d_28 (Conv2D) (None, 8, 8, 64) 36864 ['activation_25[0][0]'] batch_normalization_28 (BatchN (None, 8, 8, 64) 256 ['conv2d_28[0][0]'] ormalization) add_12 (Add) (None, 8, 8, 64) 0 ['activation_24[0][0]', 'batch_normalization_28[0][0]'] activation_26 (Activation) (None, 8, 8, 64) 0 ['add_12[0][0]'] conv2d_29 (Conv2D) (None, 8, 8, 64) 36864 ['activation_26[0][0]'] batch_normalization_29 (BatchN (None, 8, 8, 64) 256 ['conv2d_29[0][0]'] ormalization) activation_27 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_29[0][0]'] conv2d_30 (Conv2D) (None, 8, 8, 64) 36864 ['activation_27[0][0]'] batch_normalization_30 (BatchN (None, 8, 8, 64) 256 ['conv2d_30[0][0]'] ormalization) add_13 (Add) (None, 8, 8, 64) 0 ['activation_26[0][0]', 'batch_normalization_30[0][0]'] activation_28 (Activation) (None, 8, 8, 64) 0 ['add_13[0][0]'] conv2d_31 (Conv2D) (None, 8, 8, 64) 36864 ['activation_28[0][0]'] batch_normalization_31 (BatchN (None, 8, 8, 64) 256 ['conv2d_31[0][0]'] ormalization) activation_29 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_31[0][0]'] conv2d_32 (Conv2D) (None, 8, 8, 64) 36864 ['activation_29[0][0]'] batch_normalization_32 (BatchN (None, 8, 8, 64) 256 ['conv2d_32[0][0]'] ormalization) add_14 (Add) (None, 8, 8, 64) 0 ['activation_28[0][0]', 'batch_normalization_32[0][0]'] activation_30 (Activation) (None, 8, 8, 64) 0 ['add_14[0][0]'] average_pooling2d (AveragePool (None, 1, 1, 64) 0 ['activation_30[0][0]'] ing2D) flatten (Flatten) (None, 64) 0 ['average_pooling2d[0][0]'] dense (Dense) (None, 10) 650 ['flatten[0][0]'] ================================================================================================== Total params: 469,370 Trainable params: 466,906 Non-trainable params: 2,464 __________________________________________________________________________________________________
model_img = output_dir + '/cifar10_%s.png'%(model_name) # 模型结构图保存路径
plot_model(model, to_file=model_img, show_shapes=True) # 模型结构保存为一张图片
print('%s已保存' % model_img)
./output/cifar10_resnet.png已保存
首先我们可以设置我们的迭代次数和batch_size
epochs = 20 # 迭代次数
batch_size = 128 # 批大小
这一部分是设置在训练的时候的一些参数
checkpoint = ModelCheckpoint(output_dir + '/best_%s_simple.h5'%model_name, # model filename
monitor='val_loss', # quantity to monitor
verbose=0, # verbosity - 0 or 1
save_best_only= True, # The latest best model will not be overwritten
mode='auto') # The decision to overwrite model is made
# automatically depending on the quantity to monitor
接下来我们就可以定义我们的优化器和损失函数了,keras很简单,并且定义我们需要计算的metrics为准确率即可
adam = adam_v2.Adam(lr = 0.001)
model.compile(loss = 'categorical_crossentropy', optimizer = adam, metrics = ['accuracy'])
最后我们使用内置的fit函数,并且加上我们所需要的超参数,就可以完成我们的训练了。
history = model.fit(x_train,y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_val,y_val),
shuffle=True,
callbacks=[checkpoint])
Epoch 1/20 391/391 [==============================] - 25s 48ms/step - loss: 1.5136 - accuracy: 0.4872 - val_loss: 2.3308 - val_accuracy: 0.3384 Epoch 2/20 391/391 [==============================] - 18s 46ms/step - loss: 1.0520 - accuracy: 0.6656 - val_loss: 1.3796 - val_accuracy: 0.5948 Epoch 3/20 391/391 [==============================] - 18s 46ms/step - loss: 0.8567 - accuracy: 0.7406 - val_loss: 1.0479 - val_accuracy: 0.6941 Epoch 4/20 391/391 [==============================] - 17s 45ms/step - loss: 0.7438 - accuracy: 0.7839 - val_loss: 1.4804 - val_accuracy: 0.5939 Epoch 5/20 391/391 [==============================] - 18s 47ms/step - loss: 0.6655 - accuracy: 0.8118 - val_loss: 1.0215 - val_accuracy: 0.7029 Epoch 6/20 391/391 [==============================] - 18s 46ms/step - loss: 0.6030 - accuracy: 0.8339 - val_loss: 1.3131 - val_accuracy: 0.6538 Epoch 7/20 391/391 [==============================] - 18s 47ms/step - loss: 0.5612 - accuracy: 0.8501 - val_loss: 0.9786 - val_accuracy: 0.7368 Epoch 8/20 391/391 [==============================] - 18s 45ms/step - loss: 0.5128 - accuracy: 0.8705 - val_loss: 1.2962 - val_accuracy: 0.6735 Epoch 9/20 391/391 [==============================] - 18s 45ms/step - loss: 0.4838 - accuracy: 0.8817 - val_loss: 1.5207 - val_accuracy: 0.6318 Epoch 10/20 391/391 [==============================] - 18s 47ms/step - loss: 0.4482 - accuracy: 0.8953 - val_loss: 0.9014 - val_accuracy: 0.7674 Epoch 11/20 391/391 [==============================] - 18s 45ms/step - loss: 0.4241 - accuracy: 0.9053 - val_loss: 0.9986 - val_accuracy: 0.7666 Epoch 12/20 391/391 [==============================] - 18s 45ms/step - loss: 0.3995 - accuracy: 0.9154 - val_loss: 1.0747 - val_accuracy: 0.7399 Epoch 13/20 391/391 [==============================] - 17s 45ms/step - loss: 0.3755 - accuracy: 0.9243 - val_loss: 1.4094 - val_accuracy: 0.7106 Epoch 14/20 391/391 [==============================] - 18s 45ms/step - loss: 0.3634 - accuracy: 0.9308 - val_loss: 1.3048 - val_accuracy: 0.7187 Epoch 15/20 391/391 [==============================] - 18s 46ms/step - loss: 0.3418 - accuracy: 0.9394 - val_loss: 1.1310 - val_accuracy: 0.7498 Epoch 16/20 391/391 [==============================] - 18s 45ms/step - loss: 0.3392 - accuracy: 0.9413 - val_loss: 1.1636 - val_accuracy: 0.7490 Epoch 17/20 391/391 [==============================] - 18s 45ms/step - loss: 0.3301 - accuracy: 0.9456 - val_loss: 1.6518 - val_accuracy: 0.6921 Epoch 18/20 391/391 [==============================] - 18s 45ms/step - loss: 0.3227 - accuracy: 0.9495 - val_loss: 1.2451 - val_accuracy: 0.7381 Epoch 19/20 391/391 [==============================] - 18s 45ms/step - loss: 0.3123 - accuracy: 0.9547 - val_loss: 1.2867 - val_accuracy: 0.7464 Epoch 20/20 391/391 [==============================] - 18s 45ms/step - loss: 0.3129 - accuracy: 0.9546 - val_loss: 1.6354 - val_accuracy: 0.6954
def plot_model_history(model_history): fig, axs = plt.subplots(1,2,figsize=(15,5)) # summarize history for accuracy axs[0].plot(range(1,len(model_history.history['accuracy'])+1),model_history.history['accuracy']) axs[0].plot(range(1,len(model_history.history['val_accuracy'])+1),model_history.history['val_accuracy']) axs[0].set_title('Model Accuracy') axs[0].set_ylabel('Accuracy') axs[0].set_xlabel('Epoch') axs[0].legend(['train', 'val'], loc='best') # summarize history for loss axs[1].plot(range(1,len(model_history.history['loss'])+1),model_history.history['loss']) axs[1].plot(range(1,len(model_history.history['val_loss'])+1),model_history.history['val_loss']) axs[1].set_title('Model Loss') axs[1].set_ylabel('Loss') axs[1].set_xlabel('Epoch') axs[1].legend(['train', 'val'], loc='best') plt.show()
plot_model_history(history)
model_path = output_dir + '/keras_cifar10_%s_model.h5'%model_name
model.save(model_path)
print('%s已保存' % model_path)
./output/keras_cifar10_resnet_model.h5已保存
# 取验证集里面的图片拿来预测看看 name = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'} n = 20 # 取多少张图片 x_test = x_val[:n] y_test = y_val[:n] # 预测 y_predict = model.predict(x_test, batch_size=n) # 绘制预测结果 plt.figure(figsize=(18, 3)) # 指定画布大小 for i in range(n): plt.subplot(2, 10, i + 1) plt.axis('off') # 取消x,y轴坐标 plt.imshow(x_test[i]) # 显示图片 if y_test[i].argmax() == y_predict[i].argmax(): # 预测正确,用绿色标题 plt.title('%s,%s' % (name[y_test[i].argmax()], name[y_predict[i].argmax()]), color='green') else: # 预测错误,用红色标题 plt.title('%s,%s' % (name[y_test[i].argmax()], name[y_predict[i].argmax()]), color='red') predict_img = output_dir + '/predict_%s.png'%(model_name) print('%s已保存' % predict_img) plt.savefig(predict_img) # 保存预测图片 plt.show() # 显示画布
./output/predict_resnet.png已保存
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jdVm1j3p-1668299839132)(./img/Keras%20CIFAR-10%E5%88%86%E7%B1%BB%EF%BC%88ResNet%EF%BC%89_41_1.png)]
res = model.evaluate(x_test,y_test)
print('{:2f}%'.format(res[1]*100))
1/1 [==============================] - 0s 54ms/step - loss: 0.8094 - accuracy: 0.7500
75.000000%
res = model.evaluate(x_val,y_val)
print('{:2f}%'.format(res[1]*100))
313/313 [==============================] - 3s 11ms/step - loss: 1.1061 - accuracy: 0.7747
77.469999%
除了用原图片进行训练之外,我们还有一种方式可以增加准确性,也就是数据增强。下面来介绍一下数据增强
数据增强(Data Augmentation)是一种通过让有限的数据产生更多的等价数据来人工扩展训练数据集的技术。它是克服训练数据不足的有效手段,目前在深度学习的各个领域中应用广泛。但是由于生成的数据与真实数据之间的差异,也不可避免地带来了噪声问题。深度神经网络在许多任务中表现良好,但这些网络通常需要大量数据才能避免过度拟合。遗憾的是,许多场景无法获得大量数据,数据增强技术的存在是为了解决这个问题,这是针对有限数据问题的解决方案。数据增强一套技术,可提高训练数据集的大小和质量,以便您可以使用它们来构建更好的深度学习模型。
计算视觉领域的数据增强
计算视觉领域的数据增强算法大致可以分为两类:第一类是基于基本图像处理技术的数据增强,第二个类别是基于深度学习的数据增强算法。
下面先介绍基于基本图像处理技术的数据增强方法:
第二个类别是基于深度学习的数据增强算法:
from keras.preprocessing.image import ImageDataGenerator from keras.callbacks import ReduceLROnPlateau from keras.callbacks import LearningRateScheduler # fit data with data augmentation or not data_augmentation = True # def lr_scheduler(epoch): # return lr * (0.1 ** (epoch // 50)) # reduce_lr = LearningRateScheduler(lr_scheduler) reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.93, patience=1, min_lr=1e-6, verbose=1) checkpoint = ModelCheckpoint(output_dir + '/best_%s_data_augmentation.h5'%model_name, # model filename monitor='val_loss', # quantity to monitor verbose=0, # verbosity - 0 or 1 save_best_only= True, # The latest best model will not be overwritten mode='auto') # The decision to overwrite model is made # automatically depending on the quantity to monitor batch_size = 64 epochs = 30
from keras.optimizers import gradient_descent_v2
weight_decay = 1e-4
model = ResNet32ForCIFAR10(input_shape=(32, 32, 3), classes=num_classes, weight_decay=weight_decay)
adam = gradient_descent_v2.SGD(lr = 0.1, momentum=0.9, nesterov=True)
model.compile(loss = 'categorical_crossentropy', optimizer = adam, metrics = ['accuracy'])
model.summary()
Model: "resnet32" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_2 (InputLayer) [(None, 32, 32, 3)] 0 [] conv2d_33 (Conv2D) (None, 32, 32, 16) 432 ['input_2[0][0]'] batch_normalization_33 (BatchN (None, 32, 32, 16) 64 ['conv2d_33[0][0]'] ormalization) activation_31 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_33[0][0]'] conv2d_34 (Conv2D) (None, 32, 32, 16) 2304 ['activation_31[0][0]'] batch_normalization_34 (BatchN (None, 32, 32, 16) 64 ['conv2d_34[0][0]'] ormalization) activation_32 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_34[0][0]'] conv2d_35 (Conv2D) (None, 32, 32, 16) 2304 ['activation_32[0][0]'] batch_normalization_35 (BatchN (None, 32, 32, 16) 64 ['conv2d_35[0][0]'] ormalization) add_15 (Add) (None, 32, 32, 16) 0 ['activation_31[0][0]', 'batch_normalization_35[0][0]'] activation_33 (Activation) (None, 32, 32, 16) 0 ['add_15[0][0]'] conv2d_36 (Conv2D) (None, 32, 32, 16) 2304 ['activation_33[0][0]'] batch_normalization_36 (BatchN (None, 32, 32, 16) 64 ['conv2d_36[0][0]'] ormalization) activation_34 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_36[0][0]'] conv2d_37 (Conv2D) (None, 32, 32, 16) 2304 ['activation_34[0][0]'] batch_normalization_37 (BatchN (None, 32, 32, 16) 64 ['conv2d_37[0][0]'] ormalization) add_16 (Add) (None, 32, 32, 16) 0 ['activation_33[0][0]', 'batch_normalization_37[0][0]'] activation_35 (Activation) (None, 32, 32, 16) 0 ['add_16[0][0]'] conv2d_38 (Conv2D) (None, 32, 32, 16) 2304 ['activation_35[0][0]'] batch_normalization_38 (BatchN (None, 32, 32, 16) 64 ['conv2d_38[0][0]'] ormalization) activation_36 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_38[0][0]'] conv2d_39 (Conv2D) (None, 32, 32, 16) 2304 ['activation_36[0][0]'] batch_normalization_39 (BatchN (None, 32, 32, 16) 64 ['conv2d_39[0][0]'] ormalization) add_17 (Add) (None, 32, 32, 16) 0 ['activation_35[0][0]', 'batch_normalization_39[0][0]'] activation_37 (Activation) (None, 32, 32, 16) 0 ['add_17[0][0]'] conv2d_40 (Conv2D) (None, 32, 32, 16) 2304 ['activation_37[0][0]'] batch_normalization_40 (BatchN (None, 32, 32, 16) 64 ['conv2d_40[0][0]'] ormalization) activation_38 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_40[0][0]'] conv2d_41 (Conv2D) (None, 32, 32, 16) 2304 ['activation_38[0][0]'] batch_normalization_41 (BatchN (None, 32, 32, 16) 64 ['conv2d_41[0][0]'] ormalization) add_18 (Add) (None, 32, 32, 16) 0 ['activation_37[0][0]', 'batch_normalization_41[0][0]'] activation_39 (Activation) (None, 32, 32, 16) 0 ['add_18[0][0]'] conv2d_42 (Conv2D) (None, 32, 32, 16) 2304 ['activation_39[0][0]'] batch_normalization_42 (BatchN (None, 32, 32, 16) 64 ['conv2d_42[0][0]'] ormalization) activation_40 (Activation) (None, 32, 32, 16) 0 ['batch_normalization_42[0][0]'] conv2d_43 (Conv2D) (None, 32, 32, 16) 2304 ['activation_40[0][0]'] batch_normalization_43 (BatchN (None, 32, 32, 16) 64 ['conv2d_43[0][0]'] ormalization) add_19 (Add) (None, 32, 32, 16) 0 ['activation_39[0][0]', 'batch_normalization_43[0][0]'] activation_41 (Activation) (None, 32, 32, 16) 0 ['add_19[0][0]'] conv2d_45 (Conv2D) (None, 16, 16, 32) 4608 ['activation_41[0][0]'] batch_normalization_45 (BatchN (None, 16, 16, 32) 128 ['conv2d_45[0][0]'] ormalization) activation_42 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_45[0][0]'] conv2d_44 (Conv2D) (None, 16, 16, 32) 512 ['activation_41[0][0]'] conv2d_46 (Conv2D) (None, 16, 16, 32) 9216 ['activation_42[0][0]'] batch_normalization_44 (BatchN (None, 16, 16, 32) 128 ['conv2d_44[0][0]'] ormalization) batch_normalization_46 (BatchN (None, 16, 16, 32) 128 ['conv2d_46[0][0]'] ormalization) add_20 (Add) (None, 16, 16, 32) 0 ['batch_normalization_44[0][0]', 'batch_normalization_46[0][0]'] activation_43 (Activation) (None, 16, 16, 32) 0 ['add_20[0][0]'] conv2d_47 (Conv2D) (None, 16, 16, 32) 9216 ['activation_43[0][0]'] batch_normalization_47 (BatchN (None, 16, 16, 32) 128 ['conv2d_47[0][0]'] ormalization) activation_44 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_47[0][0]'] conv2d_48 (Conv2D) (None, 16, 16, 32) 9216 ['activation_44[0][0]'] batch_normalization_48 (BatchN (None, 16, 16, 32) 128 ['conv2d_48[0][0]'] ormalization) add_21 (Add) (None, 16, 16, 32) 0 ['activation_43[0][0]', 'batch_normalization_48[0][0]'] activation_45 (Activation) (None, 16, 16, 32) 0 ['add_21[0][0]'] conv2d_49 (Conv2D) (None, 16, 16, 32) 9216 ['activation_45[0][0]'] batch_normalization_49 (BatchN (None, 16, 16, 32) 128 ['conv2d_49[0][0]'] ormalization) activation_46 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_49[0][0]'] conv2d_50 (Conv2D) (None, 16, 16, 32) 9216 ['activation_46[0][0]'] batch_normalization_50 (BatchN (None, 16, 16, 32) 128 ['conv2d_50[0][0]'] ormalization) add_22 (Add) (None, 16, 16, 32) 0 ['activation_45[0][0]', 'batch_normalization_50[0][0]'] activation_47 (Activation) (None, 16, 16, 32) 0 ['add_22[0][0]'] conv2d_51 (Conv2D) (None, 16, 16, 32) 9216 ['activation_47[0][0]'] batch_normalization_51 (BatchN (None, 16, 16, 32) 128 ['conv2d_51[0][0]'] ormalization) activation_48 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_51[0][0]'] conv2d_52 (Conv2D) (None, 16, 16, 32) 9216 ['activation_48[0][0]'] batch_normalization_52 (BatchN (None, 16, 16, 32) 128 ['conv2d_52[0][0]'] ormalization) add_23 (Add) (None, 16, 16, 32) 0 ['activation_47[0][0]', 'batch_normalization_52[0][0]'] activation_49 (Activation) (None, 16, 16, 32) 0 ['add_23[0][0]'] conv2d_53 (Conv2D) (None, 16, 16, 32) 9216 ['activation_49[0][0]'] batch_normalization_53 (BatchN (None, 16, 16, 32) 128 ['conv2d_53[0][0]'] ormalization) activation_50 (Activation) (None, 16, 16, 32) 0 ['batch_normalization_53[0][0]'] conv2d_54 (Conv2D) (None, 16, 16, 32) 9216 ['activation_50[0][0]'] batch_normalization_54 (BatchN (None, 16, 16, 32) 128 ['conv2d_54[0][0]'] ormalization) add_24 (Add) (None, 16, 16, 32) 0 ['activation_49[0][0]', 'batch_normalization_54[0][0]'] activation_51 (Activation) (None, 16, 16, 32) 0 ['add_24[0][0]'] conv2d_56 (Conv2D) (None, 8, 8, 64) 18432 ['activation_51[0][0]'] batch_normalization_56 (BatchN (None, 8, 8, 64) 256 ['conv2d_56[0][0]'] ormalization) activation_52 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_56[0][0]'] conv2d_55 (Conv2D) (None, 8, 8, 64) 2048 ['activation_51[0][0]'] conv2d_57 (Conv2D) (None, 8, 8, 64) 36864 ['activation_52[0][0]'] batch_normalization_55 (BatchN (None, 8, 8, 64) 256 ['conv2d_55[0][0]'] ormalization) batch_normalization_57 (BatchN (None, 8, 8, 64) 256 ['conv2d_57[0][0]'] ormalization) add_25 (Add) (None, 8, 8, 64) 0 ['batch_normalization_55[0][0]', 'batch_normalization_57[0][0]'] activation_53 (Activation) (None, 8, 8, 64) 0 ['add_25[0][0]'] conv2d_58 (Conv2D) (None, 8, 8, 64) 36864 ['activation_53[0][0]'] batch_normalization_58 (BatchN (None, 8, 8, 64) 256 ['conv2d_58[0][0]'] ormalization) activation_54 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_58[0][0]'] conv2d_59 (Conv2D) (None, 8, 8, 64) 36864 ['activation_54[0][0]'] batch_normalization_59 (BatchN (None, 8, 8, 64) 256 ['conv2d_59[0][0]'] ormalization) add_26 (Add) (None, 8, 8, 64) 0 ['activation_53[0][0]', 'batch_normalization_59[0][0]'] activation_55 (Activation) (None, 8, 8, 64) 0 ['add_26[0][0]'] conv2d_60 (Conv2D) (None, 8, 8, 64) 36864 ['activation_55[0][0]'] batch_normalization_60 (BatchN (None, 8, 8, 64) 256 ['conv2d_60[0][0]'] ormalization) activation_56 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_60[0][0]'] conv2d_61 (Conv2D) (None, 8, 8, 64) 36864 ['activation_56[0][0]'] batch_normalization_61 (BatchN (None, 8, 8, 64) 256 ['conv2d_61[0][0]'] ormalization) add_27 (Add) (None, 8, 8, 64) 0 ['activation_55[0][0]', 'batch_normalization_61[0][0]'] activation_57 (Activation) (None, 8, 8, 64) 0 ['add_27[0][0]'] conv2d_62 (Conv2D) (None, 8, 8, 64) 36864 ['activation_57[0][0]'] batch_normalization_62 (BatchN (None, 8, 8, 64) 256 ['conv2d_62[0][0]'] ormalization) activation_58 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_62[0][0]'] conv2d_63 (Conv2D) (None, 8, 8, 64) 36864 ['activation_58[0][0]'] batch_normalization_63 (BatchN (None, 8, 8, 64) 256 ['conv2d_63[0][0]'] ormalization) add_28 (Add) (None, 8, 8, 64) 0 ['activation_57[0][0]', 'batch_normalization_63[0][0]'] activation_59 (Activation) (None, 8, 8, 64) 0 ['add_28[0][0]'] conv2d_64 (Conv2D) (None, 8, 8, 64) 36864 ['activation_59[0][0]'] batch_normalization_64 (BatchN (None, 8, 8, 64) 256 ['conv2d_64[0][0]'] ormalization) activation_60 (Activation) (None, 8, 8, 64) 0 ['batch_normalization_64[0][0]'] conv2d_65 (Conv2D) (None, 8, 8, 64) 36864 ['activation_60[0][0]'] batch_normalization_65 (BatchN (None, 8, 8, 64) 256 ['conv2d_65[0][0]'] ormalization) add_29 (Add) (None, 8, 8, 64) 0 ['activation_59[0][0]', 'batch_normalization_65[0][0]'] activation_61 (Activation) (None, 8, 8, 64) 0 ['add_29[0][0]'] average_pooling2d_1 (AveragePo (None, 1, 1, 64) 0 ['activation_61[0][0]'] oling2D) flatten_1 (Flatten) (None, 64) 0 ['average_pooling2d_1[0][0]'] dense_1 (Dense) (None, 10) 650 ['flatten_1[0][0]'] ================================================================================================== Total params: 469,370 Trainable params: 466,906 Non-trainable params: 2,464 __________________________________________________________________________________________________
这段代码我们使用的是基于基本图像处理的数据增强,我们设置了一些,比如roattion_range也就是旋转的角度,以及左右偏移大概0.1,以及水平翻转等,这些都是可以在我们的ImageDataGenerator中进行设置
%%time if data_augmentation: # datagen datagen = ImageDataGenerator( featurewise_center=False, # set input mean to 0 over the dataset samplewise_center=False, # set each sample mean to 0 featurewise_std_normalization=False, # divide inputs by std of the dataset samplewise_std_normalization=False, # divide each input by its std zca_whitening=False, # apply ZCA whitening rotation_range=15, # randomly rotate images in the range (degrees, 0 to 180) width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) height_shift_range=0.1, # randomly shift images vertically (fraction of total height) horizontal_flip=True, # randomly flip images vertical_flip=False, # randomly flip images ) # (std, mean, and principal components if ZCA whitening is applied). datagen.fit(x_train) print('train with data augmentation') history = model.fit_generator(generator=datagen.flow(x_train, y_train, batch_size=batch_size), epochs=epochs, callbacks=[reduce_lr, checkpoint], validation_data=(x_val, y_val) ) else: print('train without data augmentation') history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, callbacks=[reduce_lr], validation_data=(x_val, y_val) )
782/782 [==============================] - 38s 44ms/step - loss: 1.8886 - accuracy: 0.3671 - val_loss: 1.7345 - val_accuracy: 0.4173 - lr: 0.1000 Epoch 2/30 782/782 [==============================] - 34s 43ms/step - loss: 1.4087 - accuracy: 0.5443 - val_loss: 1.4261 - val_accuracy: 0.5444 - lr: 0.1000 Epoch 3/30 782/782 [==============================] - ETA: 0s - loss: 1.1361 - accuracy: 0.6509 Epoch 00003: ReduceLROnPlateau reducing learning rate to 0.093000001385808. 782/782 [==============================] - 34s 44ms/step - loss: 1.1361 - accuracy: 0.6509 - val_loss: 1.5838 - val_accuracy: 0.5590 - lr: 0.1000 Epoch 4/30 782/782 [==============================] - 34s 44ms/step - loss: 0.9685 - accuracy: 0.7187 - val_loss: 1.0326 - val_accuracy: 0.7028 - lr: 0.0930 Epoch 5/30 781/782 [============================>.] - ETA: 0s - loss: 0.8858 - accuracy: 0.7521 Epoch 00005: ReduceLROnPlateau reducing learning rate to 0.08649000205099583. 782/782 [==============================] - 33s 43ms/step - loss: 0.8859 - accuracy: 0.7520 - val_loss: 1.5229 - val_accuracy: 0.5927 - lr: 0.0930 Epoch 6/30 782/782 [==============================] - ETA: 0s - loss: 0.8307 - accuracy: 0.7734 Epoch 00006: ReduceLROnPlateau reducing learning rate to 0.0804357048869133. 782/782 [==============================] - 34s 44ms/step - loss: 0.8307 - accuracy: 0.7734 - val_loss: 1.0997 - val_accuracy: 0.7028 - lr: 0.0865 Epoch 7/30 782/782 [==============================] - ETA: 0s - loss: 0.7843 - accuracy: 0.7906 Epoch 00007: ReduceLROnPlateau reducing learning rate to 0.07480520859360695. 782/782 [==============================] - 33s 43ms/step - loss: 0.7843 - accuracy: 0.7906 - val_loss: 1.0626 - val_accuracy: 0.7179 - lr: 0.0804 Epoch 8/30 781/782 [============================>.] - ETA: 0s - loss: 0.7455 - accuracy: 0.8035 Epoch 00008: ReduceLROnPlateau reducing learning rate to 0.06956884302198887. 782/782 [==============================] - 33s 43ms/step - loss: 0.7454 - accuracy: 0.8035 - val_loss: 1.3238 - val_accuracy: 0.6693 - lr: 0.0748 Epoch 9/30 782/782 [==============================] - 34s 43ms/step - loss: 0.7134 - accuracy: 0.8177 - val_loss: 0.9464 - val_accuracy: 0.7547 - lr: 0.0696 Epoch 10/30 782/782 [==============================] - 34s 43ms/step - loss: 0.6992 - accuracy: 0.8234 - val_loss: 0.9212 - val_accuracy: 0.7511 - lr: 0.0696 Epoch 11/30 782/782 [==============================] - 34s 43ms/step - loss: 0.6823 - accuracy: 0.8302 - val_loss: 0.8771 - val_accuracy: 0.7730 - lr: 0.0696 Epoch 12/30 781/782 [============================>.] - ETA: 0s - loss: 0.6782 - accuracy: 0.8307 Epoch 00012: ReduceLROnPlateau reducing learning rate to 0.06469902366399766. 782/782 [==============================] - 33s 42ms/step - loss: 0.6782 - accuracy: 0.8307 - val_loss: 0.9172 - val_accuracy: 0.7625 - lr: 0.0696 Epoch 13/30 782/782 [==============================] - 36s 46ms/step - loss: 0.6558 - accuracy: 0.8412 - val_loss: 0.7392 - val_accuracy: 0.8187 - lr: 0.0647 Epoch 14/30 782/782 [==============================] - ETA: 0s - loss: 0.6469 - accuracy: 0.8444 Epoch 00014: ReduceLROnPlateau reducing learning rate to 0.06017009228467941. 782/782 [==============================] - 33s 43ms/step - loss: 0.6469 - accuracy: 0.8444 - val_loss: 0.9919 - val_accuracy: 0.7541 - lr: 0.0647 Epoch 15/30 781/782 [============================>.] - ETA: 0s - loss: 0.6270 - accuracy: 0.8516 Epoch 00015: ReduceLROnPlateau reducing learning rate to 0.05595818527042866. 782/782 [==============================] - 35s 45ms/step - loss: 0.6269 - accuracy: 0.8516 - val_loss: 0.7825 - val_accuracy: 0.8096 - lr: 0.0602 Epoch 16/30 782/782 [==============================] - ETA: 0s - loss: 0.6099 - accuracy: 0.8560 Epoch 00016: ReduceLROnPlateau reducing learning rate to 0.05204111237078905. 782/782 [==============================] - 33s 43ms/step - loss: 0.6099 - accuracy: 0.8560 - val_loss: 0.8544 - val_accuracy: 0.7907 - lr: 0.0560 Epoch 17/30 781/782 [============================>.] - ETA: 0s - loss: 0.5900 - accuracy: 0.8624 Epoch 00017: ReduceLROnPlateau reducing learning rate to 0.04839823544025421. 782/782 [==============================] - 34s 43ms/step - loss: 0.5899 - accuracy: 0.8625 - val_loss: 0.7734 - val_accuracy: 0.8060 - lr: 0.0520 Epoch 18/30 781/782 [============================>.] - ETA: 0s - loss: 0.5732 - accuracy: 0.8675 Epoch 00018: ReduceLROnPlateau reducing learning rate to 0.04501035757362843. 782/782 [==============================] - 34s 43ms/step - loss: 0.5733 - accuracy: 0.8675 - val_loss: 0.9039 - val_accuracy: 0.7739 - lr: 0.0484 Epoch 19/30 782/782 [==============================] - 34s 43ms/step - loss: 0.5543 - accuracy: 0.8740 - val_loss: 0.6435 - val_accuracy: 0.8511 - lr: 0.0450 Epoch 20/30 781/782 [============================>.] - ETA: 0s - loss: 0.5497 - accuracy: 0.8767 Epoch 00020: ReduceLROnPlateau reducing learning rate to 0.04185963302850723. 782/782 [==============================] - 34s 43ms/step - loss: 0.5497 - accuracy: 0.8767 - val_loss: 0.8291 - val_accuracy: 0.7898 - lr: 0.0450 Epoch 21/30 781/782 [============================>.] - ETA: 0s - loss: 0.5368 - accuracy: 0.8797 Epoch 00021: ReduceLROnPlateau reducing learning rate to 0.03892945982515812. 782/782 [==============================] - 33s 43ms/step - loss: 0.5368 - accuracy: 0.8797 - val_loss: 0.6961 - val_accuracy: 0.8392 - lr: 0.0419 Epoch 22/30 781/782 [============================>.] - ETA: 0s - loss: 0.5215 - accuracy: 0.8839 Epoch 00022: ReduceLROnPlateau reducing learning rate to 0.03620439659804106. 782/782 [==============================] - 33s 42ms/step - loss: 0.5214 - accuracy: 0.8839 - val_loss: 0.7055 - val_accuracy: 0.8369 - lr: 0.0389 Epoch 23/30 782/782 [==============================] - ETA: 0s - loss: 0.5085 - accuracy: 0.8888 Epoch 00023: ReduceLROnPlateau reducing learning rate to 0.03367008984088898. 782/782 [==============================] - 34s 44ms/step - loss: 0.5085 - accuracy: 0.8888 - val_loss: 0.7107 - val_accuracy: 0.8380 - lr: 0.0362 Epoch 24/30 782/782 [==============================] - 33s 43ms/step - loss: 0.4883 - accuracy: 0.8942 - val_loss: 0.6060 - val_accuracy: 0.8619 - lr: 0.0337 Epoch 25/30 781/782 [============================>.] - ETA: 0s - loss: 0.4868 - accuracy: 0.8928 Epoch 00025: ReduceLROnPlateau reducing learning rate to 0.03131318382918835. 782/782 [==============================] - 34s 43ms/step - loss: 0.4871 - accuracy: 0.8927 - val_loss: 0.7552 - val_accuracy: 0.8338 - lr: 0.0337 Epoch 26/30 781/782 [============================>.] - ETA: 0s - loss: 0.4749 - accuracy: 0.8970 Epoch 00026: ReduceLROnPlateau reducing learning rate to 0.02912126172333956. 782/782 [==============================] - 34s 44ms/step - loss: 0.4750 - accuracy: 0.8969 - val_loss: 0.7968 - val_accuracy: 0.8160 - lr: 0.0313 Epoch 27/30 782/782 [==============================] - 33s 43ms/step - loss: 0.4645 - accuracy: 0.8998 - val_loss: 0.5906 - val_accuracy: 0.8637 - lr: 0.0291 Epoch 28/30 781/782 [============================>.] - ETA: 0s - loss: 0.4605 - accuracy: 0.9026 Epoch 00028: ReduceLROnPlateau reducing learning rate to 0.027082772813737395. 782/782 [==============================] - 33s 42ms/step - loss: 0.4606 - accuracy: 0.9026 - val_loss: 0.5991 - val_accuracy: 0.8682 - lr: 0.0291 Epoch 29/30 782/782 [==============================] - ETA: 0s - loss: 0.4414 - accuracy: 0.9069 Epoch 00029: ReduceLROnPlateau reducing learning rate to 0.025186978820711376. 782/782 [==============================] - 34s 44ms/step - loss: 0.4414 - accuracy: 0.9069 - val_loss: 0.7109 - val_accuracy: 0.8343 - lr: 0.0271 Epoch 30/30 782/782 [==============================] - ETA: 0s - loss: 0.4339 - accuracy: 0.9096 Epoch 00030: ReduceLROnPlateau reducing learning rate to 0.02342388980090618. 782/782 [==============================] - 33s 43ms/step - loss: 0.4339 - accuracy: 0.9096 - val_loss: 0.6589 - val_accuracy: 0.8427 - lr: 0.0252 CPU times: user 27min 58s, sys: 29.5 s, total: 28min 28s Wall time: 17min 21s
plot_model_history(history)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eBZvonnc-1668299839133)(C:\Users\86137\AppData\Roaming\Typora\typora-user-images\image-20221113083648203.png)]
从结果可以看出,使用了数据增强之后,我们的结果是比没有进行数据增强是好的,已经达到了84.2%+的准确率,如果设置好数据增强的参数,很有可能可以得到更高的准确率,数据增强还是对结果有比较大的影响的,并且也更稳定
loss,acc = model.evaluate(x_val,y_val)
print('evaluate loss:%f acc:%f' % (loss, acc))
313/313 [==============================] - 3s 11ms/step - loss: 0.6589 - accuracy: 0.8427
evaluate loss:0.658943 acc:0.842700
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。