当前位置:   article > 正文

BERT主体网络代码详解

BERT主体网络代码详解

BERT(Bidirectional Encoder Representations from Transformers) 是Google AI Language由2019年发表的工作,其在某种意义上开创了NLP领域的新纪元。其采用了Transformer中的encoder架构进行工作,主要做的是MLM以及next sentence predict两个任务,其在大量的无标号的数据上进行预训练,之后进行fine-tune(微调)到相应的子任务数据集。

与之相对应的是openAI的GPT系列,GPT系列使用的transformer中的decoder架构。但是BERT的影响力至少是GPT系列的10倍,被众多研究者广泛使用。

在这里,我分享以下我对于BERT的主体网络代码的解析,因为时间有限,进行的比较仓促,难免会有错误,希望大家多多指教。

对于代码,我使用的是他人复现的pytorch版本。

paper:https://arxiv.org/pdf/1810.04805.pdf&usg=ALkJrhhzxlCL6yTht2BRmH9atgvKFxHsxQ

code:https://github.com/codertimo/BERT-pytorch

1. BERT主体网络代码,对于其中的Transformer的细节我并没有加进去,需要的同学可以去上面的仓库中寻找。

  1. import torch.nn as nn
  2. import torch
  3. import math
  4. from .attention import MultiHeadedAttention
  5. from .utils import SublayerConnection, PositionwiseFeedForward
  6. #-----------------------------------------------------------------#
  7. # BERT: Bidirectional Encoder Representations from Transformers
  8. # 可以直译为BERT:使用Transformer的双向编码器表示
  9. #-----------------------------------------------------------------#
  10. class BERT(nn.Module):
  11. def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
  12. super(BERT,self).__init__()
  13. """
  14. :param vocab_size: vocab_size of total words
  15. :param hidden: BERT model hidden size
  16. :param n_layers: numbers of Transformer blocks(layers),为Transformer中的encoder的层数
  17. :param attn_heads: number of attention heads
  18. :param dropout: dropout rate
  19. """
  20. self.hidden = hidden
  21. self.n_layers = n_layers
  22. self.attn_heads = attn_heads
  23. #----------------------------------------------#
  24. # 对于Feed Forward Network,paper中指出
  25. # 对于feed_forward_hidden他们使用了4*hidden_size
  26. #----------------------------------------------#
  27. self.feed_forward_hidden = hidden * 4
  28. #----------------------------------------------#
  29. # 对于BERT的编码操作,在这里是positional, segment
  30. # 以及token embeddings的sum操作
  31. #----------------------------------------------#
  32. self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)
  33. #-------------------------------------------------#
  34. # 具有多层transformer encoder的transformer结构
  35. #-------------------------------------------------#
  36. self.transformer_blocks = nn.ModuleList([TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])
  37. def forward(self, x, segment_info):
  38. # attention masking for padded token
  39. # torch.ByteTensor([batch_size, 1, seq_len, seq_len)
  40. mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
  41. #---------------------------------------------------#
  42. # 将一条句子嵌入转换为sequence of vectors
  43. #---------------------------------------------------#
  44. x = self.embedding(x, segment_info)
  45. #---------------------------------------------------#
  46. # pass through transformer layers
  47. #---------------------------------------------------#
  48. for transformer in self.transformer_blocks:
  49. x = transformer.forward(x, mask)
  50. return x
  51. #---------------------------------------------------------------------------------#
  52. # BERTEmbedding这一个类的定义
  53. # BERT Embedding which is consisted with under features
  54. # 1. TokenEmbedding : normal embedding matrix
  55. # 2. PositionalEmbedding : adding positional information using sin, cos
  56. # 3. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2)
  57. #---------------------------------------------------------------------------------#
  58. class BERTEmbedding(nn.Module):
  59. def __init__(self, vocab_size, embed_size, dropout=0.1):
  60. """
  61. :param vocab_size: total vocab size
  62. :param embed_size: embedding size of token embedding
  63. :param dropout: dropout rate
  64. """
  65. super(BERTEmbedding,self).__init__()
  66. self.embed_size = embed_size
  67. self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
  68. self.position = PositionalEmbedding(d_model=self.token.embedding_dim)
  69. self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
  70. self.dropout = nn.Dropout(p=dropout)
  71. #------------------------------------------------------------------------#
  72. # 将TokenEmbedding,PositionalEmbedding以及SegmentEmbedding
  73. # 三部分相加,之后过Dropout来降低过拟合发生的风险,最后得出BERTEmbedding这一个类的输出
  74. #------------------------------------------------------------------------#
  75. def forward(self, sequence, segment_label):
  76. x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
  77. return self.dropout(x)
  78. #-----------------------------------------------#
  79. # TokenEmbedding以及SegmentEmbedding这两个类的定义
  80. # 两者都继承nn.Embedding这个方法
  81. #-----------------------------------------------#
  82. class TokenEmbedding(nn.Embedding):
  83. def __init__(self, vocab_size, embed_size=512):
  84. super(TokenEmbedding,self).__init__(vocab_size, embed_size, padding_idx=0)
  85. class SegmentEmbedding(nn.Embedding):
  86. def __init__(self, embed_size=512):
  87. super(SegmentEmbedding,self).__init__(3, embed_size, padding_idx=0)
  88. #-----------------------------------#
  89. # PositionalEmbedding这一个类的定义
  90. #-----------------------------------#
  91. class PositionalEmbedding(nn.Module):
  92. def __init__(self, d_model, max_len=512):
  93. super(PositionalEmbedding,self).__init__()
  94. #---------------#
  95. # 位置信息的嵌入
  96. #---------------#
  97. pe = torch.zeros(max_len, d_model).float()
  98. pe.require_grad = False
  99. position = torch.arange(0, max_len).float().unsqueeze(1)
  100. div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
  101. #--------------------------------#
  102. # 使用sin对偶数位置上的token进行编码
  103. # 使用cos对奇数位置上的token进行编码
  104. #--------------------------------#
  105. pe[:, 0::2] = torch.sin(position * div_term)
  106. pe[:, 1::2] = torch.cos(position * div_term)
  107. pe = pe.unsqueeze(0)
  108. #------------------------------------------#
  109. # 对pe进行buffer注册,使得其可以保存到权重中,但不会
  110. # 随着训练的进行而进行梯度更新
  111. #------------------------------------------#
  112. self.register_buffer('pe', pe)
  113. def forward(self, x):
  114. return self.pe[:, :x.size(1)]
  115. #-----------------------------------------#
  116. # 对于Transformer模块的定义
  117. # 在BERT中其只使用了Transformer中的encoder结构
  118. # 而位置信息嵌入以及token Embedding则均在
  119. # BERTEmbedding中完成
  120. #-----------------------------------------#
  121. class TransformerBlock(nn.Module):
  122. """
  123. Bidirectional Encoder = Transformer (self-attention)
  124. Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
  125. """
  126. def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
  127. """
  128. :param hidden: hidden size of transformer
  129. :param attn_heads: head sizes of multi-head attention
  130. :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
  131. :param dropout: dropout rate
  132. """
  133. super(TransformerBlock,self).__init__()
  134. self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
  135. self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
  136. self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
  137. self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
  138. self.dropout = nn.Dropout(p=dropout)
  139. #-------------------------------------#
  140. # 1.run over MHA
  141. # 2.run over input_sublayer(add&norm)
  142. # 3.run over FFN
  143. # 4.run over output_sublayer(add&norm)
  144. # 5.run over dropout
  145. #-------------------------------------#
  146. def forward(self, x, mask):
  147. x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
  148. x = self.output_sublayer(x, self.feed_forward)
  149. return self.dropout(x)

 BERT所进行的两个任务的相关模型 1. MLM--Masked Language Model;2.Next Sentence Prediction Model.

  1. import torch.nn as nn
  2. from .bert import BERT
  3. #----------------------------------------------------------------#
  4. # 对于BERT Language Model的定义
  5. # Next Sentence Prediction Model + Masked Language Model
  6. # 当需要在一个类的定义中定义另外一个类的时候,我们需要的定义形式为 bert: BERT
  7. # bert为在该类中的定义,BERT为所调用的类的类名
  8. #----------------------------------------------------------------#
  9. class BERTLM(nn.Module):
  10. def __init__(self, bert: BERT, vocab_size):
  11. """
  12. :param bert: BERT model which should be trained
  13. :param vocab_size: total vocab size for masked_lm
  14. """
  15. super(BERTLM,self).__init__()
  16. self.bert = bert
  17. #---------------------------#
  18. # 对于所作用的两个任务的模型的定义
  19. #---------------------------#
  20. self.next_sentence = NextSentencePrediction(self.bert.hidden)
  21. self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size)
  22. #--------------------------------#
  23. # 会输出两个与概率相关的值
  24. #--------------------------------#
  25. def forward(self, x, segment_label):
  26. x = self.bert(x, segment_label)
  27. return self.next_sentence(x), self.mask_lm(x)
  28. #-----------------------------------------------------#
  29. # Next Sentence Prediction Model的定义
  30. # 2-class classification model : is_next, is_not_next
  31. #-----------------------------------------------------#
  32. class NextSentencePrediction(nn.Module):
  33. def __init__(self, hidden):
  34. """
  35. :param hidden: BERT model output size
  36. """
  37. super(NextSentencePrediction,self).__init__()
  38. #-------------------------------------#
  39. # Linear层将feature从BERT输出的size-> 2
  40. # 对应is_next, is_not_next这两个类
  41. # dim=-1表示沿着最后一个维度做LogSoftmax
  42. #-------------------------------------#
  43. self.linear = nn.Linear(hidden, 2)
  44. self.softmax = nn.LogSoftmax(dim=-1)
  45. #-------------------------------#
  46. # x[:, 0]表示取所有维度的第0个数据
  47. #-------------------------------#
  48. def forward(self, x):
  49. return self.softmax(self.linear(x[:, 0]))
  50. #------------------------------------------------------#
  51. # 对于MLM模型的定义
  52. # predicting origin token from masked input sequence
  53. # n-class classification problem, n-class = vocab_size
  54. #------------------------------------------------------#
  55. class MaskedLanguageModel(nn.Module):
  56. def __init__(self, hidden, vocab_size):
  57. """
  58. :param hidden: output size of BERT model
  59. :param vocab_size: total vocab size
  60. """
  61. super(MaskedLanguageModel,self).__init__()
  62. self.linear = nn.Linear(hidden, vocab_size)
  63. self.softmax = nn.LogSoftmax(dim=-1)
  64. #-----------------------------#
  65. # 表示对x做全连接之后再输出其概率分布
  66. #-----------------------------#
  67. def forward(self, x):
  68. return self.softmax(self.linear(x))

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/393971
推荐阅读
相关标签
  

闽ICP备14008679号