当前位置:   article > 正文

中文命名实体识别NER

中文命名实体识别

命名实体识别(英语:Named Entity Recognition),简称NER,是指识别文本中具有特定意义的实体,主要包括人名、地名、机构名、专有名词等,以及时间、数量、货币、比例数值等文字。目前在NER上表现较好的模型都是基于深度学习或者是统计学习的方法的,这些方法共同的特点都是需要大量的数据来进行学习,本文使用的数据集是2018ACL论文中新浪财经收集的简历数据。

数据集链接:https://github.com/jiesutd/LatticeLSTM

标注集采用BIOES(B表示实体开头,E表示实体结尾,I表示在实体内部,O表示非实体,S表示单个实体),句子之间用一个空行隔开。

对于命名实体识别其他方法举例

 

 

常用的模型以及涉及到的主要代码

 1、隐马尔可夫模型(HMM)

隐马尔可夫模型描述由一个隐藏的马尔科夫链随机生成不可观测的状态随机序列,再由各个状态生成一个观测而产生观测随机序列的过程(李航 统计学习方法)。隐马尔可夫模型由初始状态分布,状态转移概率矩阵以及观测概率矩阵所确定。上面的定义太过学术看不懂没关系,我们只需要知道,NER本质上可以看成是一种序列标注问题(预测每个字的BIOES标记),在使用HMM解决NER这种序列标注问题的时候,我们所能观测到的是字组成的序列(观测序列),观测不到的是每个字对应的标注(状态序列)。对应的,HMM的三个要素可以解释为,初始状态分布就是每一个标注作为句子第一个字的标注的概率,状态转移概率矩阵就是由某一个标注转移到下一个标注的概率(设状态转移矩阵为  ,那么若前一个词的标注为  ,则下一个词的标注为  的概率为  ),观测概率矩阵就是指在某个标注下,生成某个词的概率。根据HMM的三个要素,我们可以定义如下的HMM模型:

  1. class HMM(object):
  2. def __init__(self, N, M):
  3. """Args:
  4. N: 状态数,这里对应存在的标注的种类
  5. M: 观测数,这里对应有多少不同的字
  6. """
  7. self.N = N
  8. self.M = M
  9. # 状态转移概率矩阵 A[i][j]表示从i状态转移到j状态的概率
  10. self.A = torch.zeros(N, N)
  11. # 观测概率矩阵, B[i][j]表示i状态下生成j观测的概率
  12. self.B = torch.zeros(N, M)
  13. # 初始状态概率 Pi[i]表示初始时刻为状态i的概率
  14. self.Pi = torch.zeros(N)

有了模型定义,接下来的问题就是训练模型了。HMM模型的训练过程对应隐马尔可夫模型的学习问题(李航 统计学习方法),实际上就是根据训练数据根据最大似然的方法估计模型的三个要素,即上文提到的初始状态分布、状态转移概率矩阵以及观测概率矩阵。举个例子帮助理解,在估计初始状态分布的时候,假如某个标记在数据集中作为句子第一个字的标记的次数为k,句子的总数为N,那么该标记作为句子第一个字的概率可以近似估计为k/N,很简单对吧,使用这种方法,我们近似估计HMM的三个要素,代码如下(出现过的函数将用省略号代替):

  1. class HMM(object):
  2. def __init__(self, N, M):
  3. ....
  4. def train(self, word_lists, tag_lists, word2id, tag2id):
  5. """HMM的训练,即根据训练语料对模型参数进行估计,
  6. 因为我们有观测序列以及其对应的状态序列,所以我们
  7. 可以使用极大似然估计的方法来估计隐马尔可夫模型的参数
  8. 参数:
  9. word_lists: 列表,其中每个元素由字组成的列表,如 ['担','任','科','员']
  10. tag_lists: 列表,其中每个元素是由对应的标注组成的列表,如 ['O','O','B-TITLE', 'E-TITLE']
  11. word2id: 将字映射为ID
  12. tag2id: 字典,将标注映射为ID
  13. """
  14. assert len(tag_lists) == len(word_lists)
  15. # 估计转移概率矩阵
  16. for tag_list in tag_lists:
  17. seq_len = len(tag_list)
  18. for i in range(seq_len - 1):
  19. current_tagid = tag2id[tag_list[i]]
  20. next_tagid = tag2id[tag_list[i+1]]
  21. self.A[current_tagid][next_tagid] += 1
  22. # 一个重要的问题:如果某元素没有出现过,该位置为0,这在后续的计算中是不允许的
  23. # 解决方法:我们将等于0的概率加上很小的数
  24. self.A[self.A == 0.] = 1e-10
  25. self.A = self.A / self.A.sum(dim=1, keepdim=True)
  26. # 估计观测概率矩阵
  27. for tag_list, word_list in zip(tag_lists, word_lists):
  28. assert len(tag_list) == len(word_list)
  29. for tag, word in zip(tag_list, word_list):
  30. tag_id = tag2id[tag]
  31. word_id = word2id[word]
  32. self.B[tag_id][word_id] += 1
  33. self.B[self.B == 0.] = 1e-10
  34. self.B = self.B / self.B.sum(dim=1, keepdim=True)
  35. # 估计初始状态概率
  36. for tag_list in tag_lists:
  37. init_tagid = tag2id[tag_list[0]]
  38. self.Pi[init_tagid] += 1
  39. self.Pi[self.Pi == 0.] = 1e-10
  40. self.Pi = self.Pi / self.Pi.sum()

