当前位置:   article > 正文

tf.estimator基本用法_tf.contrib.tpu.tpuestimator

tf.contrib.tpu.tpuestimator

参考了官方文档,有些细节问题需要阅读tf源码。

  • 如何定义模型model_fn()?需要满足什么格式?官方文档写的还算清楚,
def model_fn(features, labels, mode, params)
  • 1

其中features和labels是input_fn()的输出结果,params等是新建Estimator对象时填入的各种参数,mode在调用train/evaluate/predict时自动确定。

  • 如何定义输入input_fn()?bert源码中input_fn(params)是如何传递params的?通过阅读tf源码终于理解了。这是Estimator类中调用input_fn()的源码,
  def _call_input_fn(self, input_fn, mode):
    input_fn_args = function_utils.fn_args(input_fn)
    kwargs = {}
    if 'mode' in input_fn_args:
      kwargs['mode'] = mode
    if 'params' in input_fn_args:
      kwargs['params'] = self.params
    if 'config' in input_fn_args:
      kwargs['config'] = self.config
    with ops.device('/cpu:0'):
      return input_fn(**kwargs)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

可以看到,input_fn的params等参数来源于Estimator类的对应参数。回到bert源码,run_pretraining.py中新建了Estimator对象,

  estimator = tf.contrib.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

参考tpu_estimator.py中的注释,

`input_fn` and `model_fn` will receive `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`
  • 1

因此在input_fn(params)中可以使用batch_size = params[“batch_size”],获取到train_batch_size或eval_batch_size的值。

  • result = estimator.evaluate()结果的数据结构?
  • result = estimator.predict()结果的数据结构?
  • tf.estimator.Estimator和tf.estimator.EstimatorSpec
  • export_saved_model()没看明白
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/312303
推荐阅读
相关标签
  

闽ICP备14008679号