当前位置:   article > 正文

利用emotion数据集微调deberta-v3-large大模型的文本分类

emotion数据集

概念

 数据集:

1、SetFit/emotion at main

模型:

1、microsoft/deberta-v3-large · Hugging Face

(备注:我都是把数据集和模型下载到本地的)

微调

代码:

  1. from datasets import load_dataset
  2. from sklearn.metrics import accuracy_score,f1_score
  3. from transformers import Trainer,TrainingArguments
  4. from transformers import AutoModelForSequenceClassification,AutoTokenizer
  5. tokenizer = AutoTokenizer.from_pretrained("deberta-v3-large")
  6. model = AutoModelForSequenceClassification.from_pretrained("deberta-v3-large",num_labels=6)
  7. emotions = load_dataset("emotion")
  8. def tokenize(batch):
  9. return tokenizer(batch["text"],padding=True,truncation=True)
  10. tokenized_emotions = emotions.map(tokenize,batched=True,batch_size=None)
  11. tokenized_emotions.set_format("torch",columns=["input_ids","attention_mask","label"])
  12. def compute_metrics(pred):
  13. labels = pred.label_ids
  14. preds = pred.predictions.argmax(-1)
  15. f1 = f1_score(labels,preds,average="weighted")
  16. acc = accuracy_score(labels,preds)
  17. return {"accuracy":acc,"f1":f1}
  18. training_args = TrainingArguments(output_dir="result",)
  19. trainer = Trainer(model=model,args=training_args,compute_metrics=compute_metrics,
  20. train_dataset=tokenized_emotions["train"]
  21. ,eval_dataset=tokenized_emotions["validation"])
  22. trainer.train()

结果:

在result目录下面生成checkpoint:

  1. [ipa@comm-agi checkpoint-500]$ pwd
  2. /data2/result/checkpoint-500
  3. [ipao@comm-agi checkpoint-500]$ ll
  4. total 5098944
  5. -rw-rw-r-- 1 ipa ipa 1141 Sep 5 14:59 config.json
  6. -rw-rw-r-- 1 ipa ipa 3480872507 Sep 5 14:59 optimizer.pt
  7. -rw-rw-r-- 1 ipa ipa 1740408181 Sep 5 14:59 pytorch_model.bin
  8. -rw-rw-r-- 1 ipa ipa 14575 Sep 5 14:59 rng_state.pth
  9. -rw-rw-r-- 1 ipa ipa 627 Sep 5 14:59 scheduler.pt
  10. -rw-rw-r-- 1 ipa ipa 533 Sep 5 14:59 trainer_state.json
  11. -rw-rw-r-- 1 ipa ipa 4027 Sep 5 14:59 training_args.bin

验证:

代码:

  1. from datasets import load_dataset
  2. from sklearn.metrics import accuracy_score,f1_score
  3. from transformers import Trainer,TrainingArguments
  4. from transformers import AutoModelForSequenceClassification,AutoTokenizer
  5. tokenizer = AutoTokenizer.from_pretrained("deberta-v3-large")
  6. model = AutoModelForSequenceClassification.from_pretrained("result/checkpoint-500",num_labels=6)
  7. emotions = load_dataset("diting")
  8. def tokenize(batch):
  9. return tokenizer(batch["text"],padding=True,truncation=True)
  10. tokenized_emotions = emotions.map(tokenize,batched=True,batch_size=None)
  11. tokenized_emotions.set_format("torch",columns=["input_ids","attention_mask","label"])
  12. def compute_metrics(pred):
  13. labels = pred.label_ids
  14. preds = pred.predictions.argmax(-1)
  15. f1 = f1_score(labels,preds,average="weighted")
  16. acc = accuracy_score(labels,preds)
  17. return {"accuracy":acc,"f1":f1}
  18. training_args = TrainingArguments(output_dir="result",)
  19. trainer = Trainer(model=model,args=training_args,compute_metrics=compute_metrics,
  20. train_dataset=tokenized_emotions["train"]
  21. ,eval_dataset=tokenized_emotions["validation"])
  22. results = trainer.evaluate()
  23. print(results)

结果:

  1. {
  2. 'eval_loss': 0.24527376890182495,
  3. 'eval_accuracy': 0.923,
  4. 'eval_f1': 0.9251701895610343,
  5. 'eval_runtime': 26.781,
  6. 'eval_samples_per_second': 74.68,
  7. 'eval_steps_per_second': 1.195
  8. }
  • eval_loss:在验证集上计算的损失值。
  • eval_accuracy:在验证集上计算的准确率。
  • eval_f1:在验证集上计算的F1分数。
  • eval_runtime:模型在验证集上运行的总时间。
  • eval_samples_per_second:模型在验证集上处理的样本数每秒。
  • eval_steps_per_second:模型在验证集上处理的步骤数每秒。