模型训练完毕之后,要利用训练好的模型进行解码,就是对给定的模型未见过的句子,求句子中的每个字对应的标注,针对这个解码问题,我们使用的是维特比(viterbi)算法。关于该算法的数学推导,可以查阅一下李航统计学习方法。

HMM存在两个缺陷:1)观察值之间严格独立,观测到的句子中每个字相互独立

2)状态转移过程中当前状态只与前一状态有关,没有关注到后一时刻的状态

HMM代码实现的主要模型部分如下:

  1. import torch
  2. class HMM(object):
  3. def __init__(self, N, M):
  4. """Args:
  5. N: 状态数,这里对应存在的标注的种类
  6. M: 观测数,这里对应有多少不同的字
  7. """
  8. self.N = N
  9. self.M = M
  10. # 状态转移概率矩阵 A[i][j]表示从i状态转移到j状态的概率
  11. self.A = torch.zeros(N, N)
  12. # 观测概率矩阵, B[i][j]表示i状态下生成j观测的概率
  13. self.B = torch.zeros(N, M)
  14. # 初始状态概率 Pi[i]表示初始时刻为状态i的概率
  15. self.Pi = torch.zeros(N)
  16. def train(self, word_lists, tag_lists, word2id, tag2id):
  17. """HMM的训练,即根据训练语料对模型参数进行估计,
  18. 因为我们有观测序列以及其对应的状态序列,所以我们
  19. 可以使用极大似然估计的方法来估计隐马尔可夫模型的参数
  20. 参数:
  21. word_lists: 列表,其中每个元素由字组成的列表,如 ['担','任','科','员']
  22. tag_lists: 列表,其中每个元素是由对应的标注组成的列表,如 ['O','O','B-TITLE', 'E-TITLE']
  23. word2id: 将字映射为ID
  24. tag2id: 字典,将标注映射为ID
  25. """
  26. assert len(tag_lists) == len(word_lists)
  27. # 估计转移概率矩阵
  28. for tag_list in tag_lists:
  29. seq_len = len(tag_list)
  30. for i in range(seq_len - 1):
  31. current_tagid = tag2id[tag_list[i]]
  32. next_tagid = tag2id[tag_list[i+1]]
  33. self.A[current_tagid][next_tagid] += 1
  34. # 问题:如果某元素没有出现过,该位置为0,这在后续的计算中是不允许的
  35. # 解决方法:我们将等于0的概率加上很小的数
  36. self.A[self.A == 0.] = 1e-10
  37. self.A = self.A / self.A.sum(dim=1, keepdim=True)
  38. # 估计观测概率矩阵
  39. for tag_list, word_list in zip(tag_lists, word_lists):
  40. assert len(tag_list) == len(word_list)
  41. for tag, word in zip(tag_list, word_list):
  42. tag_id = tag2id[tag]
  43. word_id = word2id[word]
  44. self.B[tag_id][word_id] += 1
  45. self.B[self.B == 0.] = 1e-10
  46. self.B = self.B / self.B.sum(dim=1, keepdim=True)
  47. # 估计初始状态概率
  48. for tag_list in tag_lists:
  49. init_tagid = tag2id[tag_list[0]]
  50. self.Pi[init_tagid] += 1
  51. self.Pi[self.Pi == 0.] = 1e-10
  52. self.Pi = self.Pi / self.Pi.sum()
  53. def test(self, word_lists, word2id, tag2id):
  54. pred_tag_lists = []
  55. for word_list in word_lists:
  56. pred_tag_list = self.decoding(word_list, word2id, tag2id)
  57. pred_tag_lists.append(pred_tag_list)
  58. return pred_tag_lists
  59. def decoding(self, word_list, word2id, tag2id):
  60. """
  61. 使用维特比算法对给定观测序列求状态序列, 这里就是对字组成的序列,求其对应的标注。
  62. 维特比算法实际是用动态规划解隐马尔可夫模型预测问题,即用动态规划求概率最大路径(最优路径)
  63. 这时一条路径对应着一个状态序列
  64. """
  65. # 问题:整条链很长的情况下,十分多的小概率相乘,最后可能造成下溢
  66. # 解决办法:采用对数概率,这样源空间中的很小概率,就被映射到对数空间的大的负数
  67. # 同时相乘操作也变成简单的相加操作
  68. A = torch.log(self.A)
  69. B = torch.log(self.B)
  70. Pi = torch.log(self.Pi)
  71. # 初始化 维比特矩阵viterbi 它的维度为[状态数, 序列长度]
  72. # 其中viterbi[i, j]表示标注序列的第j个标注为i的所有单个序列(i_1, i_2, ..i_j)出现的概率最大值
  73. seq_len = len(word_list)
  74. viterbi = torch.zeros(self.N, seq_len)
  75. # backpointer是跟viterbi一样大小的矩阵
  76. # backpointer[i, j]存储的是 标注序列的第j个标注为i时,第j-1个标注的id
  77. # 等解码的时候,我们用backpointer进行回溯,以求出最优路径
  78. backpointer = torch.zeros(self.N, seq_len).long()
  79. # self.Pi[i] 表示第一个字的标记为i的概率
  80. # Bt[word_id]表示字为word_id的时候,对应各个标记的概率
  81. # self.A.t()[tag_id]表示各个状态转移到tag_id对应的概率
  82. # 所以第一步为
  83. start_wordid = word2id.get(word_list[0], None)
  84. Bt = B.t()
  85. if start_wordid is None:
  86. # 如果字不再字典里,则假设状态的概率分布是均匀的
  87. bt = torch.log(torch.ones(self.N) / self.N)
  88. else:
  89. bt = Bt[start_wordid]
  90. viterbi[:, 0] = Pi + bt
  91. backpointer[:, 0] = -1
  92. # 递推公式:
  93. # viterbi[tag_id, step] = max(viterbi[:, step-1]* self.A.t()[tag_id] * Bt[word])
  94. # 其中word是step时刻对应的字
  95. # 由上述递推公式求后续各步
  96. for step in range(1, seq_len):
  97. wordid = word2id.get(word_list[step], None)
  98. # 处理字不在字典中的情况
  99. # bt是在t时刻字为wordid时,状态的概率分布
  100. if wordid is None:
  101. # 如果字不再字典里,则假设状态的概率分布是均匀的
  102. bt = torch.log(torch.ones(self.N) / self.N)
  103. else:
  104. bt = Bt[wordid] # 否则从观测概率矩阵中取bt
  105. for tag_id in range(len(tag2id)):
  106. max_prob, max_id = torch.max(
  107. viterbi[:, step-1] + A[:, tag_id],
  108. dim=0
  109. )
  110. viterbi[tag_id, step] = max_prob + bt[tag_id]
  111. backpointer[tag_id, step] = max_id
  112. # 终止, t=seq_len 即 viterbi[:, seq_len]中的最大概率,就是最优路径的概率
  113. best_path_prob, best_path_pointer = torch.max(
  114. viterbi[:, seq_len-1], dim=0
  115. )
  116. # 回溯,求最优路径
  117. best_path_pointer = best_path_pointer.item()
  118. best_path = [best_path_pointer]
  119. for back_step in range(seq_len-1, 0, -1):
  120. best_path_pointer = backpointer[best_path_pointer, back_step]
  121. best_path_pointer = best_path_pointer.item()
  122. best_path.append(best_path_pointer)
  123. # 将tag_id组成的序列转化为tag
  124. assert len(best_path) == len(word_list)
  125. id2tag = dict((id_, tag) for tag, id_ in tag2id.items())
  126. tag_list = [id2tag[id_] for id_ in reversed(best_path)]
  127. return tag_list

