当前位置:   article > 正文

Bert模型实现中文新闻文本分类_中文文本分类模型

中文文本分类模型

        Bert基于Transformer架构是解决自然语言处理的深度学习模型,常使用在文本分类、情感分析、词性标注等场合。

        本文将使用Bert模型对中文文本进行分类,其中训练集数据18W条,验证集数据1W条,包含10个类别的文本数据,数据可以自己从Kaggel上下载。

        

中文新闻标题类别标签类别名
锌价难续去年辉煌0金融
金科西府 名墅天成1房地产
同步A股首秀:港股缩量回调2经济
状元心经:考前一周重点是回顾和整理3教育
一年网事扫荡10年纷扰开心网李鬼之争和平落幕4科技
60年铁树开花形状似玉米芯(组图)5社会
发改委治理涉企收费每年为企业减负超百亿6国际
布拉特:放球员一条生路吧 FIFA能消化俱乐部的攻击7体育
体验2D巅峰 倚天屠龙记十大创新概览8游戏
Rain入伍前最后开唱 本周六“雨”润京城(图)9娱乐

分类模型的结构比较简单,示意图如下:

Dataset是我们用的数据集的库,是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示。其中__len__应该返回数据集的大小,而__getitem__应该编写支持数据集索引的函数。

DataLoader是PyTorch提供的一个数据加载器,它可以将数据分成小批次进行加载,并自动完成数据的批量加载、随机洗牌、并发预取等操作。在神经网络的训练过程中,我们通常需要处理大量的数据。如果一次性将所有数据加载到内存中,不仅会消耗大量的内存资源,还可能导致程序运行缓慢甚至崩溃。因此,我们需要一种机制来将数据分成小批次进行加载,而DataLoader正是为了满足这一需求而诞生的。

  1. #首先导入需要用到的数据包
  2. from transformers import BertModel, BertTokenizer
  3. import torch.nn as nn
  4. import torch
  5. from torch.utils.data import Dataset, DataLoader
  6. from torch import optim
  7. import os
  8. class BertClassifier(nn.Module):
  9. def __init__(self, bert_model, output_size):
  10. super(BertClassifier, self).__init__()
  11. self.bert = bert_model
  12. self.classifier = nn.Linear(bert_model.config.hidden_size, output_size)
  13. def forward(self, input_ids, attention_mask):
  14. # 获取BERT模型的CLS输出
  15. text_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
  16. #得到线性层的结果
  17. logits=self.classifier(text_output.pooler_output)
  18. return logits
  19. #读取数据
  20. class data_load(Dataset):
  21. def __init__(self,path):
  22. self.data=list()
  23. file=open(path,'r',encoding='utf-8')
  24. for line in file:
  25. text,label=line.strip().split('\t')
  26. self.data.append((text,int(label)))
  27. file.close()
  28. def __len__(self):
  29. return len(self.data)
  30. def __getitem__(self, index):
  31. return self.data[index]
  32. #用于dataloader,对于每个小批量的数据,进行分词和填充
  33. def collate_fn(batch,tokenizer):
  34. texts=[text[0] for text in batch]
  35. labels=[text[1] for text in batch]
  36. labels=torch.tensor(labels,dtype=torch.long)
  37. tokens=tokenizer(
  38. texts,
  39. add_special_tokens=True,
  40. max_length=512,
  41. padding=True,
  42. truncation=True,
  43. return_tensors='pt',
  44. )
  45. return tokens['input_ids'],tokens['attention_mask'],labels
  46. if __name__=="__main__":
  47. dataset=data_load('./train.txt')
  48. print(len(dataset))
  49. output :180000
  50. #加载模型,生成分词器
  51. tokenizer=BertTokenizer.from_pretrained('bert-base-chinese')
  52. bert_model = BertModel.from_pretrained('bert-base-chinese')
  53. #dataset:要加载的数据集对象,必须是实现了len()和getitem()方法的对象
  54. data_loader=DataLoader(dataset,
  55. batch_size=128,
  56. shuffle=True,
  57. collate_fn=lambda x:collate_fn(x,tokenizer))
  58. # 指定机器
  59. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  60. #打印分词器支持的最大长度,输入的中文数据不能超过512
  61. #如果进行长文本分类,需要进行文本截断或分块处理
  62. # print(tokenizer.model_max_length)
  63. #定义bertclassifier模型为10分类
  64. model=BertClassifier(bert_model,output_size=10).to(device)
  65. model.train()
  66. #优化器
  67. optimizer=optim.AdamW(model.parameters(),lr=5e-5)
  68. #交叉熵损失误差
  69. criterion=nn.CrossEntropyLoss()
  70. #存放模型
  71. os.makedirs('output_models',exist_ok=True)
  72. epoch_n=10
  73. for epoch in range(1,epoch_n+1):
  74. for batch_index,data in enumerate(data_loader):
  75. input_ids=data[0].to(device)
  76. attention_mask=data[1].to(device)
  77. label=data[2].to(device)
  78. #清空梯度
  79. optimizer.zero_grad()
  80. #前向传播
  81. output=model(input_ids,attention_mask)
  82. loss=criterion(output,label)
  83. loss.backward() #计算梯度
  84. optimizer.step() #更新模型参数
  85. #计算正确率,用于观察模型结果
  86. predict=torch.argmax(output,dim=1)
  87. correct=(predict==label).sum().item()
  88. acc=correct/output.size(0)
  89. print(f"Epoch {epoch}/{epoch_n}") #迭代轮数
  90. print(f"Batch {batch_index+1}/{len(data_loader)}")
  91. print(f"Loss: {loss.item():.4f}") #损失
  92. print((f"Acc {correct}/{output.size(0)}=={acc:.3f}")) #正确率
  93. #每一次迭代都保存一次模型结果
  94. model_name=f'./output_models/chinese_news_classify{epoch}.pth'
  95. print("saved model: %s" % (model_name))
  96. torch.save(model.state_dict(),model_name)

