赞
踩
/
2. 数据集以及迁移需求
数据集是某场景下5个类别图片的识别
我们利用现有的VGG模型去进行微调
- train_datagen = ImageDataGenerator(
- rescale=1./255,
- shear_range=0.2,
- zoom_range=0.2,
- horizontal_flip=True)
- test_datagen = ImageDataGenerator(rescale=1./255)
- train_generator = train_datagen.flow_from_directory(
- 'data/train',
- target_size=(150, 150),
- batch_size=32,
- class_mode='binary')
- validation_generator = test_datagen.flow_from_directory(
- 'data/validation',
- target_size=(150, 150),
- batch_size=32,
- class_mode='binary')
- 使用fit_generator
- model.fit_generator(
- train_generator,
- steps_per_epoch=2000,
- epochs=50,
- validation_data=validation_generator,
- validation_steps=800)

- def __init__(self):
- # 定义训练和测试图片的变化方式,标准化以及数据增强
- self.train_generator = ImageDataGenerator(rescale=1.0 / 255)
- self.test_generator = ImageDataGenerator(rescale=1.0 / 255)
- # 指定训练集和测试集的目录
- self.train_dir = "./data/train"
- self.test_dir = "./data/test"
- # 定义图片训练的相关网络参数
- self.image_size = (224, 224)
- self.batch_size = 32
- def get_local_data(self):
- """
- 读取本地的图片数据以及类别
- :return:训练数据和测试数据迭代器
- """
- # 使用flow_from_derectory
- train_gen = self.train_generator.flow_from_directory(self.train_dir,
- target_size=self.image_size,
- batch_size=self.batch_size,
- class_mode='binary',
- shuffle=True)
- test_gen = self.test_generator.flow_from_directory(self.test_dir,
- target_size=self.image_size,
- batch_size=self.batch_size,
- class_mode='binary',
- shuffle=True)
-
-
- return train_gen, test_gen
'运行
# 定义迁移学习的基类模型 # 不包含VGG当中3个全连接层的模型加载并且加载了参数 self.batch_model = VGG16(weights='imagenet', include_top=False)
模型源码:
- if include_top:
- # Classification block
- x = layers.Flatten(name='flatten')(x)
- x = layers.Dense(4096, activation='relu', name='fc1')(x)
- x = layers.Dense(4096, activation='relu', name='fc2')(x)
- x = layers.Dense(classes, activation='softmax', name='predictions')(x)
- else:
- if pooling == 'avg':
- x = layers.GlobalAveragePooling2D()(x)
- elif pooling == 'max':
- x = layers.GlobalMaxPooling2D()(x)
- def refine_vgg_model(self):
-
- 添加尾部全连接层
- :return:
-
- # [<tf.Tensor 'block5_pool/MaxPool:0' shape=(?, ?, ?, 512) dtype=float32>]
- x = self.base_model.outputs[0]
- # 输出到全连接层,加上全局池化 [None, ?, ?, 512]---->[None, 1 * 512]
- x = keras.layers.GlobalAveragePooling2D()(x)
- x = keras.layers.Dense(1024, activation=tf.nn.relu)(x)
- y_predict = keras.layers.Dense(5, activation=tf.nn.softmax)(x)
-
- model = keras.Model(inputs=self.base_model.inputs, outputs=y_predict)
-
- return model
keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)
- modelckpt = keras.callbacks.ModelCheckpoint('./ckpt/fine_tuning/weights.{epoch:02d}.hdf5',
- monitor='val_acc',
- save_weights_only=True,
- save_best_only=True,
- mode='auto',
- period=1)
- model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[modelckpt])
- def fit_generator(self, model, train_gen, test_gen):
- """
- 训练模型,model.fit_generator()不是选择model.fit()
- :param model:
- :param train_gen:
- :param test_gen:
- :return:
- """
- # 每一次迭代准确率记录的h5文件
- modelckpt = keras.callbacks.ModelCheckpoint('./hdf5/transfer_{epoch:02d}-{accuracy:.2f}.h5',
- monitor='accuracy',
- save_weights_only=True,
- save_best_only=True,
- mode='auto',
- period=1)
- model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[modelckpt])
- # 保存模型
- # model.save_weights("./model/Transfer.h5")
- return None
'运行
fit_generator()可以使用callbacks中的ModelCheckpoints
- if __name__ == "__main__":
- tm = TransferModel()
- train_gen, test_gen = tm.get_local_data()
- model = tm.refine_base_model()
- tm.freeeze_model()
- tm.compile(model)
- tm.fit_generator(model, train_gen, test_gen)
- def predict(self, model):
- '''
- 预测类型
- :param model:
- :return:
- '''
- # 加载模型,transfer_model
- model.load_weights("./hdf5/transfer_03-0.98.h5")
- # 读取图片, 处理
- image = load_img("./data/test/dinosaurs/400.jpg", target_size=(224, 224))
- image = img_to_array(image)
- # 三维转换为4维 (224, 224, 3)--->(1, 224, 224, 3)
- img = image.reshape([1, image.shape[0], image.shape[1], image.shape[2]])
- print(img)
- # 预测结果进行处理
- image = preprocess_input(img)
- predictions = model.predict(image)
- print(predictions)
- res = np.argmax(predictions, axis=1)
- print(self.label_dict[str(res[0])])
'运行
完整代码:
- import numpy as np
- import tensorflow as tf
- from tensorflow import keras
- from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
- from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
- tf.compat.v1.disable_eager_execution()
- class TransferModel(object):
- def __init__(self):
- # 定义训练和测试图片的变化方式,标准化以及数据增强
- self.train_generator = ImageDataGenerator(rescale=1.0 / 255)
- self.test_generator = ImageDataGenerator(rescale=1.0 / 255)
- # 指定训练集和测试集的目录
- self.train_dir = "./data/train"
- self.test_dir = "./data/test"
- # 定义图片训练的相关网络参数
- self.image_size = (224, 224)
- self.batch_size = 32
- # 定义迁移学习的基类模型
- # 不包含VGG当中3个全连接层的模型加载并且加载了参数
- self.base_model = VGG16(weights='imagenet', include_top=False)
- self.label_dict = {
- '0': 'bus',
- '1': 'dinosaurs',
- '2': 'eleplants',
- '3': 'flowers',
- '4': 'horse'
- }
- def get_local_data(self):
- """
- 读取本地的图片数据以及类别
- :return:训练数据和测试数据迭代器
- """
- # 使用flow_from_derectory
- train_gen = self.train_generator.flow_from_directory(self.train_dir,
- target_size=self.image_size,
- batch_size=self.batch_size,
- class_mode='binary',
- shuffle=True)
- test_gen = self.test_generator.flow_from_directory(self.test_dir,
- target_size=self.image_size,
- batch_size=self.batch_size,
- class_mode='binary',
- shuffle=True)
- return train_gen, test_gen
- def refine_base_model(self):
- """
- 微调VGG结构,5blocks后面+全局平均池化(减少迁移学习的参数数量)+两个全连接层
- :return:
- """
- # 1. 获取原notop模型得出
- # [?, ?, ?, 512]
- x = self.base_model.outputs[0]
- # 2。 在输出后面增加我们的结构
- # [?, ?, ?, 512]----》[?, 1*1*512]
- x = keras.layers.GlobalAveragePooling2D()(x)
- # 3. 定义新的迁移模型
- x = keras.layers.Dense(1024, activation=tf.compat.v1.nn.relu)(x)
- y_predict = keras.layers.Dense(5, activation=tf.compat.v1.nn.softmax)(x)
- # model定义新模型
- # VGG模型的输入,输出,y_predict
- transfer_model = keras.models.Model(inputs=self.base_model.inputs, outputs=y_predict)
- return transfer_model
- def freeeze_model(self):
- """
- 冻结VGG模型(5blocks)
- 冻结VGG的多少根据数据量
- :return:
- """
- # 获取所有层, 返回层的列表
- for layer in self.base_model.layers:
- layer.trainable = False
- def compile(self, model):
- '''
- 编译模型
- :return:
- '''
- model.compile(optimizer=keras.optimizers.Adam(),
- loss=keras.losses.sparse_categorical_crossentropy,
- metrics=['accuracy'])
- return None
- def fit_generator(self, model, train_gen, test_gen):
- """
- 训练模型,model.fit_generator()不是选择model.fit()
- :param model:
- :param train_gen:
- :param test_gen:
- :return:
- """
- # 每一次迭代准确率记录的h5文件
- modelckpt = keras.callbacks.ModelCheckpoint('./hdf5/transfer_{epoch:02d}-{accuracy:.2f}.h5',
- monitor='accuracy',
- save_weights_only=True,
- save_best_only=True,
- mode='auto',
- period=1)
- model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[modelckpt])
- # 保存模型
- # model.save_weights("./model/Transfer.h5")
- return None
- def predict(self, model):
- '''
- 预测类型
- :param model:
- :return:
- '''
- # 加载模型,transfer_model
- model.load_weights("./hdf5/transfer_03-0.98.h5")
- # 读取图片, 处理
- image = load_img("./data/test/dinosaurs/400.jpg", target_size=(224, 224))
- image = img_to_array(image)
- # 三维转换为4维 (224, 224, 3)--->(1, 224, 224, 3)
- img = image.reshape([1, image.shape[0], image.shape[1], image.shape[2]])
- print(img)
- # 预测结果进行处理
- image = preprocess_input(img)
- predictions = model.predict(image)
- print(predictions)
- res = np.argmax(predictions, axis=1)
- print(self.label_dict[str(res[0])])
-
- if __name__ == "__main__":
- tm = TransferModel()
- # 训练
- # train_gen, test_gen = tm.get_local_data()
- # # print(train_gen, test_gen)
- # # # for data in train_gen:
- # # # print(train_gen) print(tm.batch_model.summary())
- # model = tm.refine_base_model()
- # # print(model)
- # tm.freeeze_model()
- # tm.compile(model)
- # tm.fit_generator(model, train_gen, test_gen)
- """测试"""
- model = tm.refine_base_model()
- tm.predict(model)

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。