当前位置:   article > 正文

BERT源码深度剖析之create_pretraining_data.py_bert create

bert create

  在开始之前,建议大家先阅读专栏的第一篇文章:预训练模型代码深度剖析之开宗明义:新学常见误区和正确的学习姿势

创建预训练数据
一个类
.
八大函数
训练例对象写入文件
创建整型特征
创建浮点型特征
创建训练例
从文档中创建训练例
创建对掩码语言模型的预测
截断序列对
主函数

1. 训练例对象写入文件(write_instance_to_example_files)

1.1 重要概念

  为了方便新手理解,在讲解代码之前,会对该函数用到的神经网络的常见知识点论文中提及到的知识点进行介绍和解释。

1.1.1 序列填充(padding)和参数max_seq_length

  在神经网络的训练中,为了加速训练并且确保网络能够达到收敛的状态,往往是进行批量样本(batch)的学习,也就是说输入样本通过张量(矩阵)的形式进行输入,则张量的维度需要保持一致,最终需要使得必须让同一个batch中的每个样本的序列长度一致。

  在保持序列长度一致性的方法中,最常见的就是通过padding补零,把同一个batch中的所有样本都变成同一个长度,从而便于批量计算。对于填充值(如零),可使用mask机制来避免模型对填充值进行训练。

  那么问题来了,既然是填充零值,是往哪个方向填充呢?是往后面进行填充(post)呢还是往前面进行填充(pre)。根据huggingface的BERT文档中第一页的前半部分,链接为https://huggingface.co/docs/transformers/model_doc/bert,如下图所示:
在这里插入图片描述
  由于BERT使用的是绝对位置嵌入,所以通常建议在右边进行填充,也就是说之前说的post填充(往后面填充)。往后填充直至达到序列的最大长度,这里则引出BERT中的一个重要参数:max_seq_length,代码定义为:

flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
  • 1

  读到这里,大家可能会产生一个新的疑问,即开源的BERT模型支持序列的最大长度是512,这里为什么是128呢?具体是怎么实现的呢?本质上是通过混合批次训练(Mixed-Batch Training)来实现的。具体如原文所示:

在这里插入图片描述

  为了扩展大家的视野,介绍一篇使用非常规填充方法的论文:Effect of sequence padding on the performance of deep learning models in archaeal protein functional prediction,本文介绍了strf、zoom、ext、rnd四种填充方法,链接为https://www.nature.com/articles/s41598-020-71450-8,代码链接为https://github.com/b2slab/padding_benchmark

1.1.2 TensorFlow的写入操作

  TensorFlow为了提升对大数据的处理能力,对很多基本操作进行了重写,底层代码是C++形式的,而上层调用依然保留了Python等高级语言的形式。具体来说,比如文件的写入操作中使用的API为tf.python_io.TFRecordWriter。但同时需要注意的是,写入操作是针对文本内容和二进制内容,如果涉及到一定的数据结构,那么就是要写入二进制内容,所以在代码write(tf_example.SerializeToString())中用到了tf_example.SerializeToString(),也就是将tf_example结构转换成了二进制形式。

1.1.3 BERT模型的两大预训练任务

1.1.3.1 任务一:MLM和参数max_predictions_per_seq、masked_lm_prob

   任务一为MLM(掩码语言模型),用通俗的话来说就是做完形填空。但对于每个序列(句子)而言,完形填空遮住的词的数量最多不能超过一定的值这里则引出一个重要参数:max_predictions_per_seq,代码定义为:

flags.DEFINE_integer("max_predictions_per_seq", 20,
                     "Maximum number of masked LM predictions per sequence.")
  • 1
  • 2

  最多不能超过max_predictions_per_seq,指代的是不能超过上限,那么如果没有达到上限呢?简单来说,取的原有序列token数量的一小部分,这里则引出一个重要参数:masked_lm_prob,代码定义为:

flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
  • 1

  原文(P4 right)中masked_lm_prob为0.15,如下图所示:

在这里插入图片描述

  那么对于完形填空任务而言,最终需要预测的token的数量为

num_to_predict = min(max_predictions_per_seq,
                       max(1, int(round(len(tokens) * masked_lm_prob))))
  • 1
  • 2

  留个思考题,这里使用了round函数,那么能否使用math.floor或者math.ceil来代替呢?