可以看到随着训练的进行,模型的准确率越来越高。由于数据量和机器内存原因,训练的时间比较长,就没有全部跑完。

  1. Epoch 1/10
  2. Batch 59/1407
  3. Loss: 0.4286
  4. Acc 113/128==0.883
  5. saved model: ./output_models/chinese_news_classify1.pth
  6. Epoch 1/10
  7. Batch 60/1407
  8. Loss: 0.4399
  9. Acc 114/128==0.891
  10. saved model: ./output_models/chinese_news_classify1.pth
  11. Epoch 1/10
  12. Batch 61/1407
  13. Loss: 0.5028
  14. Acc 109/128==0.852
  15. saved model: ./output_models/chinese_news_classify1.pth
  16. Epoch 1/10
  17. Batch 62/1407
  18. Loss: 0.3180
  19. Acc 120/128==0.938

使用训练好的模型预测中文文本

  1. from kaggel_chinese_text import BertClassifier
  2. from transformers import BertModel, BertTokenizer
  3. import torch
  4. test_text='铁血铸辉煌 天骄3公会战唤起新激情'
  5. bert_model = BertModel.from_pretrained('bert-base-chinese')
  6. tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
  7. model=BertClassifier(bert_model,10)
  8. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  9. model.load_state_dict(torch.load('./output_models/chinese_news_classify1.pth',map_location=device))
  10. model.to(torch.device(device))
  11. model.eval()
  12. inputs = tokenizer.encode_plus(
  13. test_text,
  14. add_special_tokens=True,
  15. max_length=128,
  16. padding='max_length',
  17. truncation=True,
  18. return_tensors='pt'
  19. )
  20. input_ids = inputs['input_ids']
  21. # print("shape of inut_ids:",input_ids.shape)
  22. attention_mask = inputs['attention_mask']
  23. with torch.no_grad():
  24. input_ids = input_ids.to(device)
  25. attention_mask = attention_mask.to(device)
  26. outputs = model(input_ids,attention_mask)
  27. _, predicted = torch.max(outputs, 1)
  28. print(predicted.item())
  29. #能正确预测文本属于游戏类型
  30. output: 8

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/887078
推荐阅读
相关标签
  

闽ICP备14008679号