赞
踩
keras是目前流行的深度学习框架之一,目前已经整合到Tensorflow2.0版本中,用户通过安装Tensorflow包即可实现对Keras方便的调用。
Keras为用户提供了多种深度学习模型调用的接口,用户通过简单的编辑即可实现经典模型的调用和搭建。目前Keras提供的模型接口有如下几个:
from tensorflow.python.keras.applications.densenet import DenseNet121 from tensorflow.python.keras.applications.densenet import DenseNet169 from tensorflow.python.keras.applications.densenet import DenseNet201 from tensorflow.python.keras.applications.inception_resnet_v2 import InceptionResNetV2 from tensorflow.python.keras.applications.inception_v3 import InceptionV3 from tensorflow.python.keras.applications.mobilenet import MobileNet from tensorflow.python.keras.applications.mobilenet_v2 import MobileNetV2 from tensorflow.python.keras.applications.nasnet import NASNetLarge from tensorflow.python.keras.applications.nasnet import NASNetMobile from tensorflow.python.keras.applications.resnet import ResNet101 from tensorflow.python.keras.applications.resnet import ResNet152 from tensorflow.python.keras.applications.resnet import ResNet50 from tensorflow.python.keras.applications.resnet_v2 import ResNet101V2 from tensorflow.python.keras.applications.resnet_v2 import ResNet152V2 from tensorflow.python.keras.applications.resnet_v2 import ResNet50V2 from tensorflow.python.keras.applications.vgg16 import VGG16 from tensorflow.python.keras.applications.vgg19 import VGG19 from tensorflow.python.keras.applications.xception import Xception
我们以Resnet50为例,从头搭建一个深度学习模型,实现对Cifar10数据集的分类:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
import os
运行此代码会自动下载Cifar10数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255 #预处理
x_test = x_test.astype('float32') / 255
y_train = to_categorical(y_train, 10) #对数据集进行onehot编码
y_test = to_categorical(y_test, 10)
若想使用其他模型,只需要把下述代码中的tf.keras.applications.ResNet50改成tf.keras.applications.其他模型,即可。
def resnet50_model(): #include_top为是否包括原始Resnet50模型的全连接层,如果不需要自己定义可以设为True #不需要预训练模型可以将weights设为None resnet50=tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape=(32,32,3), ) #设置预训练模型冻结的层,可根据自己的需要自行设置 for layer in resnet50.layers[:15]: layer.trainable = False # #选择模型连接到全连接层的位置 last=resnet50.get_layer(index=30).output #建立新的全连接层 x=tf.keras.layers.Flatten(name='flatten')(last) x=tf.keras.layers.Dense(1024,activation='relu')(x) x=tf.keras.layers.Dropout(0.5)(x) x=tf.keras.layers.Dense(128,activation='relu',name='dense1')(x) x=tf.keras.layers.Dropout(0.5,name='dense_dropout')(x) x=tf.keras.layers.Dense(10,activation='softmax')(x) model = tf.keras.models.Model(inputs=resnet50.input, outputs=x) model.summary() #打印模型结构 return model
由于Cifar10数据集的输入图像较小,为了防止采样过度,我们选取resnet50的前30层,连接全连接层构成新的模型。
也可对Resnet50前几层的池化层进行修改,但是这需要对keras提供的代码进行修改,具体操作如下:
进入resnet.py
找到其中的Resnet函数,找到其中的以下代码:
x = layers.ZeroPadding2D(
padding=((3, 3), (3, 3)), name='conv1_pad')(img_input)
x = layers.Conv2D(64, 7, strides=2, use_bias=use_bias, name='conv1_conv')(x)
if not preact:
x = layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name='conv1_bn')(x)
x = layers.Activation('relu', name='conv1_relu')(x)
x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name='pool1_pad')(x)
x = layers.MaxPooling2D(3, strides=2, name='pool1_pool')(x)
依次为头层padding、卷积层、padding层、池化层,可以根据需要进行修改。
选用SGD梯度下降优化器进行训练过程的优化
model=resnet50_model()
model.compile(
loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.SGD(lr=0.1, decay=1e-4, momentum=0.9, nesterov=True),
metrics=['accuracy'])
checkpoint_save_path = "./checkpoint/resnet50_cifar.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model-------------------')
model.load_weights(checkpoint_save_path)
checkpointer = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_save_path, verbose=1, save_best_only=True,save_weights_only = True)
运行200个epochs,batch_size设为64
model.fit(
x=x_train,
y=y_train,
batch_size=64,
epochs=200,
verbose=1,
#callbacks=[checkpointer],
validation_split=0.1,
shuffle=True)
至此,完整的训练流程就构建完了,如果想对训练好的模型进行预测的话,可通过以下代码执行:
preds=model.predict(x_test)
可以通过进行数据增强、优化学习率等来进一步的改善模型的性能,将model.fit替换为如下代码,可实现相关功能:
#设置学习率,从0.1开始,每5次准确率不上升,降低一半学习率,最小下降到1e-20
lr_reducer = ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=5,
mode='max', min_lr=1e-20)
#进行图像增强
aug = ImageDataGenerator(width_shift_range=0.2, height_shift_range=0.2,
horizontal_flip=True, zoom_range=0.2)
aug.fit(x_train)
gen = aug.flow(x_train, y_train, batch_size=32)
model.fit_generator(generator=gen, epochs=200, validation_data=(x_test, y_test),
verbose=1,callbacks=[lr_reducer,checkpointer])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。