赞
踩
参考了官方文档,有些细节问题需要阅读tf源码。
def model_fn(features, labels, mode, params)
其中features和labels是input_fn()的输出结果,params等是新建Estimator对象时填入的各种参数,mode在调用train/evaluate/predict时自动确定。
def _call_input_fn(self, input_fn, mode):
input_fn_args = function_utils.fn_args(input_fn)
kwargs = {}
if 'mode' in input_fn_args:
kwargs['mode'] = mode
if 'params' in input_fn_args:
kwargs['params'] = self.params
if 'config' in input_fn_args:
kwargs['config'] = self.config
with ops.device('/cpu:0'):
return input_fn(**kwargs)
可以看到,input_fn的params等参数来源于Estimator类的对应参数。回到bert源码,run_pretraining.py中新建了Estimator对象,
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=model_fn,
config=run_config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size)
参考tpu_estimator.py中的注释,
`input_fn` and `model_fn` will receive `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`
因此在input_fn(params)中可以使用batch_size = params[“batch_size”],获取到train_batch_size或eval_batch_size的值。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。