当前位置:   article > 正文

Tensorflow Object Detection API 代码阅读(零)_tensorflow官方后处理

tensorflow官方后处理

0 前言

这一系列的文章主要是对 Tensorflow Object Detection API (下简称为 TF-OD-API),它属于 tensorflow/models 仓库的一部分,也是基于tensorflow“官方”的检测代码仓库。虽然相比于进来很火的基于pytorchmmdetection等检测代码仓库,这个仓库的“知名度”比较低,毕竟众多detection的大牛都在facebook,基于pytorch的检测仓库受欢迎也是意料之中。但是在某一些部署的场景中,可能tensorflow会更加友好,虽然它的模型训练过程会比较“艰辛”。

本人实际使用的过程中是基于 tf 1.13.1,本系列文章是基于2019年6月份 a68f65f 版本的代码进行。

1 正文

本文是此系列文章的第一篇,作为“开篇之作”,本文会顺着训练的流程,介绍 TF-OD-API 的整体结构 (以单阶段的SSD模型为例)。在后续的文章中,会具体地介绍其中的每一个部分。关于如何制作tfrecords,如何开始训练,如何定义自己的模型等,在 models/research/object_detection/g3doc 中有详细的文档说明。

TF-OD-API 采用的是使用tf.estimator.Estimator来管理整个训练的过程,所以在介绍整体的流程之前,我们需要先了解一下使用tf.estimator.Estimator来管理训练的几个关键的步骤:

  1. 构建输入函数 (input_fn):input_fn函数的作用是为estimator构建输入,其返回的是一个tf.data.Dataset的实例,且返回的tf.data.Dataset实例的每一个元素应该包含两个组成部分,第一个组成部分是features, (作为后面的model_fn的features参数),第二个组成部分是labels (作为后面的model_fn的labels参数),features和labels可以是单个tensor,也可以是以tensor作为值的字典 (在TF-OD-API中是后者,因为检测任务使用到的信息比较多,所以使用字典来存储多个信息)。我们可以通过tf.data.Datasetmap方法来将其实例的每一个元素处理成由features和labels两部分组成的形式;
  2. 构建模型函数 (model_fn):model_fn函数的作用是为estimator构建模型,其具体的形式如下面代码所示,其中的参数features和labels由上述input_fn返回的tf.data.Dataset实例提供。这个函数内部主要定义了从features和labels到模型输出,loss计算,模型优化。最后使用tf.estimator.EstimatorSpec()对模型的loss以及训练的op进行封装,并作为返回值。这样当estimator拿到tf.estimator.EstimatorSpec实例的时候,就可以掌握待训练模型的模型信息了;
    def my_model(
       features, # This is batch_features from input_fn
       labels,   # This is batch_labels from input_fn
       mode,     # An instance of tf.estimator.ModeKeys
       params):  # Additional configuration
       
       ...
       
       return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
  3. 构建tf.estimator.Estimator实例:该实例的作用管理训练的过程,我们可以定义tf.estimator.Estimator实例如下,其中model_fn是上面定义的模型函数,config是一个tf.estimator.RunConfig的实例,用于管理训练过程一些相关的信息,比如:模型存放路径、GPU使用设置、ckpt保存周期等等;
    estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
    
    • 1
  4. (optional) 对input_fn进一步封装:这一步的作用是通过进一步封装来设置迭代次数以及钩子函数。可以通过tf.estimator.TrainSpec或者tf.estimator.EvalSpecinput_fn进行封装 (在TF-OD-API中是有进行进一步封装的);
  5. 开启训练过程:这一步的作用是基于上述定义的estimator,开启模型的训练。一般是通过tf.estimator.train_and_evaluate, tf.estimator.train等函数来实现;

了解了如何使用tf.estimator.Estimator来管理整个训练过程之后,我们可以开始进入正题,来看一看 TF-OD-API 整个训练的流程。从训练的入口脚本 model_main.py 开始,大致流程如下图所示:

image