2、条件随机场

上面讲的HMM模型中存在两个假设,一是输出观察值之间严格独立,二是状态转移过程中当前状态只与前一状态有关。也就是说,在命名实体识别的场景下,HMM认为观测到的句子中的每个字都是相互独立的,而且当前时刻的标注只与前一时刻的标注相关。但实际上,命名实体识别往往需要更多的特征,比如词性,词的上下文等等,同时当前时刻的标注应该与前一时刻以及后一时刻的标注都相关联。由于这两个假设的存在,显然HMM模型在解决命名实体识别的问题上是存在缺陷的。

而条件随机场就没有这种问题,它通过引入自定义的特征函数,不仅可以表达观测之间的依赖,还可表示当前观测与前后多个状态之间的复杂依赖,可以有效克服HMM模型面临的问题。条件随机场数学公式不在此讲述了。其解码也是采用维特比算法。

  1. from sklearn_crfsuite import CRF # CRF的具体实现太过复杂,这里我们借助一个外部的库
  2. def word2features(sent, i):
  3. """抽取单个字的特征"""
  4. word = sent[i]
  5. prev_word = "<s>" if i == 0 else sent[i-1]
  6. next_word = "</s>" if i == (len(sent)-1) else sent[i+1]
  7. # 因为每个词相邻的词会影响这个词的标记
  8. # 所以我们使用:
  9. # 前一个词,当前词,后一个词,
  10. # 前一个词+当前词, 当前词+后一个词
  11. # 作为特征
  12. features = {
  13. 'w': word,
  14. 'w-1': prev_word,
  15. 'w+1': next_word,
  16. 'w-1:w': prev_word+word,
  17. 'w:w+1': word+next_word,
  18. 'bias': 1
  19. }
  20. return features
  21. def sent2features(sent):
  22. """抽取序列特征"""
  23. return [word2features(sent, i) for i in range(len(sent))]
  24. class CRFModel(object):
  25. def __init__(self,
  26. algorithm='lbfgs',
  27. c1=0.1,
  28. c2=0.1,
  29. max_iterations=100,
  30. all_possible_transitions=False
  31. ):
  32. self.model = CRF(algorithm=algorithm,
  33. c1=c1,
  34. c2=c2,
  35. max_iterations=max_iterations,
  36. all_possible_transitions=all_possible_transitions)
  37. def train(self, sentences, tag_lists):
  38. """训练模型"""
  39. features = [sent2features(s) for s in sentences]
  40. self.model.fit(features, tag_lists)
  41. def test(self, sentences):
  42. """解码,对给定句子预测其标注"""
  43. features = [sent2features(s) for s in sentences]
  44. pred_tag_lists = self.model.predict(features)
  45. return pred_tag_lists

 

