赞
踩
BERT是谷歌去年推出的NLP模型,一经推出就在各项测试中碾压竞争对手,而且BERT是开源的。只可惜训练BERT的价格实在太高,让人望而却步。
之前需要用64个TPU训练4天才能完成,后来谷歌用并行计算优化了到只需一个多小时,但是需要的TPU数量陡增,达到了惊人的1024个。
那么总共要多少钱呢?谷歌云TPU的使用价格是每个每小时6.5美元,训练完成训练完整个模型需要近4万美元,简直就是天价。
现在,有个羊毛告诉你,在Medium上有人找到了薅谷歌羊毛的办法,只需1美元就能训练BERT,模型还能留存在你的谷歌云盘中,留作以后使用。
为了薅谷歌的羊毛,您需要一个Google云存储(Google Cloud Storage)空间。按照Google 云TPU快速入门指南,创建Google云平台(Google Cloud Platform)帐户和Google云存储账户。新的谷歌云平台用户可获得300美元的免费赠送金额。
在TPUv2上预训练BERT-Base模型大约需要54小时。Google Colab并非设计用于执行长时间运行的作业,它会每8小时左右中断一次训练过程。对于不间断的训练,请考虑使用付费的不间断使用TPUv2的方法。
也就是说,使用Colab TPU,你可以在以1美元的价格在Google云盘上存储模型和数据,以几乎可忽略成本从头开始预训练BERT模型。
以下是整个过程的代码下面的代码,可以在Colab Jupyter环境中运行。
首先,安装训练模型所需的包。Jupyter允许使用’!’直接从笔记本执行bash命令:
!pip install sentencepiece!git clone https://github.com/google-research/bert!git clone https://github.com/google-research/bert
导入包并在Google云中授权:
import osimport sysimport jsonimport nltkimport randomimport loggingimport tensorflow as tfimport sentencepiece as spmfrom glob import globfrom google.colab import auth, drivefrom tensorflow.keras.utils import Progbarsys.path.append("bert")from bert import modeling, optimization, tokenizationfrom bert.run_pretraining import input_fn_builder, model_fn_builderauth.authenticate_user()# configure logginglog = logging.getLogger('tensorflow')log.setLevel(logging.INFO)# create formatter and add it to the handlersformatter = logging.Formatter('%(asctime)s : %(message)s')sh = logging.StreamHandler()sh.setLevel(logging.INFO)sh.setFormatter(formatter)log.handlers = [sh]if 'COLAB_TPU_ADDR' in os.environ: log.info("Using TPU runtime") USE_TPU = True TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR'] with tf.Session(TPU_ADDRESS) as session: log.info('TPU address is ' + TPU_ADDRESS) # Upload credentials to TPU. with open('/content/adc.json', 'r') as f: auth_info = json.load(f) tf.contrib.cloud.configure_gcs(session, credentials=auth_info)else: log.warning('Not connected to TPU runtime') USE_TPU = Falseimport sysimport jsonimport nltkimport randomimport loggingimport tensorflow as tfimport sentencepiece as spmfrom glob import globfrom google.colab import auth, drivefrom tensorflow.keras.utils import Progbarsys.path.append("bert")from bert import modeling, optimization, tokenizationfrom bert.run_pretraining import input_fn_builder, model_fn_builderauth.authenticate_user()# configure logginglog = logging.getLogger('tensorflow')log.setLevel(logging.INFO)# create formatter and add it to the handlersformatter = logging.Formatter('%(asctime)s : %(message)s')sh = logging.StreamHandler()sh.setLevel(logging.INFO)sh.setFormatter(formatter)log.handlers = [sh]if 'COLAB_TPU_ADDR' in os.environ: log.info("Using TPU runtime") USE_TPU = True TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR'] with tf.Session(TPU_ADDRESS) as session: log.info('TPU address is ' + TPU_ADDRESS) # Upload credentials to TPU. with open('/content/adc.json', 'r') as f: auth_info = json.load(f) tf.contrib.cloud.configure_gcs(session, credentials=auth_info)else: log.warning('Not connected to TPU runtime') USE_TPU = False
接下来从网络上获取文本数据语料库。在本次实验中,我们使用OpenSubtitles数据集,该数据集包括65种语言。
与更常用的文本数据集(如维基百科)不同,它不需要任何复杂的预处理,提供预格式化,一行一个句子。
AVAILABLE = {'af','ar','bg','bn','br','bs','ca','cs', 'da','de','el','en','eo','es','et','eu', 'fa','fi','fr','gl','he','hi','hr','hu', 'hy','id','is','it','ja','ka','kk','ko', 'lt','lv','mk','ml','ms','nl','no','pl', 'pt','pt_br','ro','ru','si','sk','sl','sq', 'sr','sv','ta','te','th','tl','tr','uk', 'ur','vi','ze_en','ze_zh','zh','zh_cn', 'zh_en','zh_tw','zh_zh'}LANG_CODE = "en" #@param {type:"string"}assert LANG_CODE in AVAILABLE, "Invalid language code selected"!wget http://opus.nlpl.eu/download.php?f=OpenSubtitles/v2016/mono/OpenSubtitles.raw.'$LANG_CODE'.gz -O dataset.txt.gz!gzip -d dataset.txt.gz!tail dataset.txtAVAILABLE = {'af','ar','bg','bn','br','bs','ca','cs', 'da','de','el','en','eo','es','et','eu', 'fa','fi','fr','gl','he','hi','hr','hu', 'hy','id','is','it','ja','ka','kk','ko', 'lt','lv','mk','ml','ms','nl','no','pl', 'pt','pt_br','ro','ru','si','sk','sl','sq', 'sr','sv','ta','te','th','tl','tr','uk', 'ur','vi','ze_en','ze_zh','zh','zh_cn', 'zh_en','zh_tw','zh_zh'}LANG_CODE = "en" #@param {type:"string"}assert LANG_CODE in AVAILABLE, "Invalid language code selected"!wget http://opus.nlpl.eu/download.php?f=OpenSubtitles/v2016/mono/OpenSubtitles.raw.'$LANG_CODE'.gz -O dataset.txt.gz!gzip -d dataset.txt.gz!tail dataset.txt
你可以通过设置代码随意选择你需要的语言。出于演示目的,代码只默认使用整个语料库的一小部分。在实际训练模型时,请务必取消选中DEMO_MODE复选框,使用大100倍的数据集。
当然,100M数据足以训练出相当不错的BERT基础模型。
DEMO_MODE = True #@param {type:"boolean"}if DEMO_MODE: CORPUS_SIZE = 1000000else: CORPUS_SIZE = 100000000 #@param {type: "integer"}!(head -n $CORPUS_SIZE dataset.txt) > subdataset.txt!mv subdataset.txt dataset.txtDEMO_MODE = True #@param {type:"boolean"}if DEMO_MODE: CORPUS_SIZE = 1000000else: CORPUS_SIZE = 100000000 #@param {type: "integer"}!(head -n $CORPUS_SIZE dataset.txt) > subdataset.txt!mv subdataset.txt dataset.txt
我们下载的原始文本数据包含标点符号,大写字母和非UTF符号,我们将在继续下一步之前将其删除。在推理期间,我们将对新数据应用相同的过程。
如果你需要不同的预处理方式(例如在推理期间预期会出现大写字母或标点符号),请修改以下代码以满足你的需求。
regex_tokenizer = nltk.RegexpTokenizer("\w+")def normalize_text(text): # lowercase text text = str(text).lower() # remove non-UTF text = text.encode("utf-8", "ignore").decode() # remove punktuation symbols text = " ".join(regex_tokenizer.tokenize(text)) return textdef count_lines(filename): count = 0 with open(filename) as fi: for line in fi: count += 1 return count"\w+")def normalize_text(text): # lowercase text text = str(text).lower() # remove non-UTF text = text.encode("utf-8", "ignore").decode() # remove punktuation symbols text = " ".join(regex_tokenizer.tokenize(text)) return textdef count_lines(filename): count = 0 with open(filename) as fi: for line in fi: count += 1 return count
现在让我们预处理整个数据集:
RAW_DATA_FPATH = "dataset.txt" #@param {type: "string"}PRC_DATA_FPATH = "proc_dataset.txt" #@param {type: "string"}# apply normalization to the dataset# this will take a minute or twototal_lines = count_lines(RAW_DATA_FPATH)bar = Progbar(total_lines)with open(RAW_DATA_FPATH,encoding="utf-8") as fi: with open(PRC_DATA_FPATH, "w",encoding="utf-8") as fo: for l in fi: fo.write(normalize_text(l)+"\n") bar.add(1)"dataset.txt" #@param {type: "string"}PRC_DATA_FPATH = "proc_dataset.txt" #@param {type: "string"}# apply normalization to the dataset# this will take a minute or twototal_lines = count_lines(RAW_DATA_FPATH)bar = Progbar(total_lines)with open(RAW_DATA_FPATH,encoding="utf-8") as fi: with open(PRC_DATA_FPATH, "w",encoding="utf-8") as fo: for l in fi: fo.write(normalize_text(l)+"\n") bar.add(1)
下一步,我们将训练模型学习一个新的词汇表,用于表示我们的数据集。
BERT文件使用WordPiece分词器,在开源中不可用。我们将在unigram模式下使用SentencePiece分词器。虽然它与BERT不直接兼容,但是通过一个小的处理方法,可以使它工作。
SentencePiece需要相当多的运行内存,因此在Colab中的运行完整数据集会导致内核崩溃。
为避免这种情况,我们将随机对数据集的一小部分进行子采样,构建词汇表。另一个选择是使用更大内存的机器来执行此步骤。
此外,SentencePiece默认情况下将BOS和EOS控制符号添加到词汇表中。我们通过将其索引设置为-1来禁用它们。
VOC_SIZE的典型值介于32000和128000之间。如果想要更新词汇表,并在预训练阶段结束后对模型进行微调,我们会保留NUM_PLACEHOLDERS个token。
MODEL_PREFIX = "tokenizer" #@param {type: "string"}VOC_SIZE = 32000 #@param {type:"integer"}SUBSAMPLE_SIZE = 12800000 #@param {type:"integer"}NUM_PLACEHOLDERS = 256 #@param {type:"integer"}SPM_COMMAND = ('--input={} --model_prefix={} ' '--vocab_size={} --input_sentence_size={} ' '--shuffle_input_sentence=true ' '--bos_id=-1 --eos_id=-1').format( PRC_DATA_FPATH, MODEL_PREFIX, VOC_SIZE - NUM_PLACEHOLDERS, SUBSAMPLE_SIZE)spm.SentencePieceTrainer.Train(SPM_COMMAND)"tokenizer" #@param {type: "string"}VOC_SIZE = 32000 #@param {type:"integer"}SUBSAMPLE_SIZE = 12800000 #@param {type:"integer"}NUM_PLACEHOLDERS = 256 #@param {type:"integer"}SPM_COMMAND = ('--input={} --model_prefix={} ' '--vocab_size={} --input_sentence_size={} ' '--shuffle_input_sentence=true ' '--bos_id=-1 --eos_id=-1').format( PRC_DATA_FPATH, MODEL_PREFIX, VOC_SIZE - NUM_PLACEHOLDERS, SUBSAMPLE_SIZE)spm.SentencePieceTrainer.Train(SPM_COMMAND)
现在,让我们看看如何让SentencePiece在BERT模型上工作。
下面是使用来自官方的预训练英语BERT基础模型的WordPiece词汇表标记的语句。
>>> wordpiece.tokenize("Colorless geothermal substations are generating furiously")['color', '##less', 'geo', '##thermal', 'sub', '##station', '##s', 'are', 'generating', 'furiously']"Colorless geothermal substations are generating furiously")['color', '##less', 'geo', '##thermal', 'sub', '##station', '##s', 'are', 'generating', 'furiously']
WordPiece标记器在“##”的单词中间预置了出现的子字。在单词开头出现的子词不变。如果子词出现在单词的开头和中间,则两个版本(带和不带’##’)都会添加到词汇表中。
SentencePiece创建了两个文件:tokenizer.model和tokenizer.vocab。让我们来看看它学到的词汇:
def read_sentencepiece_vocab(filepath): voc = [] with open(filepath, encoding='utf-8') as fi: for line in fi: voc.append(line.split("\t")[0]) # skip the first <unk> token voc = voc[1:] return vocsnt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))print("Learnt vocab size: {}".format(len(snt_vocab)))print("Sample tokens: {}".format(random.sample(snt_vocab, 10))) voc = [] with open(filepath, encoding='utf-8') as fi: for line in fi: voc.append(line.split("\t")[0]) # skip the first <unk> token voc = voc[1:] return vocsnt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))print("Learnt vocab size: {}".format(len(snt_vocab)))print("Sample tokens: {}".format(random.sample(snt_vocab, 10)))
运行结果:
Learnt vocab size: 31743 Sample tokens: ['▁cafe', '▁slippery', 'xious', '▁resonate', '▁terrier', '▁feat', '▁frequencies', 'ainty', '▁punning', 'modern']31743 Sample tokens: ['▁cafe', '▁slippery', 'xious', '▁resonate', '▁terrier', '▁feat', '▁frequencies', 'ainty', '▁punning', 'modern']
SentencePiece与WordPiece的运行结果完全相反。从文档中可以看出:SentencePiece首先使用元符号“_”将空格转义为空格,如下所示:
Hello_World。
然后文本被分段为小块:
[Hello] [_Wor] [ld] [.]
在空格之后出现的子词(也是大多数词开头的子词)前面加上“_”,而其他子词不变。这排除了仅出现在句子开头而不是其他地方的子词。然而,这些案件应该非常罕见。
因此,为了获得类似于WordPiece的词汇表,我们需要执行一个简单的转换,从包含它的标记中删除“_”,并将“##”添加到不包含它的标记中。
我们还添加了一些BERT架构所需的特殊控制符号。按照惯例,我们把它们放在词汇的开头。
另外,我们在词汇表中添加了一些占位符token。
如果你希望使用新的用于特定任务的token来更新预先训练的模型,那么这些方法是很有用的。
在这种情况下,占位符token被替换为新的token,重新生成预训练数据,并且对新数据进行微调。
def parse_sentencepiece_token(token): if token.startswith("▁"): return token[1:] else: return "##" + tokenbert_vocab = list(map(parse_sentencepiece_token, snt_vocab))ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]bert_vocab = ctrl_symbols + bert_vocabbert_vocab += ["[UNUSED_{}]".format(i) for i in range(VOC_SIZE - len(bert_vocab))]print(len(bert_vocab)) if token.startswith("▁"): return token[1:] else: return "##" + tokenbert_vocab = list(map(parse_sentencepiece_token, snt_vocab))ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]bert_vocab = ctrl_symbols + bert_vocabbert_vocab += ["[UNUSED_{}]".format(i) for i in range(VOC_SIZE - len(bert_vocab))]print(len(bert_vocab))
最后,我们将获得的词汇表写入文件。
VOC_FNAME = "vocab.txt" #@param {type:"string"}with open(VOC_FNAME, "w") as fo: for token in bert_vocab: fo.write(token+"\n")"vocab.txt" #@param {type:"string"}with open(VOC_FNAME, "w") as fo: for token in bert_vocab: fo.write(token+"\n")
现在,让我们看看新词汇在实践中是如何运作的:
>>> testcase = "Colorless geothermal substations are generating furiously">>> bert_tokenizer = tokenization.FullTokenizer(VOC_FNAME)>>> bert_tokenizer.tokenize(testcase)['color', '##less', 'geo', '##ther', '##mal', 'sub', '##station', '##s', 'are', 'generat', '##ing', 'furious', '##ly']"Colorless geothermal substations are generating furiously">>> bert_tokenizer = tokenization.FullTokenizer(VOC_FNAME)>>> bert_tokenizer.tokenize(testcase)['color', '##less', 'geo', '##ther', '##mal', 'sub', '##station', '##s', 'are', 'generat', '##ing', 'furious', '##ly']
通过手头的词汇表,我们可以为BERT模型生成预训练数据。
由于我们的数据集可能非常大,我们将其拆分为碎片:
mkdir ./shardssplit -a 4 -l 256000 -d $PRC_DATA_FPATH ./shards/shard_split -a 4 -l 256000 -d $PRC_DATA_FPATH ./shards/shard_
现在,对于每个部分,我们需要从BERT仓库调用create_pretraining_data.py脚本,需要使用xargs命令。
在开始生成之前,我们需要设置一些参数传递给脚本。你可以从自述文件中找到有关它们含义的更多信息。
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}MASKED_LM_PROB = 0.15 #@paramMAX_PREDICTIONS = 20 #@param {type:"integer"}DO_LOWER_CASE = True #@param {type:"boolean"}PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}# controls how many parallel processes xargs can createPROCESSES = 2 #@param {type:"integer"}128 #@param {type:"integer"}MASKED_LM_PROB = 0.15 #@paramMAX_PREDICTIONS = 20 #@param {type:"integer"}DO_LOWER_CASE = True #@param {type:"boolean"}PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}# controls how many parallel processes xargs can createPROCESSES = 2 #@param {type:"integer"}
运行此操作可能需要相当长的时间,具体取决于数据集的大小。
XARGS_CMD = ("ls ./shards/ | " "xargs -n 1 -P {} -I{} " "python3 bert/create_pretraining_data.py " "--input_file=./shards/{} " "--output_file={}/{}.tfrecord " "--vocab_file={} " "--do_lower_case={} " "--max_predictions_per_seq={} " "--max_seq_length={} " "--masked_lm_prob={} " "--random_seed=34 " "--dupe_factor=5")XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}', PRETRAINING_DIR, '{}', VOC_FNAME, DO_LOWER_CASE, MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)tf.gfile.MkDir(PRETRAINING_DIR)!$XARGS_CMD"ls ./shards/ | " "xargs -n 1 -P {} -I{} " "python3 bert/create_pretraining_data.py " "--input_file=./shards/{} " "--output_file={}/{}.tfrecord " "--vocab_file={} " "--do_lower_case={} " "--max_predictions_per_seq={} " "--max_seq_length={} " "--masked_lm_prob={} " "--random_seed=34 " "--dupe_factor=5")XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}', PRETRAINING_DIR, '{}', VOC_FNAME, DO_LOWER_CASE, MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)tf.gfile.MkDir(PRETRAINING_DIR)!$XARGS_CMD
为了保留来之不易的训练模型,我们会将其保留在Google云存储中。
在Google云存储中创建两个目录,一个用于数据,一个用于模型。在模型目录中,我们将放置模型词汇表和配置文件。
在继续操作之前,请配置BUCKET_NAME变量,否则将无法训练模型。
BUCKET_NAME = "bert_resourses" #@param {type:"string"}MODEL_DIR = "bert_model" #@param {type:"string"}tf.gfile.MkDir(MODEL_DIR)if not BUCKET_NAME: log.warning("WARNING: BUCKET_NAME is not set. " "You will not be able to train the model.")"bert_resourses" #@param {type:"string"}MODEL_DIR = "bert_model" #@param {type:"string"}tf.gfile.MkDir(MODEL_DIR)if not BUCKET_NAME: log.warning("WARNING: BUCKET_NAME is not set. " "You will not be able to train the model.")
下面是BERT-base的超参数配置示例:
# use this for BERT-basebert_base_config = { "attention_probs_dropout_prob": 0.1, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "max_position_embeddings": 512, "num_attention_heads": 12, "num_hidden_layers": 12, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "type_vocab_size": 2, "vocab_size": VOC_SIZE}with open("{}/bert_config.json".format(MODEL_DIR), "w") as fo: json.dump(bert_base_config, fo, indent=2)with open("{}/{}".format(MODEL_DIR, VOC_FNAME), "w") as fo: for token in bert_vocab: fo.write(token+"\n")bert_base_config = { "attention_probs_dropout_prob": 0.1, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "max_position_embeddings": 512, "num_attention_heads": 12, "num_hidden_layers": 12, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "type_vocab_size": 2, "vocab_size": VOC_SIZE}with open("{}/bert_config.json".format(MODEL_DIR), "w") as fo: json.dump(bert_base_config, fo, indent=2)with open("{}/{}".format(MODEL_DIR, VOC_FNAME), "w") as fo: for token in bert_vocab: fo.write(token+"\n")
现在,我们已准备好将模型和数据存储到谷歌云当中:
if BUCKET_NAME: !gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME BUCKET_NAME: !gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME
注意,之前步骤中的某些参数在此处不用改变。请确保在整个实验中设置的参数完全相同。
BUCKET_NAME = "bert_resourses" #@param {type:"string"}MODEL_DIR = "bert_model" #@param {type:"string"}PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}VOC_FNAME = "vocab.txt" #@param {type:"string"}# Input data pipeline configTRAIN_BATCH_SIZE = 128 #@param {type:"integer"}MAX_PREDICTIONS = 20 #@param {type:"integer"}MAX_SEQ_LENGTH = 128 #@param {type:"integer"}MASKED_LM_PROB = 0.15 #@param# Training procedure configEVAL_BATCH_SIZE = 64LEARNING_RATE = 2e-5TRAIN_STEPS = 1000000 #@param {type:"integer"}SAVE_CHECKPOINTS_STEPS = 2500 #@param {type:"integer"}NUM_TPU_CORES = 8if BUCKET_NAME: BUCKET_PATH = "gs://{}".format(BUCKET_NAME)else: BUCKET_PATH = "."BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)VOCAB_FILE = os.path.join(BERT_GCS_DIR, VOC_FNAME)CONFIG_FILE = os.path.join(BERT_GCS_DIR, "bert_config.json")INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)bert_config = modeling.BertConfig.from_json_file(CONFIG_FILE)input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))log.info("Using {} data shards".format(len(input_files)))"bert_resourses" #@param {type:"string"}MODEL_DIR = "bert_model" #@param {type:"string"}PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}VOC_FNAME = "vocab.txt" #@param {type:"string"}# Input data pipeline configTRAIN_BATCH_SIZE = 128 #@param {type:"integer"}MAX_PREDICTIONS = 20 #@param {type:"integer"}MAX_SEQ_LENGTH = 128 #@param {type:"integer"}MASKED_LM_PROB = 0.15 #@param# Training procedure configEVAL_BATCH_SIZE = 64LEARNING_RATE = 2e-5TRAIN_STEPS = 1000000 #@param {type:"integer"}SAVE_CHECKPOINTS_STEPS = 2500 #@param {type:"integer"}NUM_TPU_CORES = 8if BUCKET_NAME: BUCKET_PATH = "gs://{}".format(BUCKET_NAME)else: BUCKET_PATH = "."BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)VOCAB_FILE = os.path.join(BERT_GCS_DIR, VOC_FNAME)CONFIG_FILE = os.path.join(BERT_GCS_DIR, "bert_config.json")INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)bert_config = modeling.BertConfig.from_json_file(CONFIG_FILE)input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))log.info("Using {} data shards".format(len(input_files)))
准备训练运行配置,建立评估器和输入函数,启动BERT!
model_fn = model_fn_builder( bert_config=bert_config, init_checkpoint=INIT_CHECKPOINT, learning_rate=LEARNING_RATE, num_train_steps=TRAIN_STEPS, num_warmup_steps=10, use_tpu=USE_TPU, use_one_hot_embeddings=True)tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=BERT_GCS_DIR, save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS, tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=SAVE_CHECKPOINTS_STEPS, num_shards=NUM_TPU_CORES, per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))estimator = tf.contrib.tpu.TPUEstimator( use_tpu=USE_TPU, model_fn=model_fn, config=run_config, train_batch_size=TRAIN_BATCH_SIZE, eval_batch_size=EVAL_BATCH_SIZE)train_input_fn = input_fn_builder( input_files=input_files, max_seq_length=MAX_SEQ_LENGTH, max_predictions_per_seq=MAX_PREDICTIONS, is_training=True) bert_config=bert_config, init_checkpoint=INIT_CHECKPOINT, learning_rate=LEARNING_RATE, num_train_steps=TRAIN_STEPS, num_warmup_steps=10, use_tpu=USE_TPU, use_one_hot_embeddings=True)tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=BERT_GCS_DIR, save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS, tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=SAVE_CHECKPOINTS_STEPS, num_shards=NUM_TPU_CORES, per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))estimator = tf.contrib.tpu.TPUEstimator( use_tpu=USE_TPU, model_fn=model_fn, config=run_config, train_batch_size=TRAIN_BATCH_SIZE, eval_batch_size=EVAL_BATCH_SIZE)train_input_fn = input_fn_builder( input_files=input_files, max_seq_length=MAX_SEQ_LENGTH, max_predictions_per_seq=MAX_PREDICTIONS, is_training=True)
执行!
estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)
最后,使用默认参数训练模型需要100万步,约54小时的运行时间。如果内核由于某种原因重新启动,可以从断点处继续训练。
以上就是是在云TPU上从头开始预训练BERT的指南。
好的,我们已经训练好了模型,接下来可以做什么?
1、使用预训练的模型作为通用的自然语言理解模块;
2、针对某些特定的分类任务微调模型;
3、使用BERT作为构建块,去创建另一个深度学习模型。
原文地址:
https://towardsdatascience.com/pre-training-bert-from-scratch-with-cloud-tpu-6e2f71028379
Colab代码:
https://colab.research.google.com/drive/1nVn6AFpQSzXBt8_ywfx6XR8ZfQXlKGAz
作者系网易新闻·网易号“各有态度”签约作者
— 完 —
加入社群 | 与优秀的人交流
小程序 | 全类别AI学习教程
量子位 QbitAI · 头条号签约作者
վ'ᴗ' ի 追踪AI技术和产品新动态
喜欢就点「在看」吧 !
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。