F1分数

F1分数的取值范围在0~1之间,值越大表示模型的性能越好。通常情况下,F1分数的区间可以分为以下几个等级:

  • F1分数 < 0.1,表示模型的性能非常差,预测结果基本没有准确性可言;
  • 0.1 < F1分数 < 0.3,表示模型的性能比较差,预测结果有一定的准确性,但还有很大的提升空间;
  • 0.3 < F1分数 < 0.5,表示模型的性能一般,预测结果的准确性有一定的保障,但还有进步的空间;
  • 0.5 < F1分数 < 0.7,表示模型的性能较好,预测结果的准确性比较高,但还有进一步提升的空间;
  • 0.7 < F1分数 < 0.9,表示模型的性能很好,预测结果的准确性非常高,但还有进一步提升的空间;
  • F1分数 > 0.9,表示模型的性能非常优秀,预测结果的准确性非常接近完美。

预测:

代码:

  1. from datasets import load_dataset
  2. from sklearn.metrics import accuracy_score,f1_score
  3. from transformers import Trainer,TrainingArguments
  4. from transformers import AutoModelForSequenceClassification,AutoTokenizer
  5. tokenizer = AutoTokenizer.from_pretrained("deberta-v3-large")
  6. model = AutoModelForSequenceClassification.from_pretrained("result/checkpoint-500",num_labels=6)
  7. emotions = load_dataset("emotion")
  8. def tokenize(batch):
  9. return tokenizer(batch["text"],padding=True,truncation=True)
  10. tokenized_emotions = emotions.map(tokenize,batched=True,batch_size=None)
  11. tokenized_emotions.set_format("torch",columns=["input_ids","attention_mask","label"])
  12. def compute_metrics(pred):
  13. labels = pred.label_ids
  14. preds = pred.predictions.argmax(-1)
  15. f1 = f1_score(labels,preds,average="weighted")
  16. acc = accuracy_score(labels,preds)
  17. return {"accuracy":acc,"f1":f1}
  18. training_args = TrainingArguments(output_dir="result",)
  19. trainer = Trainer(model=model,args=training_args,compute_metrics=compute_metrics,
  20. train_dataset=tokenized_emotions["train"]
  21. ,eval_dataset=tokenized_emotions["validation"])
  22. texts = ["This is a good movie.", "I do not like this product."]
  23. max_length = 512
  24. predictor = lambda text: tokenizer(text, padding=True, truncation=True)
  25. inputs = list(map(predictor, texts))
  26. predictions = trainer.predict(tokenized_emotions["test"])
  27. predicted_labels = predictions.predictions.argmax(-1).tolist()
  28. print(predicted_labels)

结果:

