赞
踩
最近入坑了一个项目,是有关pretraining的,模型已经训练好了,需要做一些downstream任务的测试。然而项目是用tensorflow写的,作为一个重度pytorch用户,我只听说过一些诸如session,eager execution等云里雾里的名词,和Keras整合时的混乱,以及从TF1迁移到TF2时的痛苦。这次硬着头皮上了。
在搭好针对downstream任务的训练和测试代码后,开始finetune。Training accuracy看起来不错,但是validation accuracy怎么这么低,难道overfit这么严重吗?我把validation set的路径指向training data,按说这时的validation accuracy应该和training大致相等,结果还是低到不可思议!那应该是计算accuracy的代码出bug了吧。于是我想打印出predictions来看看,就这样挖出了一个无底的rabbit hole。。。
先说这个项目的一些结构。代码中用了一个叫Estimator的东西,它打包了从训练到测试再到部署一系列环节,而我们只要给它写两个函数:一个model_fn
和一个input_fn
,其中前者运行模型,后者提供输入数据。对于每一个batch, model_fn
返回一个EstimatorSpec,包含了训练和记录所需的信息:
- return tf.contrib.tpu.TPUEstimatorSpec(mode=tf.estimator.ModeKeys.EVAL,
- loss=losses['loss'],
- train_op=train_op,
- eval_metrics=(metric_fn, [evaluation[metric]]),
- scaffold_fn=scaffold_fn,
- host_call=host_call,
- predictions=evaluation)
首先试了试用print()打印batch的预测结果:
print(evaluation['predictions'])
然而并没有效果。原来是Estimator会以graph execution而不是eager execution来执行,在建图时只会保留必要的tensor operation,而舍弃像print()之类的效果。
然后我在TF中找出两个打印函数:tf.print()和tf.compat.v1.Print()。这里说一下compat.v1这个模块,它是为了从TF1迁移至TF2时的后向兼容性。所有的TF1函数都被移到了compat.v1中,保留了原有的函数签名和语法,并且还支持在TF2的runtime里运行(只是在优化度方面不及TF2原生的代码)。用compat.v1中的函数写出的代码,同时支持在TF1和TF2中运行,并且能够很方便地被升级成TF2原生代码。但这里的tf.compat.v1.Print()在文档中已被标明deprecated,如有需要应该用tf.print()。
试着用tf.print()打印:
tf.print(evaluation['predictions'])
然而还是没有效果。这个操作应该会被加到graph当中执行的啊?然后在tf.print()的文档里找到了这么一小段话:
好吧,谁让我在用TF1,而且是graph execution呢。Estimator没有session,那我就加一个control_dependencies()试试:
- print_op = tf.print(losses['loss'])
- with tf.control_dependencies([print_op]):
- losses['loss'] = tf.identity(losses['loss'])
好嘛,终于有动静了,给我糊了一脸stack trace,唯一有用的是这么一段:
- ERROR:tensorflow:Error recorded from evaluation_loop: From /job:worker/replica:0/task:0:
- Compilation failure: Detected unsupported operations when trying to compile graph _functionalize_body_2[] on XLA_TPU_JIT: StringFormat (No registered 'StringFormat' OpKernel for XLA_TPU_JIT devices compatible with node node StringFormat (defined at modeling.py:417)
- . Registered: device='CPU'
- )node StringFormat (defined at modeling.py:417)
- [[LoopCond]]
- TPU compilation failed
- [[tpu_compile_succeeded_assert/_6684163811970357570/_1671]]
-
- Errors may have originated from an input operation.
- Input Source operations connected to node StringFormat:
- truediv (defined at modeling.py:323)
我把这个“No registered StringFormat OpKernel“错误搜了一下,感觉完全没有相关的讨论。又走到了死胡同。
这时我上朋友圈问了一圈,有好几个好心的朋友提议我试试LoggingTensorHook。大概是长这样:
- hook = tf.train.LoggingTensorHook({'accuracy': evaluation['accuracy']}, every_n_iter=10)
- return tf.contrib.tpu.TPUEstimatorSpec(evaluation_hooks=[hook],
- ...)
结果这次运行到Restoring model parameters时竟然卡住了!没有任何报错信息,只是卡在这里十多分钟完全不走。加一个hook会影响parameter loading真的想不通。
这里再吐槽一下Tensorflow的doc。我是在StackOverflow里看到的tf.train.LoggingTensorHook,然而这个类在TF2的文档里根本找不到,原来这个是TF1中的命名,在TF2中被挪到了tf.estimator.LoggingTensorHook。然而,在compat.v1的文档中,LoggingTensorHook并没有出现在tf.train或tf.estimator模块下,仿佛它在TF1中从没有出现过一般。最后我是特意翻出TF 1.14的文档来(已经被archive到了GitHub里)才找到,并发现了这段信息:
讲道理,改名这种事情应该放在最新版文档的显眼位置,要翻这么久才能找到的信息,不能不让人感到confusing。
实在不行,把Estimator拆成一个自己写的loop,然后用eager execution行不行呢?这样上面的print()以及tf.print()都应该work了吧?
因为TF1默认graph execution,我们先要启用eager execution:
tf.compat.v1.enable_eager_execution()
把TPU设置好:
- resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
- tf.config.experimental_connect_to_host('10.240.1.10:8470')
- tf.tpu.experimental.initialize_tpu_system(resolver)
- strategy = tf.distribute.experimental.TPUStrategy(resolver)
然后写我们的loop:
- with strategy.scope():
- input_fn = tvqa_dataloader.input_fn_builder(config, is_training=False)
- dataset = input_fn(params={'batch_size': 8})
- for data in dataset:
- features, labels = data
- spec = model_fn(features, labels, tf.estimator.ModeKeys.EVAL, params={})
- metric_fn, tensors = spec.eval_metrics
- acc = tensors[0]
- print(tf.make_ndarray(acc))
- tf.get_variable_scope().reuse_variables()
这次运行的时候卡在了model_fn()中创建model的环节,仍然是没有任何报错信息。
既然TF1的文档和支持这么差了,如果把代码升级成TF2会不会解决这个无法print的问题呢。。
<One hour later...>
把代码库升级到TF2了,过程还算顺利,唯一的glitch是slim库作为contrib被踢出了TF2,于是我得重新安装tf_slim包,并更新了一遍各处tfexample.decoder的函数名。
- Exception in thread Thread-1:
- Traceback (most recent call last):
- File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/threading.py", line 917, in _bootstrap_inner
- self.run()
- File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/site-packages/tensorflow/python/tpu/preempted_hook.py", line 87, in run
- recoverable = self._cluster._cloud_tpu_client.recoverable() # pylint: disable=protected-access
- File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/site-packages/tensorflow/python/tpu/client/client.py", line 264, in recoverable
- elif FLAGS.runtime_oom_exit and self._oom_event():
- File "/home/LJC/conda/envs/VLpretrain-tf2/lib/python3.7/site-packages/absl/flags/_flagvalues.py", line 498, in __getattr__
- raise _exceptions.UnparsedFlagAccessError(error_message)
- absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --runtime_oom_exit before flags were parsed.
2. 再试LoggingTensorHook,现在给了这样一个错误:
ValueError: Passed Tensor("Cast_372:0", shape=(8,), dtype=float32) should have graph attribute that is equal to current graph <tensorflow.python.framework.ops.Graph object at 0x7f5c5cf12ef0>.
又去StackOverflow搜了一圈,感觉这个错误是因为tensor不是从model_fn()里出来的。可我这个accuracy和loss确实都是model_fn()计算出来的啊。TF的GitHub上有一个issue和我遇到的问题相似,但是楼主并没有follow-up自己是如何解决的。
3. 写成eager execution,错误是:
- ValueError: Attempt to convert a value (functools.partial(<tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PolynomialDecay object at 0x7fee44640470>, TPUMirroredVariable:{
- 0: <tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>,
- 1: <tf.Variable 'global_step/replica_1:0' shape=() dtype=int64, numpy=0>,
- 2: <tf.Variable 'global_step/replica_2:0' shape=() dtype=int64, numpy=0>,
- 3: <tf.Variable 'global_step/replica_3:0' shape=() dtype=int64, numpy=0>,
- 4: <tf.Variable 'global_step/replica_4:0' shape=() dtype=int64, numpy=0>,
- 5: <tf.Variable 'global_step/replica_5:0' shape=() dtype=int64, numpy=0>,
- 6: <tf.Variable 'global_step/replica_6:0' shape=() dtype=int64, numpy=0>,
- 7: <tf.Variable 'global_step/replica_7:0' shape=() dtype=int64, numpy=0>
- })) with an unsupported type (<class 'functools.partial'>) to a Tensor.
总结一下,这次上手Tensorflow的体验可以说是极差的,官方文档在处理TF1到2的迁移时存在很多混乱,很多函数的详细功能、使用条件、常见问题说得都不是很清楚,很多常见问题在StackOverflow上的讨论也不充分。
结论:还是Pytorch香!
最后祭出这张图:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。