当前位置:   article > 正文

Keras官方文档-迁移学习和微调-Xception_keras迁移模型学习文档

keras迁移模型学习文档

Keras官方文档-迁移学习和微调
github官方模型链接

Keras文档

端到端示例:微调猫狗数据集上的图像分类模型

为了巩固这些概念,让我们为您介绍一个具体的端到端转移学习和微调示例。我们将加载在ImageNet上经过预先训练的Xception模型,并将其用于Kaggle“猫与狗”分类数据集中。

获取数据
首先,让我们使用TFDS来获取“猫与狗”数据集。如果您拥有自己的数据集,则可能要使用该实用程序 tf.keras.preprocessing.image_dataset_from_directory从磁盘上提交到特定于类的文件夹中的一组图像中生成相似的带标签的数据集对象。

当使用非常小的数据集时,转移学习最有用。为了使数据集较小,我们将使用原始训练数据的40%(25,000张图像)进行训练,将10%用于验证,将10%用于测试。

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

这些是训练数据集中的前9张图像-如您所见,它们都是不同的大小。

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

在这里插入图片描述

我们还可以看到标签1是“ dog”,标签0是“ cat”。

标准化数据

我们的原始图像有各种尺寸。另外,每个像素由0到255之间的3个整数值(RGB级别值)组成。这不太适合提供神经网络。我们需要做两件事:

标准化为固定的图像尺寸。我们选择150x150。
归一化介于-1和1之间的像素值。我们将使用Normalization图层作为模型本身的一部分来进行此操作。
通常,与采用已预处理数据的模型相反,开发以原始数据为输入的模型是一个好习惯。原因是,如果模型需要预处理的数据,则每次导出模型以在其他地方使用它(在Web浏览器,移动应用程序中)时,都需要重新实现完全相同的预处理管道。这很快就变得非常棘手。因此,在达到模型之前,我们应该进行尽可能少的预处理。

在这里,我们将在数据管道中进行图像大小调整(因为深度神经网络只能处理连续的数据批处理),并且在创建模型时将其作为模型的一部分进行输入值缩放。

注意:Xception网络需要送入之前标准化到-1和1之间

让我们将图像尺寸调整为150x150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))
  • 1
  • 2
  • 3
  • 4
  • 5

此外,让我们分批处理数据并使用缓存和预取来优化加载速度。

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)
  • 1
  • 2
  • 3
  • 4
  • 5

使用随机数据扩充data_augmentation 序列模型层

当您没有大型图像数据集时,通过对训练图像进行随机但逼真的变换(例如随机水平翻转或小的随机旋转)来人为引入样本多样性是一种很好的做法。这有助于使模型暴露于训练数据的不同方面,同时减慢过度拟合的速度。

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(0.1),
    ]
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

让我们直观地看到经过各种随机转换后的第一批图像:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[i]))
        plt.axis("off")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

在这里插入图片描述

建立模型

现在让我们建立一个遵循我们先前解释的蓝图的模型。

注意:

我们添加一个Normalization图层以将输入值(最初在该[0, 255] 范围内)缩放到该[-1, 1]范围。
我们Dropout在分类层之前添加一层以进行正则化。
我们确保training=False在调用基本模型时通过,以便它以推理模式运行,这样即使我们取消冻结基本模型以进行微调,batchnorm统计信息也不会得到更新。

base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmenta
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/寸_铁/article/detail/865682
推荐阅读
相关标签
  

闽ICP备14008679号