赞
踩
在pytorch的官方文档里面,有关于LSTM做命名实体识别的介绍,https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html
当然,官方的文档肯定存在一些新手在利用lstm做命名实体识别过程中的一些不全面的地方,我在这里对这些代码进行了补全,但是关于他们的原理就不做多的介绍了。
首先是模型的lstm的搭建
- import torch.nn as nn
- import torch.nn.functional as F
-
-
- class LSTM_Model(nn.Module):
- def __init__(self, vocabSize, embedDim, hiddenDim, tagSize):
- super(LSTM_Model, self).__init__()
- self.embeds = nn.Embedding(vocabSize, embedDim)
- self.lstm = nn.LSTM(embedDim, hiddenDim)
- self.hidden2tag = nn.Linear(hiddenDim, tagSize)
-
- def forward(self, sentSeq):
- embeds = self.embeds(sentSeq)
- output, hidden = self.lstm(embeds.view(len(sentSeq), 1, -1))
- tagSpace = self.hidden2tag(output.view(len(sentSeq), -1))
- result = F.log_softmax(tagSpace, dim=1)
- return result
搭建完成后对模型进行训练,下面是训练模型的代码
- model = LSTM_Model(len(word2id), EMBEDDING_DIM, HIDDEN_DIM, len(tag2id))
- lossFunction = nn.NLLLoss()
- optimzer = optim.SGD(model.parameters(), lr=1e-1)
-
- for epoch in range(300):
- for wordList, tagList in zip(wordLists, tagLists):
- model.zero_grad() # 清除积累梯度
- input = torch.tensor([word2id[word] for word in wordList])
- tagSeq = torch.tensor([tag2id[tag] for tag in tagList])
- tagScore = model(input)
- loss = lossFunction(tagScore, tagSeq)
- loss.backward()
- optimzer.step()
看看训练后的结果
- with torch.no_grad():
- testText = ['欧', '美', '港', '台']
- testSeq = torch.tensor([word2id[word] for word in testText]).long()
- tags_scores = model(testSeq)
- print(tags_scores)
- _, predictId = torch.max(tags_scores, dim=1)
- id2tag = dict((id, tag) for tag, id in tag2id.items())
- tagList = [id2tag[id] for id in predictId.numpy()]
- printZip(testText, tagList)
其中我在下面写了读取训练数据的代码filePath.py
- import sys
- path=sys.path[0].split('\\')
- path.pop(-1)
- basePath='/'.join(path)+'/data'
loadText.py
- import time
- from os import listdir
-
-
- class loadData():
- def __init__(self):
- pass
-
- def loadLists(self, filename):
- print('loading data...')
- textLines = open(filename, encoding='utf-8').readlines()
- wordLists = []
- tagLists = []
- wordList = []
- tagList = []
- for textLine in textLines:
- if textLine != '\n':
- word, tag = textLine.strip().split('\t')
- wordList.append(word)
- tagList.append(tag)
- else:
- wordLists.append(wordList)
- tagLists.append(tagList)
- wordList = []
- tagList = []
- print('loading done.')
- return wordLists, tagLists
-
- def getVocab(self, sentence):
- vocab = {}
- for word in sentence:
- if word not in vocab:
- vocab[word] = len(vocab)
- return vocab
-
- def text2sentences(self, filename):
- textList = open(filename, encoding='utf-8').read().split('\n')
- sentences = []
- for text in textList:
- sentence = []
- for word in text:
- if word != ' ':
- sentence.append(word)
- sentences.append(sentence)
- return sentences
-
- def loadList(self, filename):
- """
- :return: wordList,tagList
- """
- wordList = []
- tagList = []
- textLines = open(filename, encoding='utf-8').readlines()
- for textLine in textLines:
- if textLine != '\n':
- text_list = textLine.strip().split('\t')
- wordList.append(text_list[0])
- tagList.append(text_list[1])
- return wordList, tagList
pprint.py
- def printZip(list1, list2):
- pairs = []
- for node1, node2 in zip(list1, list2):
- pairs.append(node1)
- pairs.append(node2)
- print(pairs)
代码就这麽多,训练数据的格式就是这个样子的
放在你的pycharm上就可以用stm模型实现命名实体识别了。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。