当前位置:   article > 正文

【MindSpore】跑bert base pretrain for gpu,AdamWeightDecay优化器性能对不上_mindspore bertadam类

mindspore bertadam类

model_zoo/official/nlp/bert

1、版本为 r1.1,硬件为 NV SMX2 V100-16G * 8

2、参数设置 batch size 32 * 8,优化器 'AdamWeightDecay'

3、运行 bash scripts/run_distributed_pretrain_for_gpu.sh 8 1 /data/wiki

问题:官方给的结果是290ms/step,我跑出来的是 80ms/step,相差约 3 倍

如果设置优化器为 'Momentum',跑出来的是 288ms/step,结果正常

还请帮忙定位或提供建议,谢谢~

【日志信息】(可选,上传日志内容或者附件)

优化器设置为 'AdamWeightDecay' 时开头部分的日志信息:

args_opt: Namespace(accumulation_steps=1, batch_size=32, data_dir='/workspace/bert/data/wiki', data_sink_steps=20, device_id=0, device_num=1, device_target='GPU', distribute='true', do_shuffle='true', enable_data_sink='true', enable_global_norm='false', enable_graph_kernel='auto', enable_lossscale='false', enable_save_ckpt='false', epoch_size=1, load_checkpoint_path='', optimizer='AdamWeightDecay', save_checkpoint_num=1, save_checkpoint_path='', save_checkpoint_steps=10000, schema_dir='', train_steps=120)

cfg: {'batch_size': 32, 'bert_network': 'base', 'loss_scale_value': 65536, 'scale_factor': 2, 'scale_window': 1000, 'optimizer': 'AdamWeightDecay', 'enable_global_norm': False, 'AdamWeightDecay': {'learning_rate': 3e-05, 'end_learning_rate': 0.0, 'power': 5.0, 'weight_decay': 1e-05, 'decay_filter': <function <lambda> at 0x7ff9a8705320>, 'eps': 1e-06, 'warmup_steps': 10000}, 'Lamb': {'learning_rate': 3e-05, 'end_learning_rate': 0.0, 'power': 10.0, 'warmup_steps': 10000, 'weight_decay': 0.01, 'decay_filter': <function <lambda> at 0x7ff9a8714680>, 'eps': 1e-06}, 'Momentum': {'learning_rate': 2e-05, 'momentum': 0.9}}

origin dataset size:  5823104

origin dataset size:  5823104

origin dataset size:  5823104

origin dataset size:  5823104

origin dataset size:  5823104

origin dataset size:  5823104

origin dataset size:  5823104

origin dataset size:  5823104

epoch: 0, current epoch percent: 0.000, step: 20, outputs are [10.4174595]

epoch: 0, current epoch percent: 0.000, step: 20, outputs are [10.417459]

epoch time: 135128.564 ms, per step time: 6756.428 ms

epoch time: 135241.635 ms, per step time: 6762.082 ms

epoch: 0, current epoch percent: 0.000, step: 20, outputs are [10.4174595]

epoch: 0, current epoch percent: 0.000, step: 20, outputs are [10.417459]

epoch: 0, current epoch percent: 0.000, step: 20, outputs are [10.4174595]

epoch time: 136646.635 ms, per step time: 6832.332 ms

epoch: 0, current epoch percent: 0.000, step: 20, outputs are [10.4174595]

epoch: 0, current epoch percent: 0.000, step: 20, outputs are [10.4174595]

epoch time: 134981.528 ms, per step time: 6749.076 ms

epoch time: 134997.352 ms, per step time: 6749.868 ms

epoch: 0, current epoch percent: 0.000, step: 20, outputs are [10.4174595]

epoch time: 136710.835 ms, per step time: 6835.542 ms

epoch time: 136442.407 ms, per step time: 6822.120 ms

epoch time: 136147.168 ms, per step time: 6807.358 ms

epoch: 0, current epoch percent: 0.000, step: 40, outputs are [10.07456]

epoch time: 1611.637 ms, per step time: 80.582 ms

epoch: 0, current epoch percent: 0.000, step: 40, outputs are [10.07456]

epoch: 0, current epoch percent: 0.000, step: 40, outputs are [10.074541]

epoch: 0, current epoch percent: 0.000, step: 40, outputs are [10.07456]

epoch: 0, current epoch percent: 0.000, step: 40, outputs are [10.07456]

epoch time: 1611.219 ms, per step time: 80.561 ms

epoch time: 1611.087 ms, per step time: 80.554 ms

epoch time: 1612.109 ms, per step time: 80.605 ms

epoch: 0, current epoch percent: 0.000, step: 40, outputs are [10.07456]

epoch: 0, current epoch percent: 0.000, step: 40, outputs are [10.074541]

epoch time: 1611.183 ms, per step time: 80.559 ms

epoch: 0, current epoch percent: 0.000, step: 40, outputs are [10.074554]

epoch time: 1611.513 ms, per step time: 80.576 ms

epoch time: 1611.393 ms, per step time: 80.570 ms

epoch time: 1611.300 ms, per step time: 80.565 ms

epoch: 0, current epoch percent: 0.000, step: 60, outputs are [10.104292]

epoch time: 1622.121 ms, per step time: 81.106 ms

......

解答:

找到问题了,代码默认开启了 float16,但是只有 AdamWeightDecay 优化器支持。

所以选择 AdamWeightDecay 时为 float16 结果,选择其他优化器时为float32结果。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/946061
推荐阅读
相关标签
  

闽ICP备14008679号