当前位置:   article > 正文

BertForSequenceClassification

bertforsequenceclassification

Anaconda3\envs\env_name\Lib\site-packages\transformers\models\bert\modeling_bert.py

  1. class BertForSequenceClassification(BertPreTrainedModel):
  2. def __init__(self, config):
  3. super().__init__(config)
  4. self.num_labels = config.num_labels
  5. self.config = config
  6. self.bert = BertModel(config)
  7. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  8. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  9. self.init_weights()
  10. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  11. @add_code_sample_docstrings(
  12. tokenizer_class=_TOKENIZER_FOR_DOC,
  13. checkpoint=_CHECKPOINT_FOR_DOC,
  14. output_type=SequenceClassifierOutput,
  15. config_class=_CONFIG_FOR_DOC,
  16. )
  17. def forward(
  18. self,
  19. input_ids=None,
  20. attention_mask=None,
  21. token_type_ids=None,
  22. position_ids=None,
  23. head_mask=None,
  24. inputs_embeds=None,
  25. labels=None,
  26. output_attentions=None,
  27. output_hidden_states=None,
  28. return_dict=None,
  29. ):
  30. r"""
  31. labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
  32. Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
  33. config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
  34. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  35. """
  36. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  37. outputs = self.bert(
  38. input_ids,
  39. attention_mask=attention_mask,
  40. token_type_ids=token_type_ids,
  41. position_ids=position_ids,
  42. head_mask=head_mask,
  43. inputs_embeds=inputs_embeds,
  44. output_attentions=output_attentions,
  45. output_hidden_states=output_hidden_states,
  46. return_dict=return_dict,
  47. )
  48. pooled_output = outputs[1]
  49. pooled_output = self.dropout(pooled_output)
  50. logits = self.classifier(pooled_output)
  51. loss = None
  52. if labels is not None:
  53. if self.config.problem_type is None:
  54. if self.num_labels == 1:
  55. self.config.problem_type = "regression"
  56. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  57. self.config.problem_type = "single_label_classification"
  58. else:
  59. self.config.problem_type = "multi_label_classification"
  60. if self.config.problem_type == "regression":
  61. loss_fct = MSELoss()
  62. if self.num_labels == 1:
  63. loss = loss_fct(logits.squeeze(), labels.squeeze())
  64. else:
  65. loss = loss_fct(logits, labels)
  66. elif self.config.problem_type == "single_label_classification":
  67. loss_fct = CrossEntropyLoss()
  68. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  69. elif self.config.problem_type == "multi_label_classification":
  70. loss_fct = BCEWithLogitsLoss()
  71. loss = loss_fct(logits, labels)
  72. if not return_dict:
  73. output = (logits,) + outputs[2:]
  74. return ((loss,) + output) if loss is not None else output
  75. return SequenceClassifierOutput(
  76. loss=loss,
  77. logits=logits,
  78. hidden_states=outputs.hidden_states,
  79. attentions=outputs.attentions,
  80. )
  81. @add_start_docstrings(
  82. """
  83. Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
  84. softmax) e.g. for RocStories/SWAG tasks.
  85. """,
  86. BERT_START_DOCSTRING,
  87. )
  1. def __init__(
  2. self,
  3. vocab_size=30522,
  4. hidden_size=768,
  5. num_hidden_layers=12,
  6. num_attention_heads=12,
  7. intermediate_size=3072,
  8. hidden_act="gelu",
  9. hidden_dropout_prob=0.1,
  10. attention_probs_dropout_prob=0.1,
  11. max_position_embeddings=512,
  12. type_vocab_size=2,
  13. initializer_range=0.02,
  14. layer_norm_eps=1e-12,
  15. pad_token_id=0,
  16. gradient_checkpointing=False,
  17. position_embedding_type="absolute",
  18. use_cache=True,
  19. **kwargs
  20. ):
  21. super().__init__(pad_token_id=pad_token_id, **kwargs)
  22. self.vocab_size = vocab_size
  23. self.hidden_size = hidden_size
  24. self.num_hidden_layers = num_hidden_layers
  25. self.num_attention_heads = num_attention_heads
  26. self.hidden_act = hidden_act
  27. self.intermediate_size = intermediate_size
  28. self.hidden_dropout_prob = hidden_dropout_prob
  29. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  30. self.max_position_embeddings = max_position_embeddings
  31. self.type_vocab_size = type_vocab_size
  32. self.initializer_range = initializer_range
  33. self.layer_norm_eps = layer_norm_eps
  34. self.gradient_checkpointing = gradient_checkpointing
  35. self.position_embedding_type = position_embedding_type
  36. self.use_cache = use_cache

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/235484
推荐阅读
相关标签
  

闽ICP备14008679号