GitHub源代码:Befineyou/bert-bilstm-in-Sentiment-classification: The author applies BERT+BILSTM to emotion classification (github.com)
本文使用的是中文bert,需要提前从huggingface官网中将预训练好的模型组件下载下来,下载网站来自:bert-base-chinese at main (huggingface.co)
数据集来源:疫情期间网民情绪识别 竞赛 - DataFountain
- from transformers import BertTokenizer
- token = BertTokenizer.from_pretrained('bert-base-chinese')
- print(token)
- train_data = pd.read_csv('data/train_clean.csv')
- train_dataset = Dataset.from_pandas(train_data)
- class Dataset(torch.utils.data.Dataset):
- def __init__(self):
- self.dataset = train_dataset
- def __len__(self):
- return len(self.dataset)
- def __getitem__(self, item):
- text = self.dataset[item]['text']
- label = self.dataset[item]['label']
- return text, label
- train_dataset = Dataset()
input_ids 代表句子中每个字的词典编号
attention_mask 只有0或1,0代表空,也就是PAD
token_type_ids 只有0或1,0代表第一个句子和特殊符号,1代表第二个句子
- def collate_fn(data):
- sents = [i[0] for i in data]
- labels = [i[1] for i in data]
- data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
- truncation=True,
- padding='max_length',
- max_length=500,
- return_tensors='pt',
- return_length=True)
- input_ids = data['input_ids']
- attention_mask = data['attention_mask']
- token_type_ids = data['token_type_ids']
- #labels = torch.LongTensor(labels)
- # labels = torch.tensor(labels).long()
- labels = torch.tensor([label if label != -1 else 0 for label in labels]).long()
- return input_ids, attention_mask, token_type_ids, labels
- train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
- batch_size=16,
- collate_fn=collate_fn,
- shuffle=True,
- drop_last=True)
- from transformers import BertModel
- pretrained = BertModel.from_pretrained('bert-base-chinese')
- for param in pretrained.parameters():
- param.requires_grad_(False)
- class BertBiLSTMClassifier(nn.Module):
- def __init__(self, num_classes, hidden_size=768, lstm_hidden_size=128, lstm_layers=1):
- super(BertBiLSTMClassifier, self).__init__()
- # BiLSTM层
- self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=lstm_hidden_size, num_layers=lstm_layers,
- batch_first=True, bidirectional=True)
- # 全连接层用于分类
- self.fc = nn.Linear(lstm_hidden_size * 2, num_classes)
- def forward(self, input_ids, attention_mask, token_type_ids):
- # BERT的前向传播
- with torch.no_grad():
- outputs = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
- #pooled_output = outputs.pooler_output
- last_hidden_state = outputs.last_hidden_state
- # 将BERT输出输入BiLSTM
- lstm_out, _ = self.lstm(last_hidden_state)
- # 提取BiLSTM的最后一层输出
- lstm_out = lstm_out[:, -1, :]
- # 全连接层分类
- logits = self.fc(lstm_out)
- return logits
model = BertBiLSTMClassifier(num_classes)
- def train():
- optimizer = AdamW(model.parameters(), lr=5e-4)
- criterion = torch.nn.CrossEntropyLoss()
- scheduler = get_scheduler(name='linear',
- num_warmup_steps=0,
- num_training_steps=len(loader),
- optimizer=optimizer)
- model.train()
- for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
- out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
- loss = criterion(out, labels)
- loss.backward()
- optimizer.step()
- scheduler.step()
- optimizer.zero_grad()
- if i % 10 == 0:
- out = out.argmax(dim=1)
- accuracy = (out == labels).sum().item() / len(labels)
- lr = optimizer.state_dict()['param_groups'][0]['lr']
- print(i, loss.item(), lr, accuracy)
- if i % 90 == 0:
- torch.save(model.state_dict(), f'bert_cnn_model_epoch_{i}.pth')
- def test():
- loader_test = torch.utils.data.DataLoader(dataset=test_dataset,
- batch_size=32,
- collate_fn=collate_fn,
- shuffle=True,
- drop_last=True)
- model.eval()
- correct = 0
- total = 0
- for i,(input_ids,attention_mask,token_type_ids,labels) in enumerate(loader_test):
- if i==5:
- break
- print(i)
- with torch.no_grad():
- out = model(input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids = token_type_ids)
- #out = out.argmax(dim=1)
- out = torch.argmax(out, dim=1)
- correct +=(out==labels).sum().item()
- total +=len(labels)
- print(correct/total)
- test()