1.1.3.2 任务二:Next Sentence Prediction

  任务二为Next Sentence Prediction,用通俗的话就是判断两个句子是否是上下句的关系。

  如果说MLM任务对应的参数是全局参数,那么对于Next Sentence Prediction的参数即为局部参数,也就是说该参数是对于每个TrainingInstance对象而言的,即TrainingInstance中的is_random_next参数。is_random_next为True时,则上个句子不是上下句的关系,反之则为上下句的关系。

1.2 代码详解

  首先看下代码里面是否包含循环,可以看到代码里面包括了一个for的大循环(包括了2个while的小循环,两个小循环的功能是类似的)。

  • while小循环①:序列填充之post padding zeros(不够max_seq_length的往后补0),具体为input_ids、input_mask和segment_ids。
  • while小循环②:序列填充之post padding zeros(不够max_predictions_per_seq的往后补0),具体为masked_lm_positions、masked_lm_ids、masked_lm_weights。

  tf.train.Example相比于TrainingInstance而言,有5个成员变量是一一对应的关系,增加了2个成员变量,如下图所示(左边是TrainingInstance,右边是tf.train.Example):

tokens
input_ids
segment_ids
segment_ids
is_random_next
next_sentence_labels
masked_lm_positions
masked_lm_positions
masked_lm_labels
masked_lm_ids
input_mask
masked_lm_weights

  构建tf.train.Example和TFRecord文件,可参考之前文章:TensorFlow tf.train.Example和TFRecord的实战学习。需要说明的是,每个feature本质上它也是一种智能的字典结构,所以可使用feature.int64_list.value和feature.float_list.value来提取它的值。所谓智能指代的是,当feature中的value为int类型,则feature.float_list.value则为空列表,而不会报错。

def write_instance_to_example_files(instances, tokenizer, max_seq_length,
                                    max_predictions_per_seq, output_files):
  """Create TF example files from `TrainingInstance`s."""
  writers = []
  for output_file in output_files:
    writers.append(tf.python_io.TFRecordWriter(output_file)) # 多个输出文件

  writer_index = 0

  total_written = 0
  for (inst_index, instance) in enumerate(instances):
    input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
    input_mask = [1] * len(input_ids)
    segment_ids = list(instance.segment_ids)
    assert len(input_ids) <= max_seq_length

    while len(input_ids) < max_seq_length:
      input_ids.append(0)
      input_mask.append(0)
      segment_ids.append(0)

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    masked_lm_positions = list(instance.masked_lm_positions)
    masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
    masked_lm_weights = [1.0] * len(masked_lm_ids)

    while len(masked_lm_positions) < max_predictions_per_seq:
      masked_lm_positions.append(0)
      masked_lm_ids.append(0)
      masked_lm_weights.append(0.0)

    next_sentence_label = 1 if instance.is_random_next else 0

    features = collections.OrderedDict()
    features["input_ids"] = create_int_feature(input_ids)
    features["input_mask"] = create_int_feature(input_mask)
    features["segment_ids"] = create_int_feature(segment_ids)
    features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
    features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
    features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
    features["next_sentence_labels"] = create_int_feature([next_sentence_label])

    tf_example = tf.train.Example(features=tf.train.Features(feature=features))

    writers[writer_index].write(tf_example.SerializeToString())
    writer_index = (writer_index + 1) % len(writers) # 轮番写入

    total_written += 1

    if inst_index < 20:
      tf.logging.info("*** Example ***")
      tf.logging.info("tokens: %s" % " ".join(
          [tokenization.printable_text(x) for x in instance.tokens]))

      for feature_name in features.keys():
        feature = features[feature_name]
        values = []
        if feature.int64_list.value:
          values = feature.int64_list.value
        elif feature.float_list.value:
          values = feature.float_list.value
        tf.logging.info(
            "%s: %s" % (feature_name, " ".join([str(x) for x in values])))

  for writer in writers:
    writer.close()

  tf.logging.info("Wrote %d total instances", total_written)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71

2. 创建整型特征 & 创建浮点型特征

  创建整型特征(create_int_feature)和创建浮点型特征(create_float_feature)和TensorFlow tf.train.Example和TFRecord的实战学习中的_float_feature和_int64_feature是一致的,这里就不进行赘述。

3. 创建训练例(create_training_instances)

3.1 重要概念

  1. TensorFlow使用tf.gfile.GFile对原有的文件操作进行了封装。
  2. random模块对数据打乱(shuffle)属于in-place操作
