当前位置:   article > 正文

自然语言处理NLP:使用DisBert模型完成文本二分类任务(Pytorch)_nlp 文本二分类

nlp 文本二分类

项目来源于kaggle竞赛,地址为:Natural Language Processing with Disaster Tweets | Kaggle

        本文主要是对本人学习NLP的过程做个总结和记录,以方便日后复习,当然如果本文能帮助到阅读该文的读者,我会感到很开心。

        该项目通过建立一个机器学习模型,预测哪些推文是关于真实灾难的,哪些不是。

        首先读取数据,看一下数据的样子,数据在kaggle上可直接下载。可以看到有用的是text列和target列,我们只需要这两列进行模型训练即可,target为1表示text是真实灾难相关的文本,反之则不是。

  1. import pandas as pd
  2. train_df = pd.read_csv('/kaggle/input/nlpdataset/train.csv')
  3. test_df = pd.read_csv('/kaggle/input/nlpdataset/test.csv')
  4. train_df

         然后定义缩写字符变扩写字符的字典,以便将文本中的缩写字符扩展为全称,在NLP预处理数据过程中,将英文缩写扩展为全称的作用是为了将缩写与其全称等效地表示,以便于算法处理。例如,将"I'm" 扩展为"I am",将"don't" 扩展为"do not",将"it's" 扩展为"it is" 或 "it has"。

        缩写扩展的作用主要是为了消除歧义:同一个缩写可能有多种可能的解释,例如"AI"可以表示"Artificial Intelligence"或者"Air India"。将缩写扩展为全称可以消除这种歧义,确保模型可以正确地理解数据。

        日后如果有文本缩写扩展处理的需要,可直接复制该字典,然后进行替换操作。

  1. # 缩写变扩写
  2. contractions = {
  3. "ain't": "am not",
  4. "aren't": "are not",
  5. "can't": "cannot",
  6. "can't've": "cannot have",
  7. "'cause": "because",
  8. "could've": "could have",
  9. "couldn't": "could not",
  10. "couldn't've": "could not have",
  11. "didn't": "did not",
  12. "doesn't": "does not",
  13. "don't": "do not",
  14. "hadn't": "had not",
  15. "hadn't've": "had not have",
  16. "hasn't": "has not",
  17. "haven't": "have not",
  18. "he'd": "he would",
  19. "he'd've": "he would have",
  20. "he'll": "he will",
  21. "he'll've": "he will have",
  22. "he's": "he is",
  23. "how'd": "how did",
  24. "how'd'y": "how do you",
  25. "how'll": "how will",
  26. "how's": "how is",
  27. "i'd": "I would",
  28. "i'd've": "I would have",
  29. "i'll": "I will",
  30. "i'll've": "I will have",
  31. "i'm": "I am",
  32. "i've": "I have",
  33. "isn't": "is not",
  34. "it'd": "it would",
  35. "it'd've": "it would have",
  36. "it'll": "it will",
  37. "it'll've": "it will have",
  38. "it's": "it is",
  39. "let's": "let us",
  40. "ma'am": "madam",
  41. "mayn't": "may not",
  42. "might've": "might have",
  43. "mightn't": "might not",
  44. "mightn't've": "might not have",
  45. "must've": "must have",
  46. "mustn't": "must not",
  47. "mustn't've": "must not have",
  48. "needn't": "need not",
  49. "needn't've": "need not have",
  50. "o'clock": "of the clock",
  51. "oughtn't": "ought not",
  52. "oughtn't've": "ought not have",
  53. "shan't": "shall not",
  54. "sha'n't": "shall not",
  55. "shan't've": "shall not have",
  56. "she'd": "she would",
  57. "she'd've": "she would have",
  58. "she'll": "she will",
  59. "she'll've": "she will have",
  60. "she's": "she is",
  61. "should've": "should have",
  62. "shouldn't": "should not",
  63. "shouldn't've": "should not have",
  64. "so've": "so have",
  65. "so's": "so is",
  66. "that'd": "that would",
  67. "that'd've": "that would have",
  68. "that's": "that is",
  69. "there'd": "there would",
  70. "there's": "there is",
  71. "they'd": "they would",
  72. "they'd've": "they would have",
  73. "they'll": "they will",
  74. "they'll've": "they will have",
  75. "they're": "they are",
  76. "they've": "they have",
  77. "to've": "to have",
  78. "wasn't": "was not",
  79. "we'd": "we would",
  80. "we'd've": "we would have",
  81. "we'll": "we will",
  82. "we'll've": "we will have",
  83. "we're": "we are",
  84. "we've": "we have",
  85. "weren't": "were not",
  86. "what'll": "what will",
  87. "what'll've": "what will have",
  88. "what're": "what are",
  89. "what's": "what is",
  90. "what've": "what have",
  91. "when's": "when is",
  92. "when've": "when have",
  93. "where'd": "where did",
  94. "where's": "where is",
  95. "where've": "where have",
  96. "who'll": "who will",
  97. "who'll've": "who will have",
  98. "who's": "who is",
  99. "who've": "who have",
  100. "why's": "why is",
  101. "why've": "why have",
  102. "will've": "will have",
  103. "won't": "will not",
  104. "won't've": "will not have",
  105. "would've": "would have",
  106. "wouldn't": "would not",
  107. "wouldn't've": "would not have",
  108. "y'all": "you all",
  109. "y'all'd": "you all would",
  110. "y'all'd've": "you all would have",
  111. "y'all're": "you all are",
  112. "y'all've": "you all have",
  113. "you'd": "you would",
  114. "you'd've": "you would have",
  115. "you'll": "you will",
  116. "you'll've": "you will have",
  117. "you're": "you are",
  118. "you've": "you have"
  119. }
  120. country_contractions = {
  121. 'u.s': 'united states',
  122. 'u.s.': 'united states',
  123. 'u.s.a': 'united states',
  124. 'u.k': 'united kingdom',
  125. 'u.k.': 'united kingdom',
  126. 'u.a.e': 'united arab emirates',
  127. 'u.a.e.': 'united arab emirates',
  128. 's.korea': 'south korea',
  129. 'n.korea': 'north korea',
  130. 'czech rep.': 'czech republic',
  131. 'dominican rep.': 'dominican republic',
  132. 'costa rica': 'republic of costa rica',
  133. 'el salvador': 'republic of el salvador',
  134. 'guinea-bissau': 'republic of guinea-bissau',
  135. 'cote d\'ivoire': 'republic of cote d\'ivoire',
  136. 'trinidad & tobago': 'republic of trinidad and tobago',
  137. 'congo-brazzaville': 'republic of the congo',
  138. 'congo-kinshasa': 'democratic republic of the congo',
  139. 'sri lanka': 'democratic socialist republic of sri lanka',
  140. 'central african rep.': 'central african republic',
  141. 'san marino': 'republic of san marino',
  142. 'são tomé & príncipe': 'democratic republic of são tomé and príncipe',
  143. 'timor-leste': 'democratic republic of timor-leste'
  144. }

        定义缩写扩展的函数,需要注意的地方在注释中已经写出。

  1. # 缩写变扩写
  2. def expand_contractions(text,contractions):
  3. words = text.split()
  4. expand_contractions = []
  5. for word in words:
  6. if word.lower() in contractions:#判断word是否在contractions字典的键中出现
  7. expand_contraction = contractions[word.lower()]
  8. expand_contractions.append(expand_contraction)
  9. else:
  10. expand_contractions.append(word)
  11. return ' '.join(expand_contractions)# join函数用于将列表中的字符连接为一个字符串(以空格分割)

        进行文本清洗,包括缩写字符扩展,以及将URL、数字和标点替换为空格。URL、标点、数字等在自然语言文本中属于噪声,将它们替换为空格可以去除这些噪声,以便模型可以更好的学习到有用的特征。

  1. import re
  2. # 应用expand_contractions并用sub函数进行url、空格、标点、数字字符替换为空格
  3. def clean_text(text):
  4. # 替换一般缩略字符和城市缩略字符
  5. text = expand_contractions(text, contractions)
  6. text = expand_contractions(text, country_contractions)
  7. text = re.sub(r'http\S+|https\S+|www\S+', '', text, flags = re.MULTILINE)
  8. # \S代表匹配任意非空字符
  9. # 其中 \S 表示匹配非空白字符,+ 表示匹配一个或多个。
  10. # \S+ 表示匹配一个或多个非空白字符。| 表示或的关系。
  11. # \S+ 和 | 组合在一起,表示匹配以 http、www 或 https 开头的字符串。
  12. # flags=re.MULTILINE这个参数是多余的,可以不加
  13. text = re.sub(r'\W', ' ', text)# 非字母数字字符 空格,标点符号等
  14. text = re.sub(r'\d', ' ', text)# 数字字符
  15. text = re.sub(r'\s+', ' ', text).strip()# 匹配一个或多个空格字符,替换为空格,strip函数删除开头和结尾的空格
  16. return text
  17. train_df['clean_text'] = train_df['text'].apply(clean_text)
  18. test_df['clean_text'] = test_df['text'].apply(clean_text)
  19. train_df

         分割训练集和验证集,注意原始数据集中本就存在测试集。

  1. import math
  2. split_count = math.floor(len(train_df) / 10 * 9)
  3. val_df = train_df[split_count:]
  4. train_df = train_df[:split_count]

        之后通过Transfoemers库定义DistilBert模型的tokenizer和model,这里使用的是DistilBert模型,它是BERT模型的一种轻量级版本,通过压缩BERT模型的大小和计算量以及进行知识蒸馏,使得DistilBERT在保持高性能的同时,具有更小的模型体积和更快的推理速度。

  1. 知识蒸馏:是一种模型压缩的技术,用于将一个较大和复杂的模型的知识传递给一个较小和 单的模型中。具体来说,知识蒸馏将一个“教师”模型(通常为较大、较复杂的模型)的知识“蒸馏”到一个“学生”模型(通常为较小、较简单的模型)中,使得学生模型能够学习到教师模型中的关键特征和知识,从而获得更好的泛化能力和性能。

        Transformers是由Hugging Face公司开发的一个自然语言处理(NLP)库,它提供了一系列基于深度学习的预训练模型,包括BERT、GPT、RoBERTa、DistilBERT等,方便用户进行模型微调以应用到不同的任务中,Transformers库的应用场景广泛,可以用于各种NLP任务,如情感分析、机器翻译、文本分类、问答系统等。

  1. 预训练模型:在NLP中,预训练模型通常使用大量文本数据进行无监督式预训练,以学习语言的特征和结构,例如语法、语义和上下文等。
  2. 微调:微调是指在已经预训练的模型基础上,进一步训练模型以适应特定任务的过程。微调通常使用有监督学习方法,使用任务特定的数据集进行训练,以调整预训练模型的权重和参数,使其适应特定的NLP任务。

        如果想更进一步了解和学习Transformers库,这里给出pytorch官网关于Transformers库的链接:PyTorch-Transformers | PyTorch

  1. from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
  2. import torch
  3. MODEL_NAME = 'distilbert-base-uncased'
  4. tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
  5. model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

        之后将数据集进行格式转换,最终转换为模型输入所需要的torch格式。tokenizer函数的作用是将输入的文本数据进行分词并进行必要的处理,padding和truncation参数用于控制padding和截断的行为。padding参数用于将所有输入的文本数据padding到相同的长度,以便进行批量处理。如果一个样本的长度小于指定的length ,则用pad_token_id进行填充,如果长度大于length ,则根据truncation参数的设置进行截断。

  1. from transformers import TrainingArguments, Trainer
  2. def tokenize(batch):
  3. return tokenizer(batch['text'], padding=True, truncation=True)
  4. test_df["target"] = 0
  5. #取出文本和标签
  6. train_dataset = train_df[['clean_text', 'target']]
  7. val_dataset = val_df[['clean_text', 'target']]
  8. test_dataset = test_df[['clean_text', 'target']]
  9. #更改列名
  10. train_dataset = train_dataset.rename(columns={'clean_text': 'text', 'target': 'label'})
  11. val_dataset = val_dataset.rename(columns={'clean_text': 'text', 'target': 'label'})
  12. test_dataset = test_dataset.rename(columns={'clean_text': 'text', 'target': 'label'})
  13. #records参数代表将每一行转换为一个字典,并存储在一个列表中
  14. train_dataset = train_dataset.to_dict('records')
  15. val_dataset = val_dataset.to_dict('records')
  16. test_dataset = test_dataset.to_dict('records')
  17. from datasets import Dataset
  18. #其实可以直接将train_dataset转换为Dataset,不需要进行字典转换,多此一举了
  19. train_dataset = Dataset.from_pandas(pd.DataFrame(train_dataset))
  20. val_dataset = Dataset.from_pandas(pd.DataFrame(val_dataset))
  21. test_dataset = Dataset.from_pandas(pd.DataFrame(test_dataset))
  22. # 将batch的大小设置为整个数据集的大小,意味着将整个数据集进行tokenize操作
  23. train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
  24. val_dataset = val_dataset.map(tokenize, batched=True, batch_size=len(val_dataset))
  25. test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
  26. # 将数据集转换为torch格式
  27. train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
  28. val_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
  29. test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

         设置训练参数和训练器,训练并保存模型。

  1. import os
  2. os.environ["WANDB_DISABLED"] = "true"
  3. # 设置训练参数
  4. training_args = TrainingArguments(
  5. output_dir='./results',
  6. num_train_epochs=10,
  7. per_device_train_batch_size=16,
  8. per_device_eval_batch_size=16,
  9. warmup_steps=500,
  10. weight_decay=0.01,
  11. logging_dir='./logs',
  12. logging_steps=10,
  13. load_best_model_at_end=True,
  14. evaluation_strategy="epoch",
  15. save_strategy="epoch"
  16. )
  17. # 初始化训练器
  18. trainer = Trainer(
  19. model=model,
  20. args=training_args,
  21. train_dataset=train_dataset,
  22. eval_dataset=val_dataset,
  23. tokenizer=tokenizer,
  24. )
  25. # 训练模型
  26. trainer.train()
  27. # 保存模型参数
  28. model.save_pretrained("./model")

         加载模型并在测试集上进行预测。

  1. ​# 加载模型
  2. loaded_model = DistilBertForSequenceClassification.from_pretrained("./model")
  3. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  4. def predict(text, loaded_model, tokenizer):
  5. inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
  6. inputs.to(device)
  7. outputs = loaded_model(**inputs)
  8. logits = outputs.logits
  9. probabilities = torch.softmax(logits, dim=-1)
  10. predictions = torch.argmax(probabilities, dim=-1)
  11. return predictions.item()
  12. test_df["predicted_target"] = test_df["clean_text"].apply(lambda x: predict(x, model, tokenizer))

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/344935
推荐阅读
相关标签
  

闽ICP备14008679号