赞
踩
看了好多 Bert 的介绍文章,少有从源码层面理解Bert模型的,本文章将根据一行一行源码来深入理解Bert模型
BertBertModel 类的源码如下(删除注释)
class BertModel(BertPreTrainedModel): def __init__(self, config): super(BertModel, self).__init__(config) self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 embedding_output = self.embeddings(input_ids, token_type_ids) encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers) sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output) if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] return encoded_layers, pooled_output
首先分析BertBertModel 类的__init__函数(构造函数),__init__函数分别定义了embeddings 、encoder 、pooler三个模块,这三个模块分别是词嵌入模块,encoder模块和分类层模块;然后对模型的参数进行了初始化,通过__init__函数我们也能了解到,Bert主要就是由这三个模块组成,关于这三个模块的详细解析请参考以下文章:
forward函数是Bert模型的向前传播函数,作用是将数据从模型的输入传送到输出;
forward函数有四个参数,分别是 input_ids、token_type_ids、attention_mask 和 output_all_encoded_layers,四个参数的解释如下:
参数 | 含义 | 维度 |
---|---|---|
input_ids | token在词汇表中索引组成的数组 | [batch_size, sequence_length] |
token_type_ids | 用于标识当前token属于哪一个句向量(0属于第一句,1属于第二句) | [batch_size, sequence_length] |
attention_mask | 如果输入序列长度小于当前批次中的最大输入序列长度,则使用此掩码,用于指示序列的那些输入需要被Mask,当前位置是小于等于真实长度值为1 大于为0 | [batch_size,sequence_length] |
output_all_encoded_layers | True:输出全部12层encoder的输出,False:只输出最后一层encoder的值 | / |
在 forward 函数中,首先对 token_type_ids
和 attention_mask
参数为None值的情况进行了处理;当 token_type_ids
为 None 时,生成一个 [batch_size, sequence_length]
形状的数组赋值给token_type_ids
并将 token_type_ids
所有位置置为0,表示每个序列中只包含一个句子;当attention_mask
为None
时,生成一个[batch_size, sequence_length]
形状的数组赋值给attention_mask
并将attention_mask
中的所有值置为1,表示当前序列的所有数据都为有效数据;
然后 extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
这句代码的含义是将extended_attention_mask 维度连续扩充两次,依次是 [batch_size,sequence_length]->[batch_size,1,sequence_length] ->[batch_size,1,1,sequence_length]
关于 Pytorch 中 unsqueeze()和 squeeze()函数的详细介绍,请参考这篇博客->#彻底理解# pytorch 中的 squeeze() 和 unsqueeze()函数
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
(上一行是格式转换,此处略过)这行代码的作用是将 extended_attention_mask 中的 0 变为 -10000,(0表示此位置,不是有效数据,当-10000进入self attention后,权重会变得非常小乃至可以这些token)
embedding_output = self.embeddings(input_ids, token_type_ids)
这行代码的作用是根据 input_ids,token_type_ids 两个矩阵生成 token 对应的 word embedding,( 熟悉 Bert 原理的朋友都知道,除了这两个输入矩阵,还有一个位置编码矩阵,这个矩阵会在 embeddings 模块内部自动生成)
encoded_layers = self.encoder...
这句代码的意思是将 embedding 层的输出输入到 encoder
模块 并将 encoder 模块的输出赋值给 encoded_layers
,其中encoded_layers是一个长度为12的数组,保存着12层encoder每层对应的输出
sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output)
这两句代码的作用是生成 Bert 的pooled_output
输出;做法是:将encoder
最后一层的输出赋值给 Bert 的 pooler
模块(输入768,输出768,tanh激活函数的全连接层)并将结果返回;虽然pool模块的输入为768,但只使用第一个 token([CLS])的对应的输入,因此 pooled_output
输出一般用于解决 语句级别的任务。
根据 output_all_encoded_layers
参数的值返回所有12层encoder
的输出或仅输出最后一层encoder
的输出;同时输出 pooled_output
想深入了解Bert模型原理的朋友可以阅读这5篇文章:
接下来将进一步分析 Bert 模型中的 embeddings、encoder、pooler 等模块,链接如下:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。