import random
rng = random.Random(12345)
lst = [1, 2, 3, 4, 5]
rng.shuffle(lst)
print(lst)
  • 1
  • 2
  • 3
  • 4
  • 5

在这里插入图片描述

3.2 代码详解

  首先看下代码里面是否包含循环,可以看到代码里面包括了一个for和while的二重循环。其中对于单个文件是使用逐行读取的,读取的相关代码为:

with tf.gfile.GFile(input_file, "r") as reader:
  while True:
    line = tokenization.convert_to_unicode(reader.readline())
    if not line:
      break
    line = line.strip()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

  来个思考题,能否将下列代码中的all_documents[-1].append(tokens)修改为all_documents.append(tokens),并将if not line中的语句进行删除呢?

if not line:
  all_documents.append([])
tokens = tokenizer.tokenize(line)
if tokens:
  all_documents[-1].append(tokens)
  • 1
  • 2
  • 3
  • 4
  • 5

  从结果来看,完全是OK的,因为all_documents = [x for x in all_documents if x]去除了空白行。

4. 从文档中创建训练例 (create_instances_from_document)

4.1 基本概念

4.1.1 添加三个特殊的tokens

  在原有instance的基础上,在开头添加[CLS]、在两个句子中间添加[SEP]、在末尾添加[SEP],如原论文中的下图所示:
在这里插入图片描述
在这里插入图片描述

  由于添加了三个特殊的tokens,所以max_num_tokens = max_seq_length - 3=128-3=125。开头的[CLS]和结尾的[SEP]是显而易见的,但中间的[SEP]究竟在哪里添加呢?具体参见4.1.3部分的内容。

4.1.2 参数short_seq_prob

  为了扩大模型的泛化性,对原有序列进行小概率(10%)的随机化处理,即取一部分tokens作为目标tokens

target_seq_length = max_num_tokens
if rng.random() < short_seq_prob:
    target_seq_length = rng.randint(2, max_num_tokens)
  • 1
  • 2
  • 3
flags.DEFINE_float(
    "short_seq_prob", 0.1,
    "Probability of creating sequences which are shorter than the "
    "maximum length.")
  • 1
  • 2
  • 3
  • 4

  需要说明的是该参数在论文中并没有进行解释。

4.1.3 NSP(document & segment & current_chunk)

  Next Sentence Prediction(NSP)任务是BERT预训练的另外一个重要任务,从论文中来看是学习两个句子之间的关系,即是否为上下句的关系,如下所示:
在这里插入图片描述
  但是在实际代码中使用的并不是两个句子,而是两个句子块,所谓句子块就是一个或者多个句子。在该函数的代码中,涉及到了上述三个重要对象,那么该三者究竟是什么样的关系呢?

  document的最小单元为segment(单个句子),而多个segment组成current_chunk。其中在current_chunk中有a_end个segment(句子)是属于句子块A。

  总结一下,document指代的是所有句子,segment指代的是单个句子,current_chunk指代的是句子块(一个或者多个句子)。

4.2 代码详解

  代码中的核心部分为一个while循环,在循环中会将每个segment放入到current_chunk中,并将current_chunk中总共包含token的个数作为current_length,当current_length >= target_seq_length时,a_end取随机值即[1, len(current_chunk)-1]的值。从而得到句子块A的所有token,如下所示:

tokens_a = []
for j in range(a_end):
  tokens_a.extend(current_chunk[j])
  • 1
  • 2
  • 3

  那么如何得到句子块B的token呢?有两种可能性,一种是当current_chunk的长度为1或者50%的概率情况下,先通过随机得到一个句子,然后进行随机截断句子前面的一部分,将余下的token放入到句子块B中,代码如下所示:

random_document = all_documents[random_document_index]
random_start = rng.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
    tokens_b.extend(random_document[j])
    if len(tokens_b) >= target_b_length:
        break
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

  另外一种是将current_chunk中a_end及其以后的segement放入到句子块B中。

  思考题:为什么会有两种可能呢?简单来说,就是分别构构建了正负样本。

  当得到句子块A和句子块B后,会对current_chunk进行清空操作。

  可能同学会对以下两句有所疑问:

num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
  • 1
  • 2

  为了方便大家理解,举个例子,假设len(current_chunk)=4,a_end=2,此时i=5,图中蓝色的部分是没有使用的,如下图所示:
在这里插入图片描述

