当前位置:   article > 正文

TensorFlow Hub迁移学习_feature_extractor_layer = tf.keras.sequential([hub

feature_extractor_layer = tf.keras.sequential([hub.keraslayer(feature_extrac

参考:https://tensorflow.google.cn/tutorials/images/transfer_learning_with_hub

安装模块

import matplotlib.pylab as plt

!pip install -q tf-nightly
import tensorflow as tf

  • 1
  • 2
  • 3
  • 4
  • 5

如果报错:ERROR: tensorflow 2.1.0 has requirement gast==0.2.2, but you’ll have gast 0.3.3

!pip install -q -U tf-hub-nightly
!pip install -q tfds-nightly
import tensorflow_hub as hub

from tensorflow.keras import layers
  • 1
  • 2
  • 3
  • 4
  • 5

下载图像分类模型

任何来自hub.tensorflow.google.cn的兼容于TensorFlow 2的图像分类器URL都可以运行。

classifier_url ="https://hub.tensorflow.google.cn/google/tf2-preview/mobilenet_v2/classification/2" #@param {type:"string"}

IMAGE_SHAPE = (224, 224)

classifier = tf.keras.Sequential([
    hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE+(3,))
])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

单张图片试运行

import numpy as np
import PIL.Image as Image
# 下载图片
grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
# 打开图片
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
# 查看图片
grace_hopper

# 图片转换为0-1的值
grace_hopper = np.array(grace_hopper)/255.0
grace_hopper.shape

# 添加批处理维度,并将图像传递给模型。
result = classifier.predict(grace_hopper[np.newaxis, ...])
result.shape
# 结果是一个logits的1001元素向量,对图像的每个类的概率进行评级。
# 查找概率最高的项id
predicted_class = np.argmax(result[0], axis=-1)
predicted_class
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

解码预测

根据id对应标签

# 下载读取标签
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
# 图片与标签对应并显示
plt.imshow(grace_hopper)
plt.axis('off')
predicted_class_name = imagenet_labels[predicted_class]
_ = plt.title("Prediction: " + predicted_class_name.title())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

迁移训练

下载数据集

下载花朵的数据集

data_root = tf.keras.utils.get_file(
  'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
   untar=True)
  • 1
  • 2
  • 3

使用加载图片数据tf.keras.preprocessing.image.ImageDataGenerator

TensorFlow Hub的所有图像模型输入格式为[0,1]。使用ImageDataGenerator的rescale参数进行转换。

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
image_data = image_generator.flow_from_directory(str(data_root), target_size=IMAGE_SHAPE)

#结果对象是一个迭代器,返回image_batch、label_batch对。
for image_batch, label_batch in image_data:
  print("Image batch shape: ", image_batch.shape)
  print("Label batch shape: ", label_batch.shape)
  break
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

运行模型,输出对应标签

# 喂入数据运行分类器,输出标签id
result_batch = classifier.predict(image_batch)
result_batch.shape

# id对应标签,完成预测
predicted_class_names = imagenet_labels[np.argmax(result_batch, axis=-1)]
predicted_class_names
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

查看预测结果


# 查看打印结果
 plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predicted_class_names[n])
  plt.axis('off')
_ = plt.suptitle("ImageNet predictions")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

无顶层分类层

模型地址

feature_extractor_url = "https://hub.tensorflow.google.cn/google/tf2-preview/mobilenet_v2/feature_vector/2" #@param {type:"string"}
  • 1

创建特征提取器

feature_extractor_layer = hub.KerasLayer(feature_extractor_url,
                                         input_shape=(224,224,3))

feature_batch = feature_extractor_layer(image_batch)
#(32, 1280) 它为每张图像返回一个长度为1280的向量:
print(feature_batch.shape)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

冻结特征提取器层中的变量,使训练只修改新的分类器层。

feature_extractor_layer.trainable = False
  • 1

添加新的分类层

使用tf.keras.Sequential(),添加新的分类层

model = tf.keras.Sequential([
  feature_extractor_layer,
  layers.Dense(image_data.num_classes)
])

model.summary()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

模型预测结果形状

predictions = model(image_batch)
predictions.shape
  • 1
  • 2

模型训练

配置训练过程,优化器,损失函数等

model.compile(
  optimizer=tf.keras.optimizers.Adam(),
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
  metrics=['acc'])
  • 1
  • 2
  • 3
  • 4

为了可视化训练过程,使用自定义回调来分别记录每个批处理的损失和准确度。

class CollectBatchStats(tf.keras.callbacks.Callback):
  def __init__(self):
    self.batch_losses = []
    self.batch_acc = []

  def on_train_batch_end(self, batch, logs=None):
    self.batch_losses.append(logs['loss'])
    self.batch_acc.append(logs['acc'])
    self.model.reset_metrics()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
steps_per_epoch = np.ceil(image_data.samples/image_data.batch_size)

batch_stats_callback = CollectBatchStats()

history = model.fit_generator(image_data, epochs=2,
                              steps_per_epoch=steps_per_epoch,
                              callbacks = [batch_stats_callback])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

打印loss、acc值

plt.figure()
plt.ylabel("Loss")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(batch_stats_callback.batch_losses)

plt.figure()
plt.ylabel("Accuracy")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(batch_stats_callback.batch_acc)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

定义有序列表

class_names = sorted(image_data.class_indices.items(), key=lambda pair:pair[1])
class_names = np.array([key.title() for key, value in class_names])
class_names
  • 1
  • 2
  • 3

通过模型运行图像批处理并将索引转换为类名

predicted_batch = model.predict(image_batch)
predicted_id = np.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]
  • 1
  • 2
  • 3

打印结果

label_id = np.argmax(label_batch, axis=-1)

plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  color = "green" if predicted_id[n] == label_id[n] else "red"
  plt.title(predicted_label_batch[n].title(), color=color)
  plt.axis('off')
_ = plt.suptitle("Model predictions (green: correct, red: incorrect)")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

保存模型

import time
t = time.time()

export_path = "/tmp/saved_models/{}".format(int(t))
model.save(export_path, save_format='tf')

export_path
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

加载模型

reloaded = tf.keras.models.load_model(export_path)

result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)

abs(reloaded_result_batch - result_batch).max()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/532007
推荐阅读
  

闽ICP备14008679号