当前位置:   article > 正文

利用LSTM做命名实体识别_基于lstm的命名实体识别

基于lstm的命名实体识别

pytorch的官方文档里面,有关于LSTM做命名实体识别的介绍,https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html

当然,官方的文档肯定存在一些新手在利用lstm做命名实体识别过程中的一些不全面的地方,我在这里对这些代码进行了补全,但是关于他们的原理就不做多的介绍了。

首先是模型的lstm的搭建

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class LSTM_Model(nn.Module):
  4. def __init__(self, vocabSize, embedDim, hiddenDim, tagSize):
  5. super(LSTM_Model, self).__init__()
  6. self.embeds = nn.Embedding(vocabSize, embedDim)
  7. self.lstm = nn.LSTM(embedDim, hiddenDim)
  8. self.hidden2tag = nn.Linear(hiddenDim, tagSize)
  9. def forward(self, sentSeq):
  10. embeds = self.embeds(sentSeq)
  11. output, hidden = self.lstm(embeds.view(len(sentSeq), 1, -1))
  12. tagSpace = self.hidden2tag(output.view(len(sentSeq), -1))
  13. result = F.log_softmax(tagSpace, dim=1)
  14. return result

搭建完成后对模型进行训练,下面是训练模型的代码

  1. model = LSTM_Model(len(word2id), EMBEDDING_DIM, HIDDEN_DIM, len(tag2id))
  2. lossFunction = nn.NLLLoss()
  3. optimzer = optim.SGD(model.parameters(), lr=1e-1)
  4. for epoch in range(300):
  5. for wordList, tagList in zip(wordLists, tagLists):
  6. model.zero_grad() # 清除积累梯度
  7. input = torch.tensor([word2id[word] for word in wordList])
  8. tagSeq = torch.tensor([tag2id[tag] for tag in tagList])
  9. tagScore = model(input)
  10. loss = lossFunction(tagScore, tagSeq)
  11. loss.backward()
  12. optimzer.step()

看看训练后的结果

  1. with torch.no_grad():
  2. testText = ['欧', '美', '港', '台']
  3. testSeq = torch.tensor([word2id[word] for word in testText]).long()
  4. tags_scores = model(testSeq)
  5. print(tags_scores)
  6. _, predictId = torch.max(tags_scores, dim=1)
  7. id2tag = dict((id, tag) for tag, id in tag2id.items())
  8. tagList = [id2tag[id] for id in predictId.numpy()]
  9. printZip(testText, tagList)

其中我在下面写了读取训练数据的代码filePath.py

  1. import sys
  2. path=sys.path[0].split('\\')
  3. path.pop(-1)
  4. basePath='/'.join(path)+'/data'

loadText.py

  1. import time
  2. from os import listdir
  3. class loadData():
  4. def __init__(self):
  5. pass
  6. def loadLists(self, filename):
  7. print('loading data...')
  8. textLines = open(filename, encoding='utf-8').readlines()
  9. wordLists = []
  10. tagLists = []
  11. wordList = []
  12. tagList = []
  13. for textLine in textLines:
  14. if textLine != '\n':
  15. word, tag = textLine.strip().split('\t')
  16. wordList.append(word)
  17. tagList.append(tag)
  18. else:
  19. wordLists.append(wordList)
  20. tagLists.append(tagList)
  21. wordList = []
  22. tagList = []
  23. print('loading done.')
  24. return wordLists, tagLists
  25. def getVocab(self, sentence):
  26. vocab = {}
  27. for word in sentence:
  28. if word not in vocab:
  29. vocab[word] = len(vocab)
  30. return vocab
  31. def text2sentences(self, filename):
  32. textList = open(filename, encoding='utf-8').read().split('\n')
  33. sentences = []
  34. for text in textList:
  35. sentence = []
  36. for word in text:
  37. if word != ' ':
  38. sentence.append(word)
  39. sentences.append(sentence)
  40. return sentences
  41. def loadList(self, filename):
  42. """
  43. :return: wordList,tagList
  44. """
  45. wordList = []
  46. tagList = []
  47. textLines = open(filename, encoding='utf-8').readlines()
  48. for textLine in textLines:
  49. if textLine != '\n':
  50. text_list = textLine.strip().split('\t')
  51. wordList.append(text_list[0])
  52. tagList.append(text_list[1])
  53. return wordList, tagList

pprint.py

  1. def printZip(list1, list2):
  2. pairs = []
  3. for node1, node2 in zip(list1, list2):
  4. pairs.append(node1)
  5. pairs.append(node2)
  6. print(pairs)

代码就这麽多,训练数据的格式就是这个样子的

放在你的pycharm上就可以用stm模型实现命名实体识别了。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/353029
推荐阅读
相关标签
  

闽ICP备14008679号