当前位置:   article > 正文

jupyter快速实现单标签及多标签多分类的文本分类BERT模型_bert模型jupyter notebook版本

bert模型jupyter notebook版本

jupyter实现pytorch版BERT(单标签分类版)

nlp-notebooks/Text classification with BERT in PyTorch.ipynb

通过改写上述代码,实现多标签分类

参考解决方案 ,我选择的解决方案是继承BertForSequenceClassification并改写,即将上述代码的ln [9] 改为以下内容:

from transformers.modeling_bert import BertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput

class BertForMultilabelSequenceClassification(BertForSequenceClassification):
   def __init__(self, config):
     super().__init__(config)

   def forward(self,
       input_ids=None,
       attention_mask=None,
       token_type_ids=None,
       position_ids=None,
       head_mask=None,
       inputs_embeds=None,
       labels=None,
       output_attentions=None,
       output_hidden_states=None,
       return_dict=None):
       return_dict = return_dict if return_dict is not None else self.config.use_return_dict

       outputs = self.bert(input_ids,
           attention_mask=attention_mask,
           token_type_ids=token_type_ids,
           position_ids=position_ids,
           head_mask=head_mask,
           inputs_embeds=inputs_embeds,
           output_attentions=output_attentions,
           output_hidden_states=output_hidden_states,
           return_dict=return_dict)

       pooled_output = outputs[1]
       pooled_output = self.dropout(pooled_output)
       logits = self.classifier(pooled_output)

       loss = None
       if labels is not None:
           loss_fct = torch.nn.BCEWithLogitsLoss()
           loss = loss_fct(logits.view(-1, self.num_labels), 
                           labels.float().view(-1, self.num_labels))

       if not return_dict:
           output = (logits,) + outputs[2:]
           return ((loss,) + output) if loss is not None else output

       return SequenceClassifierOutput(loss=loss,
           logits=logits,
           hidden_states=outputs.hidden_states,
           attentions=outputs.attentions)
           
model = BertForMultilabelSequenceClassification.from_pretrained(BERT_MODEL, num_labels = len(label2idx))
model.to(device)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/447899
推荐阅读
相关标签
  

闽ICP备14008679号