num_unused_segments = len(current_chunk) - a_end = 2
i -= num_unused_segments = 3
  • 1
  • 2

  需要注意的是最后还有个i+=1,则此时i=4,即蓝色的起始位置开始。

  另外需要注意的是,[CLS]和[SEP]对应的segment_ids均为0。

5. 创建对掩码语言模型的预测 (create_masked_lm_predictions)

5.1 关键概念

  在MLM中,挑选出需要遮住的token的位置,此时并不是把对应的token都替换成[MASK],而是在80%的概率下替换成[MASK],10%的替换成随机token,10%不变。

在这里插入图片描述

5.2 代码详解

  首先对cand_indexes进行打乱操作,如果不使用全词mask,那么cand_indexes打乱后的每个元素即为包括单个元素的列表,比如[[3], [4], [1], [2]],而使用全词mask后,打乱后的每个元素也是列表,但列表中可能包括了多个元素,如[[5],[2,3,4],[1]]。

  在后续代码中包括了一个for的大循环,该循环为代码的核心部分,但该部分代码是多余的,如下所示:

for index in index_set:
  if index in covered_indexes:
    is_any_index_covered = True
    break
if is_any_index_covered:
  continue
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

  在for的大循环中,最终将每个token对应的id(如果是全词mask,则将其flatten的单个元素)再放入到covered_indexes中,直至len(masked_lms) >= num_to_predict时跳出。

  使用collections模块中的namedtuple函数创建MaskedLmInstance类,该类包括两大属性即index和label,如下所示:

MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
                                          ["index", "label"])
  • 1
  • 2

  最终得到了(output_tokens, masked_lm_positions, masked_lm_labels)的元组,其中output_tokens是对原有tokens进行遮盖位置处的tokens进行了替换,替换方法为5.1节中详细进行了介绍,masked_lm_positions指的是所有遮盖的位置,masked_lm_labels指的是遮盖前位置对应的tokens。为了方便大家理解,举个例子如下:

raw text is: As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge.

prior tokens is: ['[CLS]', 'ancient', 'sage', '-', '-', 'the', 'name', 'is', 'un', '##im', '##port', '##ant', 'to', 'a', 'monk', '-', '-', 'pumped', 'water', 'nightly', 'that', 'he', 'might', 'study', 'by', 'day', ',', 'so', 'i', ',', 'the', 'guardian', 'of', 'cloak', '##s', 'and', 'para', '##sol', '##s', ',', 'at', 'the', 'sacred', 'doors', 'of', 'her', 'lecture', '-', 'room', ',', 'im', '##bib', '##e', 'celestial', 'knowledge', '.', 'from', 'my', 'youth', 'i', 'felt', 'in', 'me', 'a', '[SEP]', 'fallen', 'star', ',', 'i', 'am', ',', 'sir', '!', "'", 'continued', 'he', ',', 'pens', '##ively', ',', 'stroking', 'his', 'lean', 'stomach', '-', '-', "'", 'a', 'fallen', 'star', '!', '-', '-', 'fallen', ',', 'if', 'the', 'dignity', 'of', 'philosophy', 'will', 'allow', 'of', 'the', 'simi', '##le', ',', 'among', 'the', 'hog', '##s', 'of', 'the', 'lower', 'world', '-', '-', 'indeed', ',', 'even', 'into', 'the', 'hog', '-', 'bucket', 'itself', '.', '[SEP]']

