当前位置:   article > 正文

Bert源码分析_bert扩展张量的维度

bert扩展张量的维度

Bert源码分析

本文主要从bert源码文件的run_classifier.py入手,观察研究整个fine-turning的数据流动过程。

在main函数下,程序首先检查各种输入参数是否正确合理,已经程序运行环境。然后开始构建模型

1.微调之前的数据加载

这部分后面补充

2.分类模型构建

run_classifier.py通过调用函数model_fn_builder()中的create_model()函数来构建微调模型,所谓的微调其实就是在Bert大模型后面又接了一层线性层。构建代码如下:

"""Creates a classification model."""
model=modeling.BertModel(config=bert_config,is_training=is_training,input_ids=input_ids,input_mask=input_mask,token_type_ids=segment_ids,use_one_hot_embeddings=use_one_hot_embeddings)
# 得到Bert模型的最后一层输出
output_layer = model.get_pooled_output()
# 获取隐藏层的大小
hidden_size = output_layer.shape[-1].value
# 初始化参数
output_weights = tf.get_variable("output_weights", [num_labels,hidden_size],initializer=tf.truncated_normal_initializer(stddev=0.02))
# 初始化偏置项
output_bias = tf.get_variable("output_bias",[num_labels],initializer=tf.zeros_initializer())
# 定义分类模型
with tf.variable_scope("loss"):  
    if is_training:    
        output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)  
        logits = tf.matmul(output_layer, output_weights, transpose_b=True)  
        logits = tf.nn.bias_add(logits, output_bias)  
        probabilities = tf.nn.softmax(logits, axis=-1)  
        log_probs = tf.nn.log_softmax(logits, axis=-1)  
        one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
        per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)  
        loss = tf.reduce_mean(per_example_loss)  
        return (loss, per_example_loss, logits, probabilities)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

3.Bert模型的构建

从上述代码中可见看到,后面的分类层是一个很简单的线性层,Bert的主要性能贡献全部都在预训练层,接下来详解代码:

model=modeling.BertModel(config=bert_config,is_training=is_training,input_ids=input_ids,input_mask=input_mask,token_type_ids=segment_ids,use_one_hot_embeddings=use_one_hot_embeddings)
  • 1

这就需要把目光移到Bert源码文件中的modeling文件中,BertModel是该文件下的一个类,在类实例化的过程中,有一个初始化操作,就是modeling文件中

BertModel.__int__(self,config,is_training,input_ids,input_mask=None,token_type_ids=None,
use_one_hot_embeddings=False,scope=None):
  • 1
  • 2

初始化函数构建模型,当然再次之前,Bert模型的超参数通过BertConfig类已经加载到内存中了,从这个初始化函数中就可以慢慢还原出Bert模型的全貌。

3.1Bert模型超参数解读

以Bert-base模型为例:

参数设置 表示含义
attention_probs_dropout_prob: 0.1 attention层drop_out概率为0.1
hidden_act: gelu 隐藏层激活函数为gelu
hidden_dropout_prob: 0.1 隐藏层drop_out概率为0.1
hidden_size: 768 隐藏层节点说为768,这个好像就是768个神经元节点
initializer_range: 0.02 参数全部初始化为标准差为0.02的随机数
intermediate_size: 3072 transformer中前馈层的大小(feed_forward)
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/348682
推荐阅读
相关标签
  

闽ICP备14008679号