[0, 0, 0, 1, 0, 4, 4, 2, 1, 3, 4, 0, 4, 1, 2, 0, 1, 0, 3, 1, 2, 1, 1, 4, 0, 4, 3, 0, 4, 3, 4, 3, 0, 3, 0, 1, 1, 0, 1, 1, 3, 0, 1, 0, 1, 3, 1, 1, 4, 4, 0, 4, 1, 0, 1, 0, 0, 1, 0, 3, 0, 0, 1, 1, 0, 5, 0, 4, 4, 5, 1, 2, 5, 2, 2, 3, 1, 0, 1, 2, 1, 3, 0, 1, 0, 0, 2, 1, 1, 0, 1, 4, 3, 4, 3, 3, 2, 0, 4, 0, 0, 0, 0, 4, 3, 3, 1, 1, 5, 0, 1, 2, 4, 1, 0, 1, 1, 4, 0, 1, 0, 2, 0, 3, 0, 2, 0, 4, 0, 0, 1, 2, 0, 3, 3, 1, 4, 4, 0, 1, 1, 0, 4, 1, 1, 0, 1, 4, 4, 2, 4, 2, 4, 0, 1, 0, 1, 1, 3, 0, 3, 3, 1, 4, 4, 1, 2, 2, 2, 0, 2, 3, 1, 1, 0, 3, 1, 1, 0, 0, 4, 1, 0, 2, 4, 0, 1, 1, 5, 3, 1, 0, 1, 3, 0, 0, 4, 4, 1, 1, 1, 2, 2, 1, 1, 1, 2, 4, 4, 1, 3, 0, 1, 4, 1, 0, 3, 0, 3, 3, 1, 4, 5, 1, 1, 1, 3, 1, 2, 4, 0, 0, 5, 1, 4, 1, 0, 1, 0, 1, 0, 2, 4, 1, 0, 0, 0, 3, 1, 5, 0, 4, 0, 3, 2, 0, 1, 1, 0, 3, 3, 3, 2, 3, 3, 2, 1, 1, 3, 0, 3, 3, 4, 0, 3, 0, 1, 4, 3, 0, 1, 0, 0, 1, 1, 1, 2, 0, 1, 1, 5, 4, 2, 0, 2, 5, 3, 0, 0, 3, 1, 3, 2, 2, 2, 4, 4, 2, 4, 4, 1, 0, 4, 1, 3, 3, 2, 5, 4, 4, 0, 0, 4, 1, 0, 0, 2, 3, 0, 0, 4, 0, 0, 2, 0, 2, 4, 2, 1, 0, 3, 2, 5, 1, 1, 1, 1, 4, 4, 1, 0, 0, 0, 2, 1, 2, 2, 0, 0, 1, 3, 0, 0, 3, 1, 1, 2, 0, 3, 2, 0, 1, 1, 3, 1, 0, 2, 0, 3, 0, 4, 3, 5, 4, 1, 3, 3, 4, 1, 0, 0, 2, 0, 0, 4, 0, 3, 1, 0, 4, 4, 1, 5, 0, 2, 1, 1, 0, 2, 0, 4, 3, 1, 1, 3, 5, 4, 0, 4, 4, 1, 5, 4, 1, 0, 1, 0, 1, 3, 4, 3, 0, 4, 4, 4, 0, 2, 3, 3, 0, 1, 5, 3, 0, 0, 1, 0, 0, 1, 3, 0, 3, 0, 0, 4, 0, 2, 5, 1, 3, 4, 1, 0, 0, 0, 1, 2, 0, 1, 5, 4, 0, 1, 1, 3, 1, 4, 3, 1, 4, 0, 0, 3, 1, 1, 0, 4, 3, 1, 2, 1, 0, 1, 4, 1, 1, 2, 0, 1, 4, 3, 4, 4, 1, 1, 1, 0, 3, 0, 0, 1, 1, 1, 2, 4, 1, 1, 1, 1, 2, 1, 3, 0, 3, 1, 0, 4, 0, 3, 2, 1, 2, 1, 1, 1, 2, 4, 4, 1, 0, 2, 1, 2, 5, 0, 3, 2, 1, 1, 0, 1, 0, 5, 0, 1, 1, 0, 3, 1, 1, 1, 3, 0, 1, 4, 0, 0, 0, 0, 1, 1, 0, 1, 0, 2, 2, 2, 3, 4, 1, 1, 1, 0, 1, 0, 4, 4, 0, 0, 1, 4, 0, 0, 2, 0, 2, 5, 1, 0, 4, 0, 3, 0, 3, 3, 0, 1, 1, 2, 1, 3, 0, 1, 5, 2, 2, 0, 1, 1, 1, 0, 2, 1, 4, 3, 4, 4, 2, 1, 2, 2, 1, 4, 1, 0, 1, 1, 2, 1, 2, 1, 0, 1, 2, 1, 1, 1, 1, 3, 0, 2, 1, 0, 1, 3, 1, 0, 3, 4, 1, 1, 0, 3, 0, 3, 3, 4, 0, 4, 0, 2, 5, 1, 1, 1, 1, 3, 0, 1, 0, 0, 5, 1, 0, 0, 5, 0, 4, 3, 0, 0, 0, 4, 0, 4, 1, 0, 4, 3, 1, 3, 1, 5, 0, 0, 4, 1, 1, 1, 0, 0, 3, 0, 2, 4, 2, 1, 4, 3, 1, 1, 4, 1, 1, 3, 0, 4, 0, 1, 2, 1, 0, 0, 1, 0, 2, 4, 4, 3, 3, 1, 1, 1, 0, 1, 1, 2, 1, 0, 0, 0, 1, 1, 0, 4, 0, 3, 3, 1, 2, 1, 1, 4, 1, 0, 2, 2, 2, 0, 3, 1, 1, 3, 1, 0, 1, 0, 4, 1, 1, 4, 1, 1, 3, 0, 1, 1, 1, 1, 2, 1, 0, 4, 1, 1, 5, 0, 1, 0, 1, 1, 1, 2, 0, 0, 1, 0, 0, 1, 0, 3, 0, 1, 5, 1, 2, 1, 0, 0, 1, 4, 2, 4, 3, 4, 1, 0, 0, 0, 0, 4, 1, 5, 2, 4, 0, 3, 0, 0, 0, 1, 0, 1, 1, 5, 2, 0, 0, 4, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 4, 0, 0, 1, 0, 0, 1, 0, 2, 1, 4, 0, 0, 1, 5, 0, 1, 1, 5, 4, 3, 0, 4, 0, 1, 2, 0, 2, 1, 2, 1, 3, 1, 2, 2, 1, 5, 0, 1, 3, 2, 5, 1, 0, 3, 0, 3, 0, 2, 0, 1, 0, 1, 1, 0, 4, 2, 0, 0, 1, 0, 4, 2, 3, 1, 0, 4, 1, 1, 3, 1, 0, 3, 1, 0, 2, 3, 1, 1, 1, 4, 4, 0, 0, 5, 0, 0, 1, 0, 0, 1, 1, 0, 4, 1, 1, 2, 1, 2, 2, 0, 0, 1, 1, 0, 1, 3, 3, 4, 3, 1, 1, 3, 0, 1, 3, 4, 2, 3, 1, 0, 2, 0, 3, 2, 1, 1, 1, 0, 0, 0, 2, 1, 2, 4, 5, 2, 1, 0, 2, 1, 3, 5, 3, 1, 4, 3, 0, 1, 4, 1, 1, 2, 0, 2, 0, 2, 4, 1, 0, 3, 0, 1, 4, 4, 4, 4, 1, 0, 1, 3, 0, 3, 0, 1, 1, 4, 4, 1, 4, 1, 1, 0, 1, 0, 3, 1, 5, 5, 1, 5, 1, 3, 5, 1, 0, 1, 0, 2, 0, 2, 5, 1, 2, 0, 1, 3, 1, 1, 5, 0, 2, 1, 3, 0, 0, 1, 3, 3, 0, 2, 3, 2, 0, 3, 0, 2, 0, 1, 4, 3, 2, 4, 0, 4, 0, 3, 0, 3, 5, 3, 0, 0, 4, 1, 0, 1, 4, 3, 0, 2, 4, 4, 1, 0, 1, 2, 0, 3, 0, 3, 1, 2, 0, 0, 1, 0, 1, 1, 3, 0, 1, 4, 0, 0, 1, 0, 3, 1, 1, 3, 1, 0, 3, 1, 1, 1, 1, 3, 3, 0, 1, 0, 2, 0, 1, 1, 3, 1, 1, 1, 4, 1, 3, 2, 0, 0, 4, 1, 1, 1, 4, 4, 1, 0, 1, 3, 2, 1, 4, 5, 1, 1, 4, 4, 0, 5, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 4, 1, 3, 0, 5, 0, 0, 5, 2, 2, 0, 1, 1, 0, 0, 0, 0, 1, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 3, 1, 5, 3, 0, 0, 0, 1, 5, 0, 0, 0, 1, 2, 0, 1, 1, 0, 0, 1, 3, 4, 0, 1, 4, 3, 4, 0, 1, 4, 1, 0, 1, 1, 4, 1, 1, 1, 0, 1, 1, 5, 1, 5, 1, 1, 4, 4, 1, 3, 4, 1, 0, 0, 1, 4, 0, 1, 0, 1, 0, 0, 4, 0, 4, 0, 0, 3, 3, 2, 1, 1, 0, 4, 5, 1, 5, 1, 1, 1, 4, 1, 5, 1, 1, 1, 0, 3, 3, 4, 4, 1, 1, 4, 0, 3, 1, 0, 2, 1, 1, 2, 0, 0, 1, 0, 0, 0, 2, 0, 1, 1, 0, 1, 2, 2, 1, 4, 1, 1, 3, 1, 3, 4, 1, 1, 2, 1, 2, 1, 1, 1, 4, 0, 0, 4, 0, 1, 2, 3, 0, 1, 0, 0, 3, 2, 0, 1, 1, 2, 4, 4, 4, 1, 2, 2, 1, 1, 1, 1, 1, 0, 5, 2, 0, 3, 4, 3, 1, 3, 2, 4, 0, 3, 1, 1, 1, 4, 1, 1, 4, 3, 0, 0, 1, 2, 1, 3, 4, 4, 0, 0, 3, 4, 3, 1, 1, 1, 0, 0, 1, 0, 1, 0, 2, 2, 0, 0, 1, 4, 4, 0, 1, 5, 0, 2, 1, 1, 1, 5, 3, 0, 1, 0, 0, 1, 2, 1, 1, 2, 0, 2, 4, 5, 1, 4, 1, 4, 3, 0, 1, 1, 0, 0, 2, 4, 2, 1, 1, 1, 0, 0, 3, 5, 1, 1, 2, 0, 1, 3, 4, 1, 0, 1, 4, 2, 1, 3, 1, 1, 1, 0, 0, 0, 1, 3, 4, 4, 5, 0, 1, 1, 1, 1, 1, 4, 1, 0, 1, 0, 0, 3, 2, 1, 1, 0, 4, 1, 3, 4, 1, 0, 4, 3, 1, 1, 1, 1, 0, 2, 3, 3, 1, 2, 0, 0, 0, 1, 1, 0, 5, 1, 0, 0, 3, 0, 2, 1, 4, 4, 1, 4, 0, 3, 1, 0, 1, 0, 3, 2, 4, 1, 0, 4, 4, 4, 0, 1, 0, 0, 0, 2, 3, 2, 3, 0, 2, 2, 1, 3, 2, 4, 1, 0, 0, 1, 3, 5, 0, 1, 3, 0, 1, 0, 4, 2, 0, 4, 4, 1, 3, 3, 1, 0, 4, 1, 4, 1, 1, 0, 0, 1, 4, 3, 3, 1, 3, 3, 1, 0, 4, 4, 5, 1, 1, 5, 1, 1, 1, 3, 2, 4, 0, 0, 1, 1, 1, 3, 2, 0, 1, 4, 3, 4, 1, 4, 0, 0, 0, 0, 1, 0, 0, 0, 0, 3, 0, 0, 1, 1, 1, 4, 1, 3, 0, 5, 3, 3, 2, 3, 0, 1, 2, 2, 1, 1, 1, 1, 0, 3, 1, 1, 1, 3, 0, 0, 0, 3, 1, 2, 0, 1, 0, 1, 0, 1, 2, 2, 3, 1, 3, 0, 2, 0, 4, 1, 4, 1, 0, 5, 1, 3, 1, 3, 1, 1, 0, 2, 3, 2, 2, 1, 5, 4, 4, 3, 3, 2, 4, 3, 0, 4, 4, 2, 1, 5, 4, 0, 1, 0, 0, 1, 1, 0, 2, 1, 1, 0, 4, 1, 4, 1, 2, 2, 1, 1, 0, 1, 1, 0, 2, 1, 5, 5, 1, 1, 0, 0, 4, 0, 1, 3, 4, 4, 1, 0, 4, 0, 0, 5, 0, 1, 3, 0, 3, 5, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 4, 3, 4, 4, 0, 0, 0, 1, 3, 1, 2, 3, 0, 0, 1, 3, 1, 1, 5, 0, 0, 4, 1, 1, 3, 2, 1, 3, 2, 4, 0, 1, 1, 0, 5, 1, 2, 1, 0, 1, 1, 2, 1, 2, 1, 4, 1, 0, 4, 1, 3, 1, 0, 2, 3, 1, 0, 0, 1, 0, 3, 1, 1, 4, 1, 1, 4, 1, 4, 4, 0, 0, 1, 0, 2, 1, 0, 1, 0, 2, 1, 2, 3, 1, 3, 0, 2, 0, 0, 0, 1, 2, 0, 0, 2, 0, 1, 0, 0, 3, 0, 3, 4, 5, 4, 3, 2, 1, 1, 3, 0, 0, 1, 1, 0, 0, 1, 5, 0, 0, 0, 3, 0, 1, 0, 1, 1, 4, 2, 0, 3, 1, 0, 1, 4, 3, 1, 2, 3, 0, 0, 1, 0, 1, 4, 1, 0, 0, 5, 5, 1, 2, 0, 2, 1, 1, 0, 0, 1, 1, 0, 2, 2, 1, 1, 1, 3, 4, 1, 1, 3, 2, 0, 3, 3, 0, 3, 1, 4, 0, 0, 0, 2, 2, 4, 3, 0, 3, 3, 1, 1, 5]

从结果来看,准确率还是很高的。

总结

1、通过transformor的Trainer快速的微调大模型

2、通过datasets的load_dataset非常方便的加载数据集,这里有个问题,数据集还是从网上下载的,其实load_dataset是可以从被动本地加载的,下篇博客专门讲一讲load_dataset

3、模型微调之后,直接使用checkpoint来加载模型

4、预测当前使用的是测试集,后续扩展从控制台输入,需要补充代码,后续博客给出

5、有个警告需要解决,作为遗留问题:

python3.9/site-packages/transformers/convert_slow_tokenizer.py:470: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/385061
推荐阅读
相关标签
  

闽ICP备14008679号