当前位置:   article > 正文

tensorflow测试代码_天坑Tensorflow

should have graph attribute that is equal to current graph

251c3dde3cfd33389ef001dc72389f49.png

最近入坑了一个项目,是有关pretraining的,模型已经训练好了,需要做一些downstream任务的测试。然而项目是用tensorflow写的,作为一个重度pytorch用户,我只听说过一些诸如session,eager execution等云里雾里的名词,和Keras整合时的混乱,以及从TF1迁移到TF2时的痛苦。这次硬着头皮上了。

在搭好针对downstream任务的训练和测试代码后,开始finetune。Training accuracy看起来不错,但是validation accuracy怎么这么低,难道overfit这么严重吗?我把validation set的路径指向training data,按说这时的validation accuracy应该和training大致相等,结果还是低到不可思议!那应该是计算accuracy的代码出bug了吧。于是我想打印出predictions来看看,就这样挖出了一个无底的rabbit hole。。。

先说这个项目的一些结构。代码中用了一个叫Estimator的东西,它打包了从训练到测试再到部署一系列环节,而我们只要给它写两个函数:一个model_fn和一个input_fn,其中前者运行模型,后者提供输入数据。对于每一个batch, model_fn返回一个EstimatorSpec,包含了训练和记录所需的信息:

  1. return tf.contrib.tpu.TPUEstimatorSpec(mode=tf.estimator.ModeKeys.EVAL,
  2. loss=losses['loss'],
  3. train_op=train_op,
  4. eval_metrics=(metric_fn, [evaluation[metric]]),
  5. scaffold_fn=scaffold_fn,
  6. host_call=host_call,
  7. predictions=evaluation)

首先试了试用print()打印batch的预测结果:

print(evaluation['predictions'])

然而并没有效果。原来是Estimator会以graph execution而不是eager execution来执行,在建图时只会保留必要的tensor operation,而舍弃像print()之类的效果。


然后我在TF中找出两个打印函数:tf.print()和tf.compat.v1.Print()。这里说一下compat.v1这个模块,它是为了从TF1迁移至TF2时的后向兼容性。所有的TF1函数都被移到了compat.v1中,保留了原有的函数签名和语法,并且还支持在TF2的runtime里运行(只是在优化度方面不及TF2原生的代码)。用compat.v1中的函数写出的代码,同时支持在TF1和TF2中运行,并且能够很方便地被升级成TF2原生代码。但这里的tf.compat.v1.Print()在文档中已被标明deprecated,如有需要应该用tf.print()。

试着用tf.print()打印:

tf.print(evaluation['predictions'])

然而还是没有效果。这个操作应该会被加到graph当中执行的啊?然后在tf.print()的文档里找到了这么一小段话:

800a984e4b20f3803b1d45e88c3c752d.png

好吧,谁让我在用TF1,而且是graph execution呢。Estimator没有session,那我就加一个control_dependencies()试试:

  1. print_op = tf.print(losses['loss'])
  2. with tf.control_dependencies([print_op]):
  3. losses['loss'] = tf.identity(losses['loss'])

好嘛,终于有动静了,给我糊了一脸stack trace,唯一有用的是这么一段:

  1. ERROR:tensorflow:Error recorded from evaluation_loop: From /job:worker/replica:0/task:0:
  2. Compilation failure: Detected unsupported operations when trying to compile graph _functionalize_body_2[] on XLA_TPU_JIT: StringFormat (No registered 'StringFormat' OpKernel for XLA_TPU_JIT devices compatible with node node StringFormat (defined at modeling.py:417)
  3. . Registered: device='CPU'
  4. )node StringFormat (defined at modeling.py:417)
  5. [[LoopCond]]
  6. TPU compilation failed
  7. [[tpu_compile_succeeded_assert/_6684163811970357570/_1671]]
  8. Errors may have originated from an input operation.
  9. Input Source operations connected to node StringFormat:
  10. truediv (defined at modeling.py:323)

我把这个“No registered StringFormat OpKernel“错误搜了一下,感觉完全没有相关的讨论。又走到了死胡同。


这时我上朋友圈问了一圈,有好几个好心的朋友提议我试试LoggingTensorHook。大概是长这样:

  1. hook = tf.train.LoggingTensorHook({'accuracy': evaluation['accuracy']}, every_n_iter=10)
  2. return tf.contrib.tpu.TPUEstimatorSpec(evaluation_hooks=[hook],
  3. ...)

结果这次运行到Restoring model parameters时竟然卡住了!没有任何报错信息,只是卡在这里十多分钟完全不走。加一个hook会影响parameter loading真的想不通。

这里再吐槽一下Tensorflow的doc。我是在StackOverflow里看到的tf.train.LoggingTensorHook,然而这个类在TF2的文档里根本找不到,原来这个是TF1中的命名,在TF2中被挪到了tf.estimator.LoggingTensorHook。然而,在compat.v1的文档中,LoggingTensorHook并没有出现在tf.train或tf.estimator模块下,仿佛它在TF1中从没有出现过一般。最后我是特意翻出TF 1.14的文档来(已经被archive到了GitHub里)才找到,并发现了这段信息:

