当前位置:   article > 正文

transformer环境配置与文本分类实战应用快速上手_importerror: using the `trainer` with `pytorch` re

importerror: using the `trainer` with `pytorch` requires `accelerate>=0.21.0

 环境配置

使用colab运行代码的对应环境

  1. pip install transformers==4.21.0 datasets evaluate
  2. pip install transformers

环境报错问题

  调用transformers中的TrainingArguments报错:ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: Please run `pip install transformers[torch]` or `pip install accelerate -U`

  1. from transformers import TrainingArguments
  2. args = TrainingArguments(learning_rate=2e-5)

报错信息如下:

 解决方法:降低transformers的版本

pip install transformers==4.24.0

或尝试安装提示指向的包:

  1. pip install transformers[torch]
  2. pip install accelerate -U

安装好后记得重启一下内核

transformer文本分类实战

简介

文本分类任务的输入为一段文本,输出为该段文本的标签。

根据文本内容与目标的不同,也可以划分为多种任务,例如:新闻分类、情感识别、意图识别等等,这些任务在模型的使用上都是类似的。

数据集准备

训练的第一步自然是完成数据集的准备,没有合适的数据集的话这里提供一个中文文本分类数据集参考,本次实践采用的数据集为今日头条新闻数据集

数据集介绍

数据来源:

今日头条客户端

数据规模:

共382688条,分布于15个分类中。

数据格式:

6554645369685278979_!_101_!_news_culture_!_隶书的间架是扁形吗?_!_
6554465412992467459_!_102_!_news_entertainment_!_封神英雄榜:张馨予_!_张馨予,张馨予张馨予
6554582914875523341_!_102_!_news_entertainment_!_能不能用一句话证明你很穷?_!_
6554604499497910787_!_102_!_news_entertainment_!_为什么六小龄童是中国唯一拥有2张身份证的人?原因竟然是这样!_!_六小龄童,小六龄童,六龄童,孙悟空,章金莱
6554635528996651272_!_103_!_news_sports_!_组三巨威少+詹姆斯+子母哥冲击力如何?你希望这三人在雷霆或雄鹿还是骑士?_!_
6554486857797730574_!_103_!_news_sports_!_kpl联赛hero继惨败黑凤梨后再次遭遇estar零封,是故意放水还是实力不行?_!_

每行为一条数据,以_!_分割的个字段,从前往后分别是 新闻ID,分类code(见下文),分类名称(见下文),新闻字符串(仅含标题),新闻关键词

分类code与名称:

100 民生 故事 news_story
101 文化 文化 news_culture
102 娱乐 娱乐 news_entertainment
103 体育 体育 news_sports
104 财经 财经 news_finance
106 房产 房产 news_house
107 汽车 汽车 news_car
108 教育 教育 news_edu 
109 科技 科技 news_tech
110 军事 军事 news_military
112 旅游 旅游 news_travel
113 国际 国际 news_world
114 证券 股票 stock
115 农业 三农 news_agriculture
116 电竞 游戏 news_game

数据处理

数据集加载

将数据集下载到本地后,使用datasets进行加载,data_files文件指定路径

  1. from datasets import load_dataset
  2. dataset = load_dataset("text", data_files=r"./toutiao_cat_data.txt")

 查看数据可以看到一共有382688条数据,字段为text

  1. >>>dataset
  2. DatasetDict({
  3. train: Dataset({
  4. features: ['text'],
  5. num_rows: 382688
  6. })
  7. })

选择一条查看数据样式,可以看到与先前txt文件中数据一样

  1. >>>dataset["train"][0]
  2. {'text': '6551700932705387022_!_101_!_news_culture_!_京城最值得你来场文化之旅的博物馆_!_保利集团,马未都,中国科学技术馆,博物馆,新中国'}

划分数据集

由于加载数据后只有一个train字段的数据集,所以我们还需要对数据集进行划分,这里需要使用train_test_split,并按照8:2的比例进行数据集的划分。

datasets = dataset["train"].train_test_split(0.2)
  1. >>>datasets
  2. DatasetDict({
  3. train: Dataset({
  4. features: ['text'],
  5. num_rows: 306150
  6. })
  7. test: Dataset({
  8. features: ['text'],
  9. num_rows: 76538
  10. })
  11. })

数据处理模型加载

加载分词器与预训练模型,这里直接调用官网的bert中文模型,这里除了要指定模型名称外,还要指定num_labels参数,值为label值的个数,如该数据集的新闻类别为15。

  1. from transformers import BertTokenizer, BertForSequenceClassification
  2. model_name = 'bert-base-chinese'
  3. tokenizer = BertTokenizer.from_pretrained(model_name)
  4. model = BertForSequenceClassification.from_pretrained(model_name, num_labels=15)
  1. >>>tokenizer
  2. PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

文本数据预处理

说明:

由于每一行数据是包括文本信息、标签信息等在内的文本,并不是直接可输入模型的数据,因此需要根据不同数据集的特征对文本进行一些简单的预处理,比如对数据

'6552260930879619591_!_102_!_news_entertainment_!_当心!男性也会患上乳腺癌!_!_肝硬化,乳腺癌,男性乳腺癌,普通外科,中华生物医学工程杂志,雄激素,男人四十要出嫁'

我们希望标签为'news_entertainment',文本信息为'当心!男性也会患上乳腺癌!肝硬化乳腺癌男性乳腺癌普通外科中华生物医学工程杂志雄激素男人四十要出嫁',使用split函数对文本进行简单的分割,同时去除文本中无关的标点符号。

  1. translator = str.maketrans('', '', string.punctuation)
  2. word = datasets["train"]['text'][0].split('_!_', 3)[3].translate(translator)
  3. label = datasets["train"]['text'][0].split('_!_')[2]
  1. >>>word
  2. '当心!男性也会患上乳腺癌!肝硬化乳腺癌男性乳腺癌普通外科中华生物医学工程杂志雄激素男人四十要出嫁'
  3. >>>label
  4. 'news_entertainment'

对其他数据集则自行定义合适的方法(如评论情感文本分类中,评论本身为文本数据,代表情感的0或1作为标签)

另外,由于标签内容为单词,需要对标签进行编码:

  1. from sklearn.preprocessing import LabelEncoder
  2. # 创建LabelEncoder对象
  3. label_encoder = LabelEncoder()
  4. # 创建标签
  5. labels = ['news_story', 'news_culture', 'news_entertainment', 'news_sports', 'news_finance',
  6. 'news_house', 'news_car', 'news_edu', 'news_tech', 'news_military', 'news_travel',
  7. 'news_world', 'stock', 'news_agriculture', 'news_game']
  8. # 先进行fit操作,学习标签的编码规则
  9. label_encoder.fit(labels)
  10. # 使用transform操作将标签编码为整数
  11. encoded_labels = label_encoder.transform(labels)
  12. encoded_labels

定义数据处理函数

有了前面这些准备就可以开始定义数据处理函数了

  1. def process_function(examples):
  2. lists = []
  3. data_labels = []
  4. translator = str.maketrans('', '', string.punctuation)
  5. for text in examples["text"]:
  6. words = text.split('_!_', 3)[3].translate(translator)
  7. label = text.split('_!_')[2]
  8. lists.append(words)
  9. data_labels.append(label)
  10. tokenized_examples = tokenizer(lists)
  11. # 将标签转换为数字编码
  12. tokenized_examples["labels"] = label_encoder.transform(data_labels)
  13. return tokenized_examples

使用map方法对数据集进行处理,后续操作不需要用到最初的数据,因此通过remove_columns将多余的字段删除

  1. tokenized_datasets = datasets.map(process_function, batched=True, load_from_cache_file=False,remove_columns=['text'])
  2. tokenized_datasets

训练数据集

构建评估函数

文本分类任务的评估指标有很多,这里采用最简单的准确率作为评价指标,直接通过evaluate加载并对其进行封装,输入参数包括着模型预测的结果与真实标签,模型预测的结果需要进一步argmax获取标签值,最终返回结果为字典,该方法会在模型训练中用到。

  1. accuracy_metric = evaluate.load("accuracy")
  2. def compute_metrics(eval_pred):
  3. predictions, labels = eval_pred
  4. predictions = predictions.argmax(axis=-1)
  5. return accuracy_metric.compute(predictions=predictions, references=labels)

配置训练器

可以使用transformers库中提供的封装好的训练方法,也可以自己重写或自定义训练函数。首先是设置训练参数,learning_rate为学习率,训练时batch大小为32,验证时为batch大小128,num_train_epochs为训练轮数,权重衰减大小为0.01,output_dir定义输出文件夹位置,logging_steps为日志记录的步长,即10个batch记录一次;评估策略为训练完一个epoch之后进行评估,设置训练完成后加载最优模型,并指定最优模型的评估指标为accuracy,这个值要和compute_metrics函数中返回值的键匹配,最后指定了半精度训练。

Trainer的参数分别为模型,训练参数,训练数据集,验证数据集,分词器,评估函数,data_collator的值为DataCollatorWithPadding的实例对象

  1. from transformers import Trainer
  2. import evaluate
  3. args = TrainingArguments(
  4. learning_rate=2e-5,
  5. per_device_train_batch_size=32,
  6. per_device_eval_batch_size=128,
  7. num_train_epochs=20,
  8. weight_decay=0.01,
  9. output_dir=save_folder,
  10. logging_steps=10,
  11. evaluation_strategy = "epoch",
  12. save_strategy = "epoch",
  13. load_best_model_at_end=True,
  14. metric_for_best_model="accuracy",
  15. fp16=True,
  16. )
  17. trainer = Trainer(
  18. model,
  19. args=args,
  20. train_dataset=tokenized_datasets["train"],
  21. eval_dataset=tokenized_datasets["test"],
  22. tokenizer=tokenizer,
  23. compute_metrics=compute_metrics,
  24. data_collator=DataCollatorWithPadding(tokenizer=tokenizer)
  25. )
trainer.train()

训练完成后的文件夹内容大致如下:checkpoint文件夹,保存着不同轮次的模型,runs文件夹中则记录着运行日志

最后可以进行评估

trainer.evaluate()

模型预测

关于如何使用自己训练好的文本分类模型进行预测,首先加载需要的包以及构建标签编码encoder对象(见上文文本预处理部分)

主要是调用训练好的模型,model_path为目录,会自动检索需要的模型以及权重,不需要直接写模型路径

  1. # path为模型所在目录(包含模型及权重等的文件夹),不是模型本身
  2. model_path = "./model_for_seqclassification/checkpoint-95680"
  3. tokenizer = BertTokenizer.from_pretrained(model_path)
  4. model = BertForSequenceClassification.from_pretrained(model_path, num_labels=15)

定义一个简单的结果可视化函数

  1. def predict(input_text):
  2. # 对文本进行标记和编码
  3. input_ids = tokenizer(input_text, truncation=True, padding=True, return_tensors="pt")
  4. '''
  5. 使用模型进行预测:
  6. 输入为一个PyTorch张量(通过上面的分词器处理)
  7. '''
  8. outputs = model(**input_ids)
  9. # 获取可能性最大值的索引(outputs.logits为各分类对应的得分数组)
  10. predicted_labels = torch.argmax(outputs.logits, dim=1).tolist()
  11. # 对预测结果文本类型进行解码
  12. for i, text in enumerate(input_text):
  13. predicted_label = label_encoder.inverse_transform([predicted_labels[i]])
  14. print("文本:", text)
  15. print("类别:", predicted_label)

预测效果测试:

  1. text1 = ["第二届山东百年品牌论坛暨全省单项冠军建设推进会在济南举行","20万如何买一辆又野又顾家的SUV,在哈弗H9、途达和mu-X中该怎么选择?"]
  2. predict(text1)

结语

        本文介绍了如何使用transformers进行文本分类任务的实践,包括对数据集的处理与训练以及预测等相关方法的使用。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/654811
推荐阅读
相关标签
  

闽ICP备14008679号