赞
踩
本文源码来源于 Github上的BERT 项目中的 run_pretraining.py 文件。阅读本文需要对Attention Is All You Need以及BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding两篇论文有所了解,以及部分关于深度学习、自然语言处理和Tensorflow的储备知识。
略。
flags.DEFINE_string(
"bert_config_file", None,
"The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
flags.DEFINE_string(
"input_file", None,
"Input TF example files (can be a glob or comma separated).")
flags.DEFINE_string(
"output_dir", None,
"The output directory where the model checkpoints will be written.")
flags.DEFINE_string(
"init_checkpoint", None,
"Initial checkpoint (usually from a pre-trained BERT model).")
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded. Must match data generation.")
flags.DEFINE_integer(
"max_predictions_per_seq", 20,
"Maximum number of masked LM predictions per sequence. "
"Must match data generation.")
每个句子的最大 MLM 预测数,必须和数据匹配。关于 MLM 模型,详情请参照 BERT论文。
flags.DEFINE_bool("do_train", False, "Whether to run training.") flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.") flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") flags.DEFINE_integer("save_checkpoints_steps", 1000, "How often to save the model checkpoint.") flags.DEFINE_integer("iterations_per_loop", 1000, "How many steps to make in each estimator call.") flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
tf.flags.DEFINE_string(
"tpu_name", None,
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
tf.flags.DEFINE_string(
"tpu_zone", None,
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string(
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。