000809668d2f864139470f09f0e96cc0.png

讲道理,改名这种事情应该放在最新版文档的显眼位置,要翻这么久才能找到的信息,不能不让人感到confusing。


实在不行,把Estimator拆成一个自己写的loop,然后用eager execution行不行呢?这样上面的print()以及tf.print()都应该work了吧?

因为TF1默认graph execution,我们先要启用eager execution:

tf.compat.v1.enable_eager_execution()

把TPU设置好:

  1. resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  2. tf.config.experimental_connect_to_host('10.240.1.10:8470')
  3. tf.tpu.experimental.initialize_tpu_system(resolver)
  4. strategy = tf.distribute.experimental.TPUStrategy(resolver)

然后写我们的loop:

  1. with strategy.scope():
  2. input_fn = tvqa_dataloader.input_fn_builder(config, is_training=False)
  3. dataset = input_fn(params={'batch_size': 8})
  4. for data in dataset:
  5. features, labels = data
  6. spec = model_fn(features, labels, tf.estimator.ModeKeys.EVAL, params={})
  7. metric_fn, tensors = spec.eval_metrics
  8. acc = tensors[0]
  9. print(tf.make_ndarray(acc))
  10. tf.get_variable_scope().reuse_variables()

这次运行的时候卡在了model_fn()中创建model的环节,仍然是没有任何报错信息。


既然TF1的文档和支持这么差了,如果把代码升级成TF2会不会解决这个无法print的问题呢。。

<One hour later...>

把代码库升级到TF2了,过程还算顺利,唯一的glitch是slim库作为contrib被踢出了TF2,于是我得重新安装tf_slim包,并更新了一遍各处tfexample.decoder的函数名。

  1. 先上tf.print(),这次应该不用control_dependency了,但是这回又出了些新东西:
  1. Exception in thread Thread-1:
  2. Traceback (most recent call last):
  3. File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/threading.py", line 917, in _bootstrap_inner
  4. self.run()
  5. File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/site-packages/tensorflow/python/tpu/preempted_hook.py", line 87, in run
  6. recoverable = self._cluster._cloud_tpu_client.recoverable() # pylint: disable=protected-access
  7. File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/site-packages/tensorflow/python/tpu/client/client.py", line 264, in recoverable
  8. elif FLAGS.runtime_oom_exit and self._oom_event():
  9. File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/site-packages/absl/flags/_flagvalues.py", line 498, in __getattr__
  10. raise _exceptions.UnparsedFlagAccessError(error_message)
  11. absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --runtime_oom_exit before flags were parsed.

2. 再试LoggingTensorHook,现在给了这样一个错误:

ValueError: Passed Tensor("Cast_372:0", shape=(8,), dtype=float32) should have graph attribute that is equal to current graph <tensorflow.python.framework.ops.Graph object at 0x7f5c5cf12ef0>.

又去StackOverflow搜了一圈,感觉这个错误是因为tensor不是从model_fn()里出来的。可我这个accuracy和loss确实都是model_fn()计算出来的啊。TF的GitHub上有一个issue和我遇到的问题相似,但是楼主并没有follow-up自己是如何解决的。

3. 写成eager execution,错误是:

  1. ValueError: Attempt to convert a value (functools.partial(<tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PolynomialDecay object at 0x7fee44640470>, TPUMirroredVariable:{
  2. 0: <tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>,
  3. 1: <tf.Variable 'global_step/replica_1:0' shape=() dtype=int64, numpy=0>,
  4. 2: <tf.Variable 'global_step/replica_2:0' shape=() dtype=int64, numpy=0>,
  5. 3: <tf.Variable 'global_step/replica_3:0' shape=() dtype=int64, numpy=0>,
  6. 4: <tf.Variable 'global_step/replica_4:0' shape=() dtype=int64, numpy=0>,
  7. 5: <tf.Variable 'global_step/replica_5:0' shape=() dtype=int64, numpy=0>,
  8. 6: <tf.Variable 'global_step/replica_6:0' shape=() dtype=int64, numpy=0>,
  9. 7: <tf.Variable 'global_step/replica_7:0' shape=() dtype=int64, numpy=0>
  10. })) with an unsupported type (<class 'functools.partial'>) to a Tensor.

总结一下,这次上手Tensorflow的体验可以说是极差的,官方文档在处理TF1到2的迁移时存在很多混乱,很多函数的详细功能、使用条件、常见问题说得都不是很清楚,很多常见问题在StackOverflow上的讨论也不充分。

结论:还是Pytorch香!

最后祭出这张图:

f9e3ef602d0f995c9cf456d80a1d6cbb.png
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/312310
推荐阅读
相关标签
  

闽ICP备14008679号