3、Bi_LSTM_CRF

简单的LSTM的优点是能够通过双向的设置学习到观测序列(输入的字)之间的依赖,在训练过程中,LSTM能够根据目标(比如识别实体)自动提取观测序列的特征,但是缺点是无法学习到状态序列(输出的标注)之间的关系,要知道,在命名实体识别任务中,标注之间是有一定的关系的,比如B类标注(表示某实体的开头)后面不会再接一个B类标注,所以LSTM在解决NER这类序列标注任务时,虽然可以省去很繁杂的特征工程,但是也存在无法学习到标注上下文的缺点。相反,CRF的优点就是能对隐含状态建模,学习状态序列的特点,但它的缺点是需要手动提取序列特征。所以一般的做法是,在LSTM后面再加一层CRF,以获得两者的优点。

下面是给Bi-LSTM加一层CRF的代码实现:

  1. from itertools import zip_longest
  2. from copy import deepcopy
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from .util import tensorized, sort_by_lengths, cal_loss, cal_lstm_crf_loss
  7. from .config import TrainingConfig, LSTMConfig
  8. from .bilstm import BiLSTM
  9. class BILSTM_Model(object):
  10. def __init__(self, vocab_size, out_size, crf=True):
  11. """功能:对LSTM的模型进行训练与测试
  12. 参数:
  13. vocab_size:词典大小
  14. out_size:标注种类
  15. crf选择是否添加CRF层"""
  16. self.device = torch.device(
  17. "cuda" if torch.cuda.is_available() else "cpu")
  18. # 加载模型参数
  19. self.emb_size = LSTMConfig.emb_size
  20. self.hidden_size = LSTMConfig.hidden_size
  21. self.crf = crf
  22. # 根据是否添加crf初始化不同的模型 选择不一样的损失计算函数
  23. if not crf:
  24. self.model = BiLSTM(vocab_size, self.emb_size,
  25. self.hidden_size, out_size).to(self.device)
  26. self.cal_loss_func = cal_loss
  27. else:
  28. self.model = BiLSTM_CRF(vocab_size, self.emb_size,
  29. self.hidden_size, out_size).to(self.device)
  30. self.cal_loss_func = cal_lstm_crf_loss
  31. # 加载训练参数:
  32. self.epoches = TrainingConfig.epoches
  33. self.print_step = TrainingConfig.print_step
  34. self.lr = TrainingConfig.lr
  35. self.batch_size = TrainingConfig.batch_size
  36. # 初始化优化器
  37. self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
  38. # 初始化其他指标
  39. self.step = 0
  40. self._best_val_loss = 1e18
  41. self.best_model = None
  42. def train(self, word_lists, tag_lists,
  43. dev_word_lists, dev_tag_lists,
  44. word2id, tag2id):
  45. # 对数据集按照长度进行排序
  46. word_lists, tag_lists, _ = sort_by_lengths(word_lists, tag_lists)
  47. dev_word_lists, dev_tag_lists, _ = sort_by_lengths(
  48. dev_word_lists, dev_tag_lists)
  49. B = self.batch_size
  50. for e in range(1, self.epoches+1):
  51. self.step = 0
  52. losses = 0.
  53. for ind in range(0, len(word_lists), B):
  54. batch_sents = word_lists[ind:ind+B]
  55. batch_tags = tag_lists[ind:ind+B]
  56. losses += self.train_step(batch_sents,
  57. batch_tags, word2id, tag2id)
  58. if self.step % TrainingConfig.print_step == 0:
  59. total_step = (len(word_lists) // B + 1)
  60. print("Epoch {}, step/total_step: {}/{} {:.2f}% Loss:{:.4f}".format(
  61. e, self.step, total_step,
  62. 100. * self.step / total_step,
  63. losses / self.print_step
  64. ))
  65. losses = 0.
  66. # 每轮结束测试在验证集上的性能,保存最好的一个
  67. val_loss = self.validate(
  68. dev_word_lists, dev_tag_lists, word2id, tag2id)
  69. print("Epoch {}, Val Loss:{:.4f}".format(e, val_loss))
  70. def train_step(self, batch_sents, batch_tags, word2id, tag2id):
  71. self.model.train()
  72. self.step += 1
  73. # 准备数据
  74. tensorized_sents, lengths = tensorized(batch_sents, word2id)
  75. tensorized_sents = tensorized_sents.to(self.device)
  76. targets, lengths = tensorized(batch_tags, tag2id)
  77. targets = targets.to(self.device)
  78. # forward
  79. scores = self.model(tensorized_sents, lengths)
  80. # 计算损失 更新参数
  81. self.optimizer.zero_grad()
  82. loss = self.cal_loss_func(scores, targets, tag2id).to(self.device)
  83. loss.backward()
  84. self.optimizer.step()
  85. return loss.item()
  86. def validate(self, dev_word_lists, dev_tag_lists, word2id, tag2id):
  87. self.model.eval()
  88. with torch.no_grad():
  89. val_losses = 0.
  90. val_step = 0
  91. for ind in range(0, len(dev_word_lists), self.batch_size):
  92. val_step += 1
  93. # 准备batch数据
  94. batch_sents = dev_word_lists[ind:ind+self.batch_size]
  95. batch_tags = dev_tag_lists[ind:ind+self.batch_size]
  96. tensorized_sents, lengths = tensorized(
  97. batch_sents, word2id)
  98. tensorized_sents = tensorized_sents.to(self.device)
  99. targets, lengths = tensorized(batch_tags, tag2id)
  100. targets = targets.to(self.device)
  101. # forward
  102. scores = self.model(tensorized_sents, lengths)
  103. # 计算损失
  104. loss = self.cal_loss_func(
  105. scores, targets, tag2id).to(self.device)
  106. val_losses += loss.item()
  107. val_loss = val_losses / val_step
  108. if val_loss < self._best_val_loss:
  109. print("保存模型...")
  110. self.best_model = deepcopy(self.model)
  111. self._best_val_loss = val_loss
  112. return val_loss
  113. def test(self, word_lists, tag_lists, word2id, tag2id):
  114. """返回最佳模型在测试集上的预测结果"""
  115. # 准备数据
  116. word_lists, tag_lists, indices = sort_by_lengths(word_lists, tag_lists)
  117. tensorized_sents, lengths = tensorized(word_lists, word2id)
  118. tensorized_sents = tensorized_sents.to(self.device)
  119. self.best_model.eval()
  120. with torch.no_grad():
  121. batch_tagids = self.best_model.test(
  122. tensorized_sents, lengths, tag2id)
  123. # 将id转化为标注
  124. pred_tag_lists = []
  125. id2tag = dict((id_, tag) for tag, id_ in tag2id.items())
  126. for i, ids in enumerate(batch_tagids):
  127. tag_list = []
  128. if self.crf:
  129. for j in range(lengths[i] - 1): # crf解码过程中,end被舍弃
  130. tag_list.append(id2tag[ids[j].item()])
  131. else:
  132. for j in range(lengths[i]):
  133. tag_list.append(id2tag[ids[j].item()])
  134. pred_tag_lists.append(tag_list)
  135. # indices存有根据长度排序后的索引映射的信息
  136. # 比如若indices = [1, 2, 0] 则说明原先索引为1的元素映射到的新的索引是0,
  137. # 索引为2的元素映射到新的索引是1...
  138. # 下面根据indices将pred_tag_lists和tag_lists转化为原来的顺序
  139. ind_maps = sorted(list(enumerate(indices)), key=lambda e: e[1])
  140. indices, _ = list(zip(*ind_maps))
  141. pred_tag_lists = [pred_tag_lists[i] for i in indices]
  142. tag_lists = [tag_lists[i] for i in indices]
  143. return pred_tag_lists, tag_lists
  144. class BiLSTM_CRF(nn.Module):
  145. def __init__(self, vocab_size, emb_size, hidden_size, out_size):
  146. """初始化参数:
  147. vocab_size:字典的大小
  148. emb_size:词向量的维数
  149. hidden_size:隐向量的维数
  150. out_size:标注的种类
  151. """
  152. super(BiLSTM_CRF, self).__init__()
  153. self.bilstm = BiLSTM(vocab_size, emb_size, hidden_size, out_size)
  154. # CRF实际上就是多学习一个转移矩阵 [out_size, out_size] 初始化为均匀分布
  155. self.transition = nn.Parameter(
  156. torch.ones(out_size, out_size) * 1/out_size)
  157. # self.transition.data.zero_()
  158. def forward(self, sents_tensor, lengths):
  159. # [B, L, out_size]
  160. emission = self.bilstm(sents_tensor, lengths)
  161. # 计算CRF scores, 这个scores大小为[B, L, out_size, out_size]
  162. # 也就是每个字对应一个 [out_size, out_size]的矩阵
  163. # 这个矩阵第i行第j列的元素的含义是:上一时刻tag为i,这一时刻tag为j的分数
  164. batch_size, max_len, out_size = emission.size()
  165. crf_scores = emission.unsqueeze(
  166. 2).expand(-1, -1, out_size, -1) + self.transition.unsqueeze(0)
  167. return crf_scores
  168. def test(self, test_sents_tensor, lengths, tag2id):
  169. """使用维特比算法进行解码"""
  170. start_id = tag2id['<start>']
  171. end_id = tag2id['<end>']
  172. pad = tag2id['<pad>']
  173. tagset_size = len(tag2id)
  174. crf_scores = self.forward(test_sents_tensor, lengths)
  175. device = crf_scores.device
  176. # B:batch_size, L:max_len, T:target set size
  177. B, L, T, _ = crf_scores.size()
  178. # viterbi[i, j, k]表示第i个句子,第j个字对应第k个标记的最大分数
  179. viterbi = torch.zeros(B, L, T).to(device)
  180. # backpointer[i, j, k]表示第i个句子,第j个字对应第k个标记时前一个标记的id,用于回溯
  181. backpointer = (torch.zeros(B, L, T).long() * end_id).to(device)
  182. lengths = torch.LongTensor(lengths).to(device)
  183. # 向前递推
  184. for step in range(L):
  185. batch_size_t = (lengths > step).sum().item()
  186. if step == 0:
  187. # 第一个字它的前一个标记只能是start_id
  188. viterbi[:batch_size_t, step,
  189. :] = crf_scores[: batch_size_t, step, start_id, :]
  190. backpointer[: batch_size_t, step, :] = start_id
  191. else:
  192. max_scores, prev_tags = torch.max(
  193. viterbi[:batch_size_t, step-1, :].unsqueeze(2) +
  194. crf_scores[:batch_size_t, step, :, :], # [B, T, T]
  195. dim=1
  196. )
  197. viterbi[:batch_size_t, step, :] = max_scores
  198. backpointer[:batch_size_t, step, :] = prev_tags
  199. # 在回溯的时候我们只需要用到backpointer矩阵
  200. backpointer = backpointer.view(B, -1) # [B, L * T]
  201. tagids = [] # 存放结果
  202. tags_t = None
  203. for step in range(L-1, 0, -1):
  204. batch_size_t = (lengths > step).sum().item()
  205. if step == L-1:
  206. index = torch.ones(batch_size_t).long() * (step * tagset_size)
  207. index = index.to(device)
  208. index += end_id
  209. else:
  210. prev_batch_size_t = len(tags_t)
  211. new_in_batch = torch.LongTensor(
  212. [end_id] * (batch_size_t - prev_batch_size_t)).to(device)
  213. offset = torch.cat(
  214. [tags_t, new_in_batch],
  215. dim=0
  216. ) # 这个offset实际上就是前一时刻的
  217. index = torch.ones(batch_size_t).long() * (step * tagset_size)
  218. index = index.to(device)
  219. index += offset.long()
  220. try:
  221. tags_t = backpointer[:batch_size_t].gather(
  222. dim=1, index=index.unsqueeze(1).long())
  223. except RuntimeError:
  224. import pdb
  225. pdb.set_trace()
  226. tags_t = tags_t.squeeze(1)
  227. tagids.append(tags_t.tolist())
  228. # tagids:[L-1](L-1是因为扣去了end_token),大小的liebiao
  229. # 其中列表内的元素是该batch在该时刻的标记
  230. # 下面修正其顺序,并将维度转换为 [B, L]
  231. tagids = list(zip_longest(*reversed(tagids), fillvalue=pad))
  232. tagids = torch.Tensor(tagids).long()
  233. # 返回解码的结果
  234. return tagids

注:关于维特比算法推荐看链接,讲解的通俗易懂如何通俗讲解维特比算法

其他学习连接:

Advanced: Making Dynamic Decisions and the Bi-LSTM CRF — PyTorch Tutorials 1.11.0+cu102 documentation

Bi-LSTM-CRF for Sequence Labeling - 知乎

https://github.com/jiesutd/LatticeLSTM

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/352676
推荐阅读
相关标签
  

闽ICP备14008679号