赞
踩
一、前期工作
二、模型加载
三、模型训练
四、模型测试
大家好,我是微学AI,今天给大家带来一个基于BERT模型做文本分类的实战案例,在BERT模型基础上做微调,训练自己的数据集,相信之前大家很多都是套用别人的模型直接训练,或者直接用于预训练模型进行预测,没有训练和微调过大模型,因为像BERT这种大模型一般人是训练不了的,我们只能在大模型的基础上进行微调,或者做下游任务改造。
下面来介绍一下BERT模型,BERT是基于transfomer的预训练语言模型,它利用了transfomer中的编码器,进行数据编码,将文本数据转化为词向量。BERT核心内容是利用transfomer中的多头自注意力机制进行编码,关于transfomer的多头自注意力机制详细可以观看网络上的资料。
BERT模型是以两个NLP任务进行训练的,第一个任务是文本中词的预测,将已知训练文本隐掉词的信息,用MASK进行隐码,让模型去预测。第二个任务是在训练数据中随机抽取上下文关系句子或非上下文关系句子,让机器判断是否为上下文关系。BERT模型训练优势是无需进行标注数据。
我们可以利用BERT预训练模型进行下游任务改造,做自己相关任务,比如中文分词、文本分类,命名实体识别,阅读理解,情感分析,文本相似度、信息抽取等任务。
- import torch
- from datasets import load_dataset
- import torch.nn.functional as F
- from transformers import BertTokenizer
-
- #加载字典和分词工具
- token = BertTokenizer.from_pretrained('bert-base-chinese')
- #定义数据集
- class Dataset(torch.utils.data.Dataset):
- def __init__(self, split):
- self.dataset = load_dataset(path='data', split=split)
-
- def __len__(self):
- return len(self.dataset)
-
- def __getitem__(self, i):
- text = self.dataset[i]['text']
- label = self.dataset[i]['label']
-
- return text, label
-
- dataset = Dataset('train')
- print(len(dataset), dataset[0])
-
- def collate_fn(data):
- sents = [i[0] for i in data]
- labels = [i[1] for i in data]
-
- #编码
- data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
- truncation=True,
- padding='max_length',
- max_length=500,
- return_tensors='pt',
- return_length=True)
-
- #input_ids:编码之后的数字
- #attention_mask:是补零的位置是0,其他位置是1
- input_ids = data['input_ids']
- attention_mask = data['attention_mask']
- token_type_ids = data['token_type_ids']
- labels = torch.LongTensor(labels)
-
- #print(data['length'], data['length'].max())
- return input_ids, attention_mask, token_type_ids, labels
-
- #数据加载器
- loader = torch.utils.data.DataLoader(dataset=dataset,
- batch_size=10,
- collate_fn=collate_fn,
- shuffle=True,
- drop_last=True)
-
- for i, (input_ids, attention_mask, token_type_ids,
- labels) in enumerate(loader):
- break
-
- print(len(loader))
- print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels)
这里代码需要在同级文件夹下创建data 文件夹, 放入train.csv、test.csv数据集。
数据集格式如下:
我们可以在BERT输出端接入一个全连接层,输出2分类问题,也可加入CNN卷积层,这些可以自行操作。
- from transformers import BertModel
-
- #加载预训练模型
- pretrained = BertModel.from_pretrained('bert-base-chinese')
-
- #不训练,不需要计算梯度
- for param in pretrained.parameters():
- param.requires_grad_(False)
-
- #模型试算
- out = pretrained(input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids)
-
- print(out.last_hidden_state.shape)
-
-
- #定义下游任务模型
- class Model(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.fc = torch.nn.Linear(768, 2)
- # 可加入CNN卷积层,可以自行操作
- # self.conv1D = torch.nn.Conv1d(in_channels=500, out_channels=500, kernel_size=1)
- # self.MaxPool1D = torch.nn.MaxPool1d(4, stride=2)
- # self.Dropout = torch.nn.Dropout(p=0.5, inplace=False)
-
- def forward(self, input_ids, attention_mask, token_type_ids):
- with torch.no_grad():
- out = pretrained(input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids)
- out = self.fc(out.last_hidden_state[:, 0])
- out = out.softmax(dim=1)
- print(out.shape)
- return out
- model = Model()
- print(model)
- #model.summary()
- model(input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids).shape
-
- from transformers import AdamW
- #训练
- optimizer = AdamW(model.parameters(), lr=5e-4)
- criterion = torch.nn.CrossEntropyLoss()
-
- model.train()
- epochs = 30
-
- for i, (input_ids, attention_mask, token_type_ids,
- labels) in enumerate(loader):
- out = model(input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids)
-
- loss = criterion(out, labels)
- loss.backward()
- optimizer.step()
- optimizer.zero_grad()
-
- if i % 1 == 0:
- out = out.argmax(dim=1)
- accuracy = (out == labels).sum().item() / len(labels)
-
- print('epochs:',i, 'loss:',loss.item(),'accuracy:', accuracy)
-
- if i == epochs:
- torch.save(model, 'text_classfiy.model')
- #model_load = torch.load('model/命名实体识别_中文.model')
- break
- #测试函数
- def test():
- model.eval()
- correct = 0
- total = 0
-
- loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
- batch_size=10,
- collate_fn=collate_fn,
- shuffle=True,
- drop_last=True)
-
- for i, (input_ids, attention_mask, token_type_ids,
- labels) in enumerate(loader_test):
-
- if i == 5:
- break
-
- with torch.no_grad():
- out = model(input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids)
-
- out = out.argmax(dim=1)
- correct += (out == labels).sum().item()
- total += len(labels)
-
- print(correct / total)
可以调用测试函数进行测试,看看模型训练效果。
欢迎继续关注 深度学习实战案例,持续更新。获取数据可私聊。
往期作品:
深度学习实战项目
3.深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类
4.深度学习实战4-卷积神经网络(DenseNet)数学图形识别+题目模式识别
5.深度学习实战5-卷积神经网络(CNN)中文OCR识别项目
6.深度学习实战6-卷积神经网络(Pytorch)+聚类分析实现空气质量与天气预测
9.深度学习实战9-文本生成图像-本地电脑实现text2img
10.深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)
11.深度学习实战11(进阶版)-BERT模型的微调应用-文本分类案例
12.深度学习实战12(进阶版)-利用Dewarp实现文本扭曲矫正
13.深度学习实战13(进阶版)-文本纠错功能,经常写错别字的小伙伴的福星
14.深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了
15.深度学习实战15(进阶版)-让机器进行阅读理解+你可以变成出题者提问
16.深度学习实战16(进阶版)-虚拟截图识别文字-可以做纸质合同和表格识别
17.深度学习实战17(进阶版)-智能辅助编辑平台系统的搭建与开发案例
18.深度学习实战18(进阶版)-NLP的15项任务大融合系统,可实现市面上你能想到的NLP任务
19.深度学习实战19(进阶版)-ChatGPT的本地实现部署测试,自己的平台就可以实现ChatGPT
...(待更新)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。