当前位置:   article > 正文

pytorch bert文本分类_一起读Bert文本分类代码 (pytorch篇 五)

pytorch_pretrained_bert.modeling

bbf595021088df7c73eca73b859c1fba.png

Bert是去年google发布的新模型,打破了11项纪录,关于模型基础部分就不在这篇文章里多说了。这次想和大家一起读的是huggingface的pytorch-pretrained-BERT代码examples里的文本分类任务run_classifier。

关于源代码可以在huggingface的github中找到。

huggingface/pytorch-pretrained-BERT​github.com
2b4b57c491ca5e3e78c292173f50c2a6.png

在前四篇文章中我分别介绍了数据预处理部分和部分的模型:

周剑:一起读Bert文本分类代码 (pytorch篇 一)​zhuanlan.zhihu.com
596fda2a9089759e0a51b75ebda903c2.png
周剑:一起读Bert文本分类代码 (pytorch篇 二)​zhuanlan.zhihu.com
596fda2a9089759e0a51b75ebda903c2.png
周剑:一起读Bert文本分类代码 (pytorch篇 三)​zhuanlan.zhihu.com
596fda2a9089759e0a51b75ebda903c2.png
周剑:一起读Bert文本分类代码 (pytorch篇 四)​zhuanlan.zhihu.com
596fda2a9089759e0a51b75ebda903c2.png

我们可以看到BertForSequenceClassification类中调用关系如下图所示。本篇文章中我会带着大家继续读BertLayer类中的BertIntermediate和BertOutput类。

0bfe5f36bc666ad162e7a3dd064bacca.png

打开pytorch_pretrained_bert.modeling.py,找到BertIntermediate类,代码如下:

  1. class BertIntermediate(nn.Module):
  2. def __init__(self, config):
  3. super(BertIntermediate, self).__init__()
  4. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  5. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  6. if isinstance(config.hidden_act, str) else config.hidden_act
  7. def forward(self, hidden_states):
  8. hidden_states = self.dense(hidden_states)
  9. hidden_states = self.intermediate_act_fn(hidden_states)
  10. return hidden_states

我们可以看到dense是一个线形Linear层,输入size是config.hidden_size,输出size是config.intermediate_size。ACT2FN是激活函数的字典,它的代码如下:

ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}

所以BertIntermediate是一个线形Linear层加激活函数。

再找到BertOutput类,代码如下:

  1. class BertOutput(nn.Module):
  2. def __init__(self, config):
  3. super(BertOutput, self).__init__()
  4. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  5. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  6. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  7. def forward(self, hidden_states, input_tensor):
  8. hidden_states = self.dense(hidden_states)
  9. hidden_states = self.dropout(hidden_states)
  10. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  11. return hidden_states

可以看到BertOutput是一个输入size是config.intermediate_size,输出size是config.hidden_size。又把size从BertIntermediate中的config.intermediate_size变回config.hidden_size。然后又接了一个Dropout和一个归一化。

到此为止,我就已经和大家一起读了model里调用的所有函数。我们再从数据的forward向来总结一下代码中的BertForSequenceClassification模型。

在BertForSequenceClassification这个model中,输入的数据首先经过BertEmbeddings类。在BertEmbeddings中将每个单词变为words_embeddings + position_embeddings +token_type_embeddings三项embeddings的和。

然后,把已经变为词向量的数据输入BertSelfAttention类中。BertSelfAttention类中是一个Multi-Head Attention(少一个Linear层), 也就是说数据流入这个少一个Linear层的Multi-Head Attention。

之后,数据流入BertSelfOutput类。BertSelfOutput是一个Linear+Dropout+LayerNorm。补齐了BertSelfAttention中少的那个Linear层,并且进行一次LayerNorm。这样就完成了Transformer中前半的任务,即下图的红框部分。

29ef188439c38cf85be1164f72fa57a2.png

再之后,数据经过BertIntermediate和BertOutput。他们分别是今天介绍的Linear层+激活函数和Linear+Dropout+LayerNorm。这样整个Transformer的部分就算完成了。

最后,数据再流回BertForSequenceClassification这个类中,经过一个Linear层分类,输出变为一个和标签size大小一致的列表。这就是整个BertForSequenceClassification模型。

在下一篇文章中,我会重回run_classifier.py的主函数。和大家一起读代码中优化器,训练和预测部分。

周剑:一起读Bert文本分类代码 (pytorch篇 六)​zhuanlan.zhihu.com
596fda2a9089759e0a51b75ebda903c2.png
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/312386
推荐阅读
相关标签
  

闽ICP备14008679号