当前位置:   article > 正文

TensorFlow 使用预训练模型 ResNet-50_resnet50如何集成到tensorflow

resnet50如何集成到tensorflow

 

 

        升级版见:TensorFlow 使用 tf.estimator 训练模型(预训练 ResNet-50)

        前面的文章已经说明了怎么使用 TensorFlow 来构建、训练、保存、导出模型等,现在来说明怎么使用 TensorFlow 调用预训练模型来精调神经网络。为了简单起见,以调用预训练的 ResNet-50 用于图像分类为例,使用的模块仍然是 tf.contrib.slim

        TensorFlow 的所有用于图像分类的预训练模型的下载地址为 models/research/slim,包含常用的 VGG,Inception,ResNet,MobileNet 以及最新的 NasNet 模型等。要使用这些预训练模型的关键是将这些预训练的参数正确的导入到定义好的神经网络,这可以通过函数 slim.assign_from_checkpoint_fn 来方便的实现。下面,用代码来说明。

        所有代码见 GitHub/finetune_classification

一、Fine tuning 模型定义

        前已提及,TensorFlow 所有预训练模型均在 GitHub 项目 models/research/slim,而其对应的神经网络实现则在其子文件夹 nets。我们以调用 ResNet-50 为例(其它模型类似),首先来定义网络结构:

  1. import tensorflow as tf
  2. from tensorflow.contrib.slim import nets
  3. slim = tf.contrib.slim
  4. def predict(self, preprocessed_inputs):
  5. """Predict prediction tensors from inputs tensor.
  6. Outputs of this function can be passed to loss or postprocess functions.
  7. Args:
  8. preprocessed_inputs: A float32 tensor with shape [batch_size,
  9. height, width, num_channels] representing a batch of images.
  10. Returns:
  11. prediction_dict: A dictionary holding prediction tensors to be
  12. passed to the Loss or Postprocess functions.
  13. """
  14. net, endpoints = nets.resnet_v1.resnet_v1_50(
  15. preprocessed_inputs, num_classes=None,
  16. is_training=self._is_training)
  17. net = tf.squeeze(net, axis=[1, 2])
  18. net = slim.fully_connected(net, num_outputs=self.num_classes,
  19. activation_fn=None, scope='Predict')
  20. prediction_dict = {'logits': net}
  21. return prediction_dict

        我们假设要分类的图像有 self.num_classes 个类,随机选择一个批量的图像,对这些图像进行预处理后,把它们作为参数传入 predict 函数,此时直接调用 TensorFlow-slim 封装好的 nets.resnet_v1.resnet_v1_50 神经网络得到图像特征,因为 ResNet-50是用于 1000 个类的分类的,所以需要设置参数 num_classes=None 禁用它的最后一个输出层。我们假设输入的图像批量形状为 [None, 224, 224, 3],则 resnet_v1_50 函数返回的形状为 [None, 1, 1, 2048],为了输入到全连接层,需要用函数 tf.squeeze 去掉形状为 1 的第 1,2 个索引维度。最后,连接再一个全连接层得到 self.num_classes 个类的预测输出。

        可以看到,使用 tf.contrib.slim 模块,调用 ResNet-50 等神经网络变得异常简单。而接下来的关键问题是怎么导入预训练的参数,进而使用我们自己的数据来对预训练模型进行精调。在阐述怎么解决这个问题之前,先将整个模型定义的文件 model.py 列出以方便阅读:

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Thu Oct 11 17:21:12 2018
  4. @author: shirhe-lyh
  5. """
  6. import tensorflow as tf
  7. from tensorflow.contrib.slim import nets
  8. import preprocessing
  9. slim = tf.contrib.slim
  10. class Model(object):
  11. """xxx definition."""
  12. def __init__(self, num_classes, is_training,
  13. fixed_resize_side=368,
  14. default_image_size=336):
  15. """Constructor.
  16. Args:
  17. is_training: A boolean indicating whether the training version of
  18. computation graph should be constructed.
  19. num_classes: Number of classes.
  20. """
  21. self._num_classes = num_classes
  22. self._is_training = is_training
  23. self._fixed_resize_side = fixed_resize_side
  24. self._default_image_size = default_image_size
  25. @property
  26. def num_classes(self):
  27. return self._num_classes
  28. def preprocess(self, inputs):
  29. """preprocessing.
  30. Outputs of this function can be passed to loss or postprocess functions.
  31. Args:
  32. preprocessed_inputs: A float32 tensor with shape [batch_size,
  33. height, width, num_channels] representing a batch of images.
  34. Returns:
  35. prediction_dict: A dictionary holding prediction tensors to be
  36. passed to the Loss or Postprocess functions.
  37. """
  38. preprocessed_inputs = preprocessing.preprocess_images(
  39. inputs, self._default_image_size, self._default_image_size,
  40. resize_side_min=self._fixed_resize_side,
  41. is_training=self._is_training,
  42. border_expand=True, normalize=False,
  43. preserving_aspect_ratio_resize=False)
  44. preprocessed_inputs = tf.cast(preprocessed_inputs, tf.float32)
  45. return preprocessed_inputs
  46. def predict(self, preprocessed_inputs):
  47. """Predict prediction tensors from inputs tensor.
  48. Outputs of this function can be passed to loss or postprocess functions.
  49. Args:
  50. preprocessed_inputs: A float32 tensor with shape [batch_size,
  51. height, width, num_channels] representing a batch of images.
  52. Returns:
  53. prediction_dict: A dictionary holding prediction tensors to be
  54. passed to the Loss or Postprocess functions.
  55. """
  56. with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
  57. net, endpoints = nets.resnet_v1.resnet_v1_50(
  58. preprocessed_inputs, num_classes=None,
  59. is_training=self._is_training)
  60. net = tf.squeeze(net, axis=[1, 2])
  61. logits = slim.fully_connected(net, num_outputs=self.num_classes,
  62. activation_fn=None, scope='Predict')
  63. prediction_dict = {'logits': logits}
  64. return prediction_dict
  65. def postprocess(self, prediction_dict):
  66. """Convert predicted output tensors to final forms.
  67. Args:
  68. prediction_dict: A dictionary holding prediction tensors.
  69. **params: Additional keyword arguments for specific implementations
  70. of specified models.
  71. Returns:
  72. A dictionary containing the postprocessed results.
  73. """
  74. logits = prediction_dict['logits']
  75. logits = tf.nn.softmax(logits)
  76. classes = tf.argmax(logits, axis=1)
  77. postprocessed_dict = {'logits': logits,
  78. 'classes': classes}
  79. return postprocessed_dict
  80. def loss(self, prediction_dict, groundtruth_lists):
  81. """Compute scalar loss tensors with respect to provided groundtruth.
  82. Args:
  83. prediction_dict: A dictionary holding prediction tensors.
  84. groundtruth_lists_dict: A dict of tensors holding groundtruth
  85. information, with one entry for each image in the batch.
  86. Returns:
  87. A dictionary mapping strings (loss names) to scalar tensors
  88. representing loss values.
  89. """
  90. logits = prediction_dict['logits']
  91. slim.losses.sparse_softmax_cross_entropy(
  92. logits=logits,
  93. labels=groundtruth_lists,
  94. scope='Loss')
  95. loss = slim.losses.get_total_loss()
  96. loss_dict = {'loss': loss}
  97. return loss_dict
  98. def accuracy(self, postprocessed_dict, groundtruth_lists):
  99. """Calculate accuracy.
  100. Args:
  101. postprocessed_dict: A dictionary containing the postprocessed
  102. results
  103. groundtruth_lists: A dict of tensors holding groundtruth
  104. information, with one entry for each image in the batch.
  105. Returns:
  106. accuracy: The scalar accuracy.
  107. """
  108. classes = postprocessed_dict['classes']
  109. accuracy = tf.reduce_mean(
  110. tf.cast(tf.equal(classes, groundtruth_lists), dtype=tf.float32))
  111. return accuracy

二、预训练模型导入

        要将预训练模型 ResNet-50 的参数导入到前面定义好的模型,需要继续借助 tf.contrib.slim 模块,而且方法很简单,只需要在训练函数 slim.learning.train 中指定初始化参数来源函数 init_fn 即可,而这可以通过函数

  1. slim.assign_from_checkpoint_fn(model_path, var_list,
  2. ignore_missing_vars=False,
  3. reshape_variables=False)

很方便的实现。其中,第一个参数 model_path 指定预训练模型 xxx.ckpt 文件的路径,第二个参数 var_list 指定需要导入对应预训练参数的所有变量,通过函数

  1. slim.get_variables_to_restore(include=None,
  2. exclude=None)

可以快速指定,如果需要排除一些变量,也就是如果想让某些变量随机初始化而不是直接使用预训练模型来初始化,则直接在参数 exclude 中指定即可。第三个参数ignore_missing_vars 非常重要,一定要将其设置为 True,也就是说,一定要忽略那些在定义的模型结构中可能存在的而在预训练模型中没有的变量,因为如果自己定义的模型结构中存在一个参数,而这些参数在预训练模型文件 xxx.ckpt 中没有,那么如果不忽略的话,就会导入失败(这样的变量很多,比如卷积层的偏置项 bias,一般预训练模型中没有,所以需要忽略,即使用默认的零初始化)。最后一个参数 reshape_variabels指定对某些变量进行变形,这个一般用不到,使用默认的 False 即可。

        有了以上的基础,而且你还阅读过上一篇文章 TensorFlow-slim 训练 CNN 分类模型(续) 的话,那么整个使用预训练模型的训练文件 train.py 就很容易写出了,如下(重点在函数 get_init_fn):

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Thu Oct 11 17:21:35 2018
  4. @author: shirhe-lyh
  5. """
  6. """Train a CNN classification model via pretrained ResNet-50 model.
  7. Example Usage:
  8. ---------------
  9. python3 train.py \
  10. --checkpoint_path: Path to pretrained ResNet-50 model.
  11. --record_path: Path to training tfrecord file.
  12. --logdir: Path to log directory.
  13. """
  14. import os
  15. import tensorflow as tf
  16. import model
  17. import preprocessing
  18. slim = tf.contrib.slim
  19. flags = tf.app.flags
  20. flags.DEFINE_string('record_path',
  21. '/data2/raycloud/jingxiong_datasets/AIChanllenger/' +
  22. 'AgriculturalDisease_trainingset/train.record',
  23. 'Path to training tfrecord file.')
  24. flags.DEFINE_string('checkpoint_path',
  25. '/home/jingxiong/python_project/model_zoo/' +
  26. 'resnet_v1_50.ckpt',
  27. 'Path to pretrained ResNet-50 model.')
  28. flags.DEFINE_string('logdir', './training', 'Path to log directory.')
  29. flags.DEFINE_float('learning_rate', 0.0001, 'Initial learning rate.')
  30. flags.DEFINE_float(
  31. 'learning_rate_decay_factor', 0.1, 'Learning rate decay factor.')
  32. flags.DEFINE_float(
  33. 'num_epochs_per_decay', 3.0,
  34. 'Number of epochs after which learning rate decays. Note: this flag counts '
  35. 'epochs per clone but aggregates per sync replicas. So 1.0 means that '
  36. 'each clone will go over full epoch individually, but replicas will go '
  37. 'once across all replicas.')
  38. flags.DEFINE_integer('num_samples', 32739, 'Number of samples.')
  39. flags.DEFINE_integer('num_steps', 10000, 'Number of steps.')
  40. flags.DEFINE_integer('batch_size', 48, 'Batch size')
  41. FLAGS = flags.FLAGS
  42. def get_record_dataset(record_path,
  43. reader=None,
  44. num_samples=50000,
  45. num_classes=7):
  46. """Get a tensorflow record file.
  47. Args:
  48. """
  49. if not reader:
  50. reader = tf.TFRecordReader
  51. keys_to_features = {
  52. 'image/encoded':
  53. tf.FixedLenFeature((), tf.string, default_value=''),
  54. 'image/format':
  55. tf.FixedLenFeature((), tf.string, default_value='jpeg'),
  56. 'image/class/label':
  57. tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1],
  58. dtype=tf.int64))}
  59. items_to_handlers = {
  60. 'image': slim.tfexample_decoder.Image(image_key='image/encoded',
  61. format_key='image/format'),
  62. 'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
  63. decoder = slim.tfexample_decoder.TFExampleDecoder(
  64. keys_to_features, items_to_handlers)
  65. labels_to_names = None
  66. items_to_descriptions = {
  67. 'image': 'An image with shape image_shape.',
  68. 'label': 'A single integer.'}
  69. return slim.dataset.Dataset(
  70. data_sources=record_path,
  71. reader=reader,
  72. decoder=decoder,
  73. num_samples=num_samples,
  74. num_classes=num_classes,
  75. items_to_descriptions=items_to_descriptions,
  76. labels_to_names=labels_to_names)
  77. def configure_learning_rate(num_samples_per_epoch, global_step):
  78. """Configures the learning rate.
  79. Modified from:
  80. https://github.com/tensorflow/models/blob/master/research/slim/
  81. train_image_classifier.py
  82. Args:
  83. num_samples_per_epoch: he number of samples in each epoch of training.
  84. global_step: The global_step tensor.
  85. Returns:
  86. A `Tensor` representing the learning rate.
  87. """
  88. decay_steps = int(num_samples_per_epoch * FLAGS.num_epochs_per_decay /
  89. FLAGS.batch_size)
  90. return tf.train.exponential_decay(FLAGS.learning_rate,
  91. global_step,
  92. decay_steps,
  93. FLAGS.learning_rate_decay_factor,
  94. staircase=True,
  95. name='exponential_decay_learning_rate')
  96. def get_init_fn():
  97. """Returns a function run by che chief worker to warm-start the training.
  98. Modified from:
  99. https://github.com/tensorflow/models/blob/master/research/slim/
  100. train_image_classifier.py
  101. Note that the init_fn is only run when initializing the model during the
  102. very first global step.
  103. Returns:
  104. An init function run by the supervisor.
  105. """
  106. if FLAGS.checkpoint_path is None:
  107. return None
  108. # Warn the user if a checkpoint exists in the train_dir. Then we'll be
  109. # ignoring the checkpoint anyway.
  110. if tf.train.latest_checkpoint(FLAGS.logdir):
  111. tf.logging.info(
  112. 'Ignoring --checkpoint_path because a checkpoint already exists ' +
  113. 'in %s' % FLAGS.logdir)
  114. return None
  115. if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
  116. checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
  117. else:
  118. checkpoint_path = FLAGS.checkpoint_path
  119. tf.logging.info('Fine-tuning from %s' % checkpoint_path)
  120. variables_to_restore = slim.get_variables_to_restore()
  121. return slim.assign_from_checkpoint_fn(
  122. checkpoint_path,
  123. variables_to_restore,
  124. ignore_missing_vars=True)
  125. def get_trainable_variables(checkpoint_exclude_scopes=None):
  126. """Return the trainable variables.
  127. Args:
  128. checkpoint_exclude_scopes: Comma-separated list of scopes of variables
  129. to exclude when restoring from a checkpoint.
  130. Returns:
  131. The trainable variables.
  132. """
  133. exclusions = []
  134. if checkpoint_exclude_scopes:
  135. exclusions = [scope.strip() for scope in
  136. checkpoint_exclude_scopes.split(',')]
  137. variables_to_train = []
  138. for var in tf.trainable_variables():
  139. excluded = False
  140. for exclusion in exclusions:
  141. if var.op.name.startswith(exclusion):
  142. excluded = True
  143. if not excluded:
  144. variables_to_train.append(var)
  145. return variables_to_train
  146. def main(_):
  147. # Specify which gpu to be used
  148. os.environ["CUDA_VISIBLE_DEVICES"] = '1'
  149. num_samples = FLAGS.num_samples
  150. dataset = get_record_dataset(FLAGS.record_path, num_samples=num_samples,
  151. num_classes=61)
  152. data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
  153. image, label = data_provider.get(['image', 'label'])
  154. # Border expand and resize
  155. image = preprocessing.border_expand(image, resize=True, output_height=368,
  156. output_width=368)
  157. inputs, labels = tf.train.batch([image, label],
  158. batch_size=FLAGS.batch_size,
  159. #capacity=5*FLAGS.batch_size,
  160. allow_smaller_final_batch=True)
  161. cls_model = model.Model(is_training=True, num_classes=61)
  162. preprocessed_inputs = cls_model.preprocess(inputs)
  163. prediction_dict = cls_model.predict(preprocessed_inputs)
  164. loss_dict = cls_model.loss(prediction_dict, labels)
  165. loss = loss_dict['loss']
  166. postprocessed_dict = cls_model.postprocess(prediction_dict)
  167. acc = cls_model.accuracy(postprocessed_dict, labels)
  168. tf.summary.scalar('loss', loss)
  169. tf.summary.scalar('accuracy', acc)
  170. global_step = slim.create_global_step()
  171. learning_rate = configure_learning_rate(num_samples, global_step)
  172. optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
  173. momentum=0.9)
  174. # optimizer = tf.train.AdamOptimizer(learning_rate=0.00001)
  175. vars_to_train = get_trainable_variables()
  176. train_op = slim.learning.create_train_op(loss, optimizer,
  177. summarize_gradients=True,
  178. variables_to_train=vars_to_train)
  179. tf.summary.scalar('learning_rate', learning_rate)
  180. init_fn = get_init_fn()
  181. slim.learning.train(train_op=train_op, logdir=FLAGS.logdir,
  182. init_fn=init_fn, number_of_steps=FLAGS.num_steps,
  183. save_summaries_secs=20,
  184. save_interval_secs=600)
  185. if __name__ == '__main__':
  186. tf.app.run()

        函数 get_init_fn 从指定路径下读取预训练模型。如果没有指定预训练模型路径(FLAGS.checkpoint_path),则返回 None(表示随机初始化参数)。如果在训练路径下(FLAGS.logdir)已经保存过训练后的模型,也返回 None(即忽略预训练模型参数,而使用最后训练保存下来的模型初始化参数)。函数 get_trainable_variables 的作用是获取需要训练的变量,它默认返回所有可训练的变量。当你需要冻结一些层,让这些层的参数不更新时,通过参数 checkpoint_exclude_scopes 指定,比如我想让 ResNet-50 的 block1 和 block2/unit_1 冻结时,通过:

  1. scopes_to_freeze = 'resnet_v1_50/block1,resnet_v1_50/block2/unit_1'
  2. vars_to_train = get_trainable_variables(scopes_to_freeze )

调用即可。

三、数据集以及训练

        本文 GitHub/finetune_classification 上的代码默认使用 AI Challenger 全球AI挑战赛/农作物病害检测 数据集。下载好数据集之后,执行如下指令:

  1. $ python3 generate_tfrecord.py \
  2. --images_dir Path/to/AgriculturalDisease_trainingset/images \
  3. --annotation_path Path/to/AgriculturalDisease_train_annotations.json \
  4. --output_path Path/to/train.record

将训练集图像写入到 train.record 文件中。之后,执行:

  1. $ python3 train.py \
  2. --record_path Path/to/train.record \
  3. --checkpoint_path Path/to/pretrained_ResNet-50_model/resnet_v1_50.ckpt

开始训练。训练开始之后,会在当前 train.py 路径下生成一个文件夹 training 用来保存训练模型。需要额外说明的是,训练过程不会在终端输出准确率、损失等数据,需要在终端执行:

$ tensorboard --logdir Path/to/training

之后,打开返回的 http 链接在浏览器查看准确率、损失等训练曲线(训练过程中,训练结束后都可查看)。训练正常启动后,每 10 分钟会保存一次模型到 training 文件夹(诸如 model.ckpt-xxx 之类的文件),你可以选择使用其中的 model.ckpt-xxx 模型来直接进行预测,也可以选择将 model.ckpt-xxx 转化为 .pb 文件之后再进行预测,如果选择转化,执行:

  1. $ python3 export_inference_graph.py \
  2. --trained_checkpoint_prefix Path/to/model.ckpt-xxx \
  3. --output_directory Path/to/exported_pb_file_directory

之后,在指定的输出路径下(Path/to/exported_pb_file_directory)会生成一个文件夹,该文件内的 frozen_inference_graph.pb 即是转化成的固化模型文件(固化指的是所有参数都转化成了常数)。之后就可以使用 evaluate.py 或者 predict.py 进行验证或预测了。

        如果你使用其它数据集,整个训练过程和上面的步骤一样,只需要根据具体的标注文件来修改文件data_provider.py 中函数 provide,该函数返回一个字典,其中 key 代表训练数据集中图像的路径,value 代表图像对应的类标号;其它参数,比如训练图像个数,类别数目,学习率等,在 train.py 中修改。

预告:下一篇文章将要介绍如何用 TensorFlow 来训练多任务多标签模型,敬请期待!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/92791
推荐阅读
相关标签
  

闽ICP备14008679号