当前位置:   article > 正文

基于torch框架的bert+bilstm+crf的实体识别实战_# 加载bert模型和tokenizer

# 加载bert模型和tokenizer

首先,我们需要导入所需的库:

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. from transformers import BertTokenizer, BertModel

然后定义一些超参数和模型结构:

  1. # 超参数
  2. MAX_LEN = 128
  3. BATCH_SIZE = 32
  4. EPOCHS = 10
  5. LEARNING_RATE = 0.001
  6. # 加载BERT模型和tokenizer
  7. tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
  8. bert_model = BertModel.from_pretrained('bert-base-chinese')
  9. class EntityModel(nn.Module):
  10. def __init__(self, bert_model, hidden_size, num_tags):
  11. super(EntityModel, self).__init__()
  12. self.bert = bert_model
  13. self.dropout = nn.Dropout(0.1)
  14. self.bilstm = nn.LSTM(bidirectional=True, input_size=hidden_size, hidden_size=hidden_size // 2, batch_first=True)
  15. self.fc = nn.Linear(hidden_size, num_tags)
  16. self.crf = CRF(num_tags)
  17. def forward(self, input_ids, attention_mask, labels=None):
  18. outputs = self.bert(input_ids, attention_mask=attention_mask)
  19. sequence_output = outputs[0]
  20. sequence_output = self.dropout(sequence_output)
  21. lstm_output, _ = self.bilstm(sequence_output)
  22. logits = self.fc(lstm_output)
  23. if labels is not None:
  24. loss = -self.crf(logits, labels, mask=attention_mask.byte())
  25. return loss
  26. else:
  27. tags = self.crf.decode(logits, mask=attention_mask.byte())
  28. return tags

在这里,我们使用了BERT模型和BiLSTM层来提取句子的特征,然后通过全连接层将其映射到标签空间,并使用CRF层来对标签序列进行建模。

接下来,我们需要定义一些辅助函数:

  1. def tokenize_and_preserve_labels(text, labels):
  2. tokenized_text = []
  3. token_labels = []
  4. for word, label in zip(text, labels):
  5. tokenized_word = tokenizer.tokenize(word)
  6. n_subwords = len(tokenized_word)
  7. tokenized_text.extend(tokenized_word)
  8. token_labels.extend([label] * n_subwords)
  9. return tokenized_text, token_labels
  10. def pad_sequences(sequences, max_len, padding_value=0):
  11. padded_sequences = torch.zeros((len(sequences), max_len)).long()
  12. for i, seq in enumerate(sequences):
  13. seq_len = len(seq)
  14. if seq_len <= max_len:
  15. padded_sequences[i, :seq_len] = torch.tensor(seq)
  16. else:
  17. padded_sequences[i, :] = torch.tensor(seq[:max_len])
  18. return padded_sequences
  19. def train(model, optimizer, train_dataloader):
  20. model.train()
  21. total_loss = 0
  22. for step, batch in enumerate(train_dataloader):
  23. input_ids = batch['input_ids'].to(device)
  24. attention_mask = batch['attention_mask'].to(device)
  25. labels = batch['labels'].to(device)
  26. loss = model(input_ids, attention_mask, labels)
  27. total_loss += loss.item()
  28. optimizer.zero_grad()
  29. loss.backward()
  30. optimizer.step()
  31. avg_train_loss = total_loss / len(train_dataloader)
  32. return avg_train_loss
  33. def evaluate(model, eval_dataloader):
  34. model.eval()
  35. total_loss = 0
  36. with torch.no_grad():
  37. for step, batch in enumerate(eval_dataloader):
  38. input_ids = batch['input_ids'].to(device)
  39. attention_mask = batch['attention_mask'].to(device)
  40. labels = batch['labels'].to(device)
  41. loss = model(input_ids, attention_mask, labels)
  42. total_loss += loss.item()
  43. avg_eval_loss = total_loss / len(eval_dataloader)
  44. return avg_eval_loss
  45. def predict(model, text):
  46. model.eval()
  47. tokenized_text = tokenizer.tokenize(text)
  48. tokenized_text_with_labels = [(token, 'O') for token in tokenized_text]
  49. input_ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokenized_text)])
  50. attention_mask = torch.ones_like(input_ids)
  51. with torch.no_grad():
  52. tags = model(input_ids.to(device), attention_mask.to(device))
  53. tag_labels = [id2label[tag] for tag in tags[0]]
  54. return list(zip(tokenized_text, tag_labels))

在这里,我们定义了一个标记化函数,用于将原始文本和标签转换为标记化的文本和标签序列。我们还定义了一个填充函数,用于对序列进行填充,以便它们可以被批处理。然后我们定义了训练、评估和预测函数。

接下来,我们需要加载数据集并将其转换为模型所需的格式:

  1. # 加载数据集
  2. train_data = []
  3. with open('train.txt', 'r', encoding='utf-8') as f:
  4. words = []
  5. labels = []
  6. for line in f:
  7. line = line.strip()
  8. if line == '':
  9. train_data.append((words, labels))
  10. words = []
  11. labels = []
  12. else:
  13. word, label = line.split()
  14. words.append(word)
  15. labels.append(label)
  16. if len(words) > 0:
  17. train_data.append((words, labels))
  18. # 将数据集转换为模型所需的格式
  19. train_input_ids = []
  20. train_attention_masks = []
  21. train_labels = []
  22. for words, labels in train_data:
  23. tokenized_text, token_labels = tokenize_and_preserve_labels(words, labels)
  24. input_ids = tokenizer.convert_tokens_to_ids(tokenized_text)
  25. attention_mask = [1] * len(input_ids)
  26. train_input_ids.append(input_ids)
  27. train_attention_masks.append(attention_mask)
  28. train_labels.append([label2id[label] for label in token_labels])
  29. train_input_ids = pad_sequences(train_input_ids, MAX_LEN)
  30. train_attention_masks = pad_sequences(train_attention_masks, MAX_LEN)
  31. train_labels = pad_sequences(train_labels, MAX_LEN, padding_value=-1)
  32. train_dataset = torch.utils.data.TensorDataset(train_input_ids, train_attention_masks, train_labels)
  33. train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
  34. # 同样地,我们还需要加载验证集和测试集,并将它们转换为模型所需的格式

在这里,我们加载了一个包含训练数据的文件,并将其转换为模型所需的格式。我们使用了标记化函数和填充函数来实现这一点。

最后,我们可以使用上述辅助函数和数据集来训练、评估和测试模型:

  1. # 训练模型
  2. model = EntityModel(bert_model, hidden_size=768, num_tags=len(label2id))
  3. model.to(device)
  4. optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
  5. for epoch in range(EPOCHS):
  6. avg_train_loss = train(model, optimizer, train_dataloader)
  7. avg_eval_loss = evaluate(model, eval_dataloader)
  8. print(f'Epoch {epoch + 1}: train_loss={avg_train_loss:.4f}, eval_loss={avg_eval_loss:.4f}')
  9. # 测试模型
  10. test_sentences = ['今天是个好日子', '我喜欢中国菜', '巴黎是一座美丽的城市']
  11. for sentence in test_sentences:
  12. tags = predict(model, sentence)
  13. print(tags)

在这里,我们使用Adam优化器和交叉熵损失函数来训练模型。然后,我们使用测试集来评估模型的性能,并使用模型来预测一些新句子中的实体。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/570186
推荐阅读
相关标签
  

闽ICP备14008679号