output_tokens is: ['[CLS]', 'ancient', 'sage', '[MASK]', '[MASK]', 'the', 'name', 'kang', 'un', '##im', '[MASK]', '##ant', 'to', 'a', 'monk', '-', '-', 'pumped', 'water', 'nightly', 'that', 'he', 'might', 'study', 'by', 'day', ',', 'so', 'i', '[MASK]', 'the', '[MASK]', 'of', 'cloak', '##s', '[MASK]', 'para', '##sol', '##acies', ',', 'at', 'the', 'sacred', 'doors', 'of', 'her', '[MASK]', '-', 'room', '[MASK]', 'im', '##bib', '##e', 'celestial', 'knowledge', '.', 'from', 'my', 'youth', 'i', 'felt', 'in', 'me', 'a', '[SEP]', 'fallen', 'star', ',', 'i', 'am', ',', 'bobbie', '!', "'", 'continued', 'he', ',', '[MASK]', '##ively', ',', 'stroking', 'his', 'lean', '[MASK]', '-', '-', "'", 'a', 'fallen', 'star', '!', '-', '[MASK]', 'fallen', ',', 'if', 'the', 'dignity', '[MASK]', 'philosophy', 'will', 'allow', 'of', 'the', 'simi', '##le', ',', 'among', 'the', 'hog', '[MASK]', 'of', 'the', 'lower', 'world', '-', '[MASK]', 'indeed', ',', 'even', 'into', 'the', 'hog', '-', 'bucket', 'itself', '.', '[SEP]']
masked_lm_positions is: [3, 4, 6, 7, 10, 29, 31, 35, 38, 46, 49, 71, 77, 83, 92, 98, 110, 116, 124]
masked_lm_labels is: ['-', '-', 'name', 'is', '##port', ',', 'guardian', 'and', '##s', 'lecture', ',', 'sir', 'pens', 'stomach', '-', 'of', '##s', '-', 'bucket']
output_tokens is: ['[CLS]', 'there', 'is', 'a', 'phil', '##oso', '##phic', 'pleasure', 'in', 'opening', '[MASK]', "'", 's', 'treasures', 'to', 'the', 'modest', 'young', '.', '[SEP]', 'rain', 'had', 'only', 'ceased', 'with', '[MASK]', 'gray', 'streaks', 'of', 'morning', 'at', 'blazing', 'star', ',', '[MASK]', 'the', 'settlement', 'awoke', 'to', 'a', 'moral', 'sense', 'of', 'clean', 'akron', '16th', '[MASK]', 'the', 'finding', 'of', 'forgotten', 'knives', ',', 'tin', 'cups', ',', 'and', 'smaller', 'camp', 'ut', '##ens', '##ils', ',', 'where', 'the', '[MASK]', 'showers', 'had', 'washed', 'away', 'the', 'debris', 'and', 'dust', 'heap', '[MASK]', 'before', 'the', 'cabin', 'doors', '.', 'indeed', '[MASK]', '[MASK]', 'was', 'recorded', 'in', 'blazing', '[MASK]', 'that', 'a', 'fortunate', '[MASK]', '[MASK]', '[MASK]', 'had', 'once', 'picked', 'up', 'on', '[MASK]', 'highway', 'a', 'solid', 'chunk', '[MASK]', '[MASK]', 'quartz', 'which', 'the', '[MASK]', 'had', 'freed', 'from', 'its', 'inc', '##umber', '##ing', 'soil', ',', 'and', 'washed', 'into', 'immediate', 'and', 'glittering', 'popularity', '[SEP]']
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

6. 截断序列对 truncate_seq_pair

  主体是一个while循环,该循环是个单层循环。Python列表赋值时是传递地址,而不是传递值。举例来说:

list_value = [1, 2, 3]
a = list_value
del a[0]
print(list_value)
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述

  • 第一个if条件语句说明的是如果a对应的token个数和b的个数之和小于等于max_num_tokens,则跳出该循环。
  • if else的三元操作符主要是在截断序列对时,保持a和b长度的平衡。为什么要保持平衡,假如len(b)=1,而a的token个数很多时,就使得序列对的语义就变成了a的语义。
  • 最后的if和else表明的含义是有一半的概率是从尾部截断,有另外一半概率是从头部截断。
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
  """Truncates a pair of sequences to a maximum sequence length."""
  while True:
    total_length = len(tokens_a) + len(tokens_b)
    if total_length <= max_num_tokens:
      break

    trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
    assert len(trunc_tokens) >= 1

    # We want to sometimes truncate from the front and sometimes from the
    # back to add more randomness and avoid biases.
    if rng.random() < 0.5:
      del trunc_tokens[0]
    else:
      trunc_tokens.pop()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

  那么如何加强对该函数的的印象呢?可以把tokens_a和tokens_b分别想象成两个瓶子,每个瓶子从上到下装多个球。每一次判断两个瓶子里面的球如果小于规定数,则结束循环。反之,则对球多的瓶子里面进行取球,取球要么是从最上面取要么是从最下面取。

7. 主函数

  读取得到关键参数如下:

  • tokenizer(分词器)
  • input_files(输入文件)
  • rng(随机种子)

  并最终执行训练例对象写入文件,另外需要注意的是,mark_flag_as_required指定的是必要输入参数,包括input_file、output_file、vocab_file。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/669969
推荐阅读
相关标签
  

闽ICP备14008679号