下面对整个流程进行介绍:
model_main.py作为入口,在其中主要干了三件事情:(1) 创建estimator和各种input_fn;(2) 对input_fn进行进一步封装;(3) 开启训练过程。有了上面对tf.estimator.Estimator管理训练过程的了解,这里的三个步骤就很好理解了,对应的就是在搭建基于tf.estimator.Estimator的训练流程。这其中内容最多的当然是第一部分,即创建estimator和各种input_fn。下面对三个部分进行介绍:

  1. 创建estimator和各种input_fn (model_lib.create_estimator_and_inputs()):此函数主要可以分为四个部分:(1) 解析参数配置文件;(2) 创建各种input_fn;(3) 创建model_fn;(4) 创建estimator。每一部分具体如下:
    1. 解析参数配置文件 (get_configs_from_pipeline_file()&merge_external_params_with_configs())。这是google的仓库,所以当然是使用自家的Protobuf来进行参数管理,其中,get_configs_from_pipeline_file()返回的是一个包含5个键的字典 (‘model’, ‘train_config’, ‘train_input_config’, ‘eval_config’, ‘eval_input_configs’),值是对应的proto实例,分别对应参数配置文件 (xxx.config)中的五个部分。merge_external_params_with_configs()根据输入参数对解析参数做进一步调整。解析出来的参数将用于后面的步骤;
    2. 创建各种input_fn (create_train/eval/predict_input_fn())。创建input_fn的函数的思路都是类似的,这里以create_train_input_fn()作为例子。其中使用train_input()函数来得到tf.data.Dataset类实例,而其中又主要是用到了dataset_builder.build()函数,主要的流程如下 (这一部分内容比较多,后面文章再进行详细介绍,下面是主要的思路):
      1. 创建以tfrecords作为输入的Dataset实例。使用tf.data.TFRecordDataset构建基于tfrecord的dataset;
      2. 对tfrecords中的元素进行解析。使用slim模块对tfrecord的元素进行decode,主要是TfExampleDecoder类,其decode方法返回的是一个字典 (tensor_dict);
      3. 数据处理和features&labels构建。使用transform_and_pad_input_data_fn()函数对上一步解码出来的tensor_dict进一步处理,主要就是进行数据处理,并将Dataset的每一个元素构建成features和labels两个部分;
      4. 添加batch维度。使用dataset.apply(tf.contrib.data.batch_and_drop_remainder())函数为Dataset的每一个元素添加batch维度;
    3. 创建model_fn (model_fn_creator())。model_fn_creator()中定义了model_fn(),其主要包括两个部分:(1) 检测模型构建;(2) 基于检测模型获取训练所需的tensor和op (同样,这一部分的内容很多,下面只是大概的流程,具体的细节再后面的文章进行介绍);
      1. 检测模型构建 (detection_model_fn())。基于detection_model_fn() (实际上就是model_builder.build()函数)来构建一个SSDMetaArch实例 (如果是双阶段,则是FasterRCNNMetaArch)。其中又包含了对 feature_extractor, box_coder, matcher, anchor_generator, box_predictor, target_assigner等等组件的构建 (这些在后续的文章中再进行介绍)。这里主要清楚SSDMetaArch是整个检测模型的类即可,其包含了preprocess(), predict(), postprocess(), loss()等方法,可以对于检测器相关的一些tensor进行获取;
      2. 基于检测模型获取训练所需的tensor和op。基于上面得到的SSDMetaArch实例,对各种tensor和op进行获取,主要是训练的loss以及优化网络的train_op,最后如上所述,使用tf.estimator.EstimatorSpec()进行封装;
    4. 创建estimator (tf.estimator.Estimator())。使用上面构建好的model_fn进行tf.estimator.Estimator实例的创建。
  2. input_fn进行进一步封装 (model_lib.create_train_and_eval_specs()):这一步和下一步基本上就是上面介绍的,这里就不再赘述;
  3. 开启训练过程 (tf.estimator.train_and_evaluate())。

以上就是顺着训练的流程,对 TF-OD-API 的整体介绍。本文主要是对整体思路的介绍,一些细节上的东西,如参数文件如何解析,如何构建input_fnSSDMetaArch类内部的具体组件等等,在后续的文章中进行介绍。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/123763
推荐阅读
相关标签
  

闽ICP备14008679号