当前位置:   article > 正文

Bert系列:BERT模型二分类demo以及讲解_bertforsequenceclassification.from_pretrained

bertforsequenceclassification.from_pretrained

主要内容:

使用torch和huggingface写二分类demo。

1.类别定义,将文本存放在list中,将label存放在另一个list中,这里举个二分类的的例子,输入类别用[0, 1]。如果是多分类,那么输入类别[0,1,2,.....n]。这里要求文本在'text'的位置跟类别在'target'中的位置对应。

2.对输入数据编码。汉字肯定不可直接作为模型的输入,将其根据词典进行编码,最后一堆数字输入到了模型中。

3.将文本和标签绑定,序列化到Dataloader中,开始训练模型。

4.优化器使用adam,然后开始更新参数,不断迭代,直到训练停止。

  1. # -*- encoding:utf-8 -*-
  2. import random
  3. import torch
  4. from torch.utils.data import TensorDataset, DataLoader, random_split
  5. from transformers import BertTokenizer, BertConfig
  6. from transformers import BertForSequenceClassification, AdamW
  7. from transformers import get_linear_schedule_with_warmup
  8. from sklearn.metrics import f1_score, accuracy_score
  9. import numpy as np
  10. # tokenizer用来对文本进行编码
  11. tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
  12. # 训练数据
  13. train = {
  14. 'text': [' 测试good',
  15. '美团 学习',
  16. ' 测试good',
  17. '美团 学习',
  18. ' 测试good',
  19. '美团 学习',
  20. ' 测试good',
  21. '美团 学习'],
  22. 'target': [0, 1, 0, 1, 0, 1, 0, 1],
  23. }
  24. # Get text values and labels
  25. text_values = train['text']
  26. labels = train['target']
  27. print('Original Text : ', text_values[0])
  28. print('Tokenized Ids: ', tokenizer.encode(text_values[0], add_special_tokens = True))
  29. print('Tokenized Text: ', tokenizer.decode(tokenizer.encode(text_values[0], add_special_tokens = True)))
  30. print('Token IDs : ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text_values[0])))
  31. # Function to get token ids for a list of texts
  32. def encode_fn(text_list):
  33. all_input_ids = []
  34. for text in text_list:
  35. input_ids = tokenizer.encode(
  36. text,
  37. add_special_tokens = True, # 添加special tokens, 也就是CLS和SEP
  38. max_length = 160, # 设定最大文本长度
  39. pad_to_max_length = True, # pad到最大的长度
  40. return_tensors = 'pt' # 返回的类型为pytorch tensor
  41. )
  42. all_input_ids.append(input_ids)
  43. all_input_ids = torch.cat(all_input_ids, dim=0)
  44. return all_input_ids
  45. # 对训练数据进行编码
  46. all_input_ids = encode_fn(text_values)
  47. labels = torch.tensor(labels)
  48. # 训练参数定义
  49. epochs = 1
  50. batch_size = 1
  51. # Split data into train and validation
  52. dataset = TensorDataset(all_input_ids, labels)
  53. train_size = int(0.75 * len(dataset))
  54. val_size = len(dataset) - train_size
  55. train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
  56. # Create train and validation dataloaders
  57. train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
  58. val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False)
  59. # Load the pretrained BERT model, num_labels=2表示类别是2
  60. model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2, output_attentions=False, output_hidden_states=True)
  61. print(model)
  62. # model.cuda()
  63. # create optimizer and learning rate schedule
  64. optimizer = AdamW(model.parameters(), lr=2e-5)
  65. total_steps = len(train_dataloader) * epochs
  66. # 表示学习率预热num_warmup_steps步后,再按照指定的学习率去更新参数
  67. scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
  68. flag = False
  69. total_batch, last_improve = 0, 0
  70. require_improvement = 1000
  71. for epoch in range(epochs):
  72. model.train()
  73. total_loss, total_val_loss = 0, 0
  74. # 开始训练
  75. for step, batch in enumerate(train_dataloader):
  76. # 梯度清零
  77. model.zero_grad()
  78. # 计算loss
  79. loss, logits, hidden_states = model(batch[0], token_type_ids=None, attention_mask=(batch[0] > 0),
  80. labels=batch[1])
  81. total_loss += loss.item()
  82. # 梯度回传
  83. loss.backward()
  84. torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
  85. # 梯度更新
  86. optimizer.step()
  87. scheduler.step()
  88. # model.eval()表示模型切换到eval模式,表示不会更新参数,只有在train模式下,才会更新梯度参数
  89. model.eval()
  90. for i, batch in enumerate(val_dataloader):
  91. with torch.no_grad():
  92. loss, logits, hidden_states = model(batch[0], token_type_ids=None, attention_mask=(batch[0] > 0),
  93. labels=batch[1])
  94. print(loss, logits)
  95. total_val_loss += loss.item()
  96. logits = logits.detach().cpu().numpy()
  97. label_ids = batch[1].to('cpu').numpy()
  98. avg_val_loss = total_val_loss / len(val_dataloader)
  99. if avg_val_loss < dev_best_loss:
  100. last_improve = total_batch
  101. if total_batch - last_improve > require_improvement:
  102. # 验证集loss超过1000batch没下降,结束训练
  103. print("No optimization for a long time, auto-stopping...")
  104. flag = True
  105. break
  106. if flag:
  107. break
  108. total_batch = total_batch + 1

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/235463
推荐阅读
相关标签
  

闽ICP备14008679号