当前位置:   article > 正文

【Bert】(十三)简易问答系统--源码解析(测试)_问答系统源码

问答系统源码

上一篇博客介绍的损失部分就涉及训练的过程。

本篇介绍一下测试。按照上一篇博客介绍损失时,start_logits选取最大的概率值作为起始位置与真实起始位置比较,end_logits选取最大的概率值作为终止位置与真实终止位置比较。那么直观观念上测试只需要分别选取start_logits和end_logits的最大值,就能得到起始位置和终止位置。

但是会碰到如下几个问题

(1)很多时候句子达不到设定的seq_length的长度,假如设定输入模型的整个句子的向量长度为384,但是实际问题+段落的长度才181,而预测的起始位置为230,预测得到的起始位置大于句子的实际长度,这显然不合理。

(2)预测得到的终止位置大于句子的实际长度,这显然不合理。

(3)由于输入模型的句子向量中,第一部分为问题,第二部分为段落。预测得到的起始位置落在第一部分问题的区域间断,这显然不合理。

(4)由于输入模型的句子向量中,第一部分为问题,第二部分为段落。预测得到的终止位置落在第一部分问题的区域间断,这显然不合理。

(5) 由于在【Bert】(十)简易问答系统--数据解析_mjiansun的专栏-CSDN博客中介绍过滑动窗切割段落的情况,是的部分重叠单词有了归属,也就是部分单词属于切割后的第一个句子,有些单词属于切割后的第二个句子,所以预测的起始位置的单词如果不属于本句话,那么也应该判定为不合理。

(6)如果start_logits最大概率所在位置A,end_logits最大概率所在位置B,A的位置在B的后面,这就不符合逻辑,起始位置怎么能在终止位置后面

(7)一般的答案都会有一个长度,一般都不会太长,如果出现一个答案特别长,这显然也不合理

针对上述问题,bert是这样解决的

选取一定数量的候选起始位置和终止位置,让每一个起始位置和终止位置进行排列组合,然后跳过出现上问题的组合,保留满足条件的组合即可。

  1. for (feature_index, feature) in enumerate(features):
  2. result = unique_id_to_result[feature.unique_id]
  3. start_indexes = _get_best_indexes(result.start_logits, n_best_size)#按照得分选取前n_best_size个对应的候选起始位置
  4. end_indexes = _get_best_indexes(result.end_logits, n_best_size)#按照得分选取前n_best_size个对应的候选终止位置
  5. for start_index in start_indexes:
  6. for end_index in end_indexes:
  7. # We could hypothetically create invalid predictions, e.g., predict
  8. # that the start of the span is in the question. We throw out all
  9. # invalid predictions.
  10. if start_index >= len(feature.tokens):
  11. continue
  12. if end_index >= len(feature.tokens):
  13. continue
  14. if start_index not in feature.token_to_orig_map:
  15. continue
  16. if end_index not in feature.token_to_orig_map:
  17. continue
  18. if not feature.token_is_max_context.get(start_index, False):
  19. continue
  20. if end_index < start_index:
  21. continue
  22. length = end_index - start_index + 1
  23. if length > max_answer_length:
  24. continue
  25. prelim_predictions.append(
  26. _PrelimPrediction(
  27. feature_index=feature_index,
  28. start_index=start_index,
  29. end_index=end_index,
  30. start_logit=result.start_logits[start_index],
  31. end_logit=result.end_logits[end_index]))

剔除完一些不合理情况后,如何选出唯一的一个组合

(1)综合起始位置和终止位置的综合得分排序

  1. prelim_predictions = sorted(
  2. prelim_predictions,
  3. key=lambda x: (x.start_logit + x.end_logit),
  4. reverse=True)

将起始位置和终止位置的得分综合起来作为该对组合的最终得分,并按照最终得分排序。

(2)根据位置获取预测的字符串

根据tokens和起始终止位置,可以得到tokens拼接出来的字符串tok_text。根据examples和起始终止位置,可以得到examples拼接出来的字符串orig_text。

通过get_final_text函数综合得到最终字符串final_text,在下面我们会介绍该函数,这里先讲逻辑。

通过n_best_size限制保存的数量,通过seen_predictions防止重复答案。

  1. for pred in prelim_predictions:
  2. if len(nbest) >= n_best_size:
  3. break
  4. feature = features[pred.feature_index]
  5. if pred.start_index > 0: # this is a non-null prediction
  6. tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]#获取经过tokenizer处理过后的tokens的答案字符串
  7. orig_doc_start = feature.token_to_orig_map[pred.start_index]
  8. orig_doc_end = feature.token_to_orig_map[pred.end_index]
  9. orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] #获取原始句子example中的答案字符串
  10. tok_text = " ".join(tok_tokens)
  11. # De-tokenize WordPieces that have been split off.之前分词添加的##需要还原到原来的词
  12. tok_text = tok_text.replace(" ##", "")
  13. tok_text = tok_text.replace("##", "")
  14. # Clean whitespace
  15. tok_text = tok_text.strip()
  16. tok_text = " ".join(tok_text.split())
  17. orig_text = " ".join(orig_tokens)
  18. final_text = get_final_text(tok_text, orig_text, do_lower_case)
  19. if final_text in seen_predictions:
  20. continue
  21. seen_predictions[final_text] = True
  22. else:
  23. final_text = ""
  24. seen_predictions[final_text] = True
  25. nbest.append(
  26. _NbestPrediction(
  27. text=final_text,
  28. start_logit=pred.start_logit,
  29. end_logit=pred.end_logit))

这里我讲解下get_final_text函数, 首先他注释中的例子我看懂了,但是与代码似乎不对应(是我理解错了?)

1)找出起始和终止位置

将orig_text使用tokenizer进行处理,然后比较找出pred_text在orig_text的起始位置。

再根据pred_text的长度得到终止位置。

  1. tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
  2. tok_text = " ".join(tokenizer.tokenize(orig_text))
  3. start_position = tok_text.find(pred_text)
  4. if start_position == -1:
  5. if FLAGS.verbose_logging:
  6. tf.logging.info(
  7. "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
  8. return orig_text
  9. end_position = start_position + len(pred_text) - 1

2)剔除空格的影响

orig_ns_text:剔除空格后的字符串

orig_ns_to_s_map:字典,键表示字符在orig_ns_text中的位置,值表示字符在剔除空格字符串的位置

tok_ns_text:剔除空格后的字符串

tok_ns_to_s_map:字典,键表示字符在orig_ns_text中的位置,值表示字符在剔除空格字符串的位置

  1. def _strip_spaces(text):
  2. ns_chars = []
  3. ns_to_s_map = collections.OrderedDict()
  4. for (i, c) in enumerate(text):
  5. if c == " ":
  6. continue
  7. ns_to_s_map[len(ns_chars)] = i
  8. ns_chars.append(c)
  9. ns_text = "".join(ns_chars)
  10. return (ns_text, ns_to_s_map)
  11. (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
  12. (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)

例如orig_ns_text:'DenverBroncosdefeatedtheNationalFootballConference(NFC)championCarolinaPanthers'

orig_ns_to_s_map:

OrderedDict([(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 15), (14, 16), (15, 17), (16, 18), (17, 19), (18, 20), (19, 21), (20, 22), (21, 24), (22, 25), (23, 26), (24, 28), (25, 29), (26, 30), (27, 31), (28, 32), (29, 33), (30, 34), (31, 35), (32, 37), (33, 38), (34, 39), (35, 40), (36, 41), (37, 42), (38, 43), (39, 44), (40, 46), (41, 47), (42, 48), (43, 49), (44, 50), (45, 51), (46, 52), (47, 53), (48, 54), (49, 55), (50, 57), (51, 58), (52, 59), (53, 60), (54, 61), (55, 63), (56, 64), (57, 65), (58, 66), (59, 67), (60, 68), (61, 69), (62, 70), (63, 72), (64, 73), (65, 74), (66, 75), (67, 76), (68, 77), (69, 78), (70, 79), (71, 81), (72, 82), (73, 83), (74, 84), (75, 85), (76, 86), (77, 87), (78, 88)])

tok_ns_text:'denverbroncosdefeatedthenationalfootballconference(nfc)championcarolinapanthers'

tok_ns_to_s_map:

OrderedDict([(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 15), (14, 16), (15, 17), (16, 18), (17, 19), (18, 20), (19, 21), (20, 22), (21, 24), (22, 25), (23, 26), (24, 28), (25, 29), (26, 30), (27, 31), (28, 32), (29, 33), (30, 34), (31, 35), (32, 37), (33, 38), (34, 39), (35, 40), (36, 41), (37, 42), (38, 43), (39, 44), (40, 46), (41, 47), (42, 48), (43, 49), (44, 50), (45, 51), (46, 52), (47, 53), (48, 54), (49, 55), (50, 57), (51, 59), (52, 60), (53, 61), (54, 63), (55, 65), (56, 66), (57, 67), (58, 68), (59, 69), (60, 70), (61, 71), (62, 72), (63, 74), (64, 75), (65, 76), (66, 77), (67, 78), (68, 79), (69, 80), (70, 81), (71, 83), (72, 84), (73, 85), (74, 86), (75, 87), (76, 88), (77, 89), (78, 90)])

3)从tok的起始终止位置转成orig中的起始终止位置

tok_text的起始位置找到tok_ns_text的位置,然后根据tok_ns_text位置和orig_ns_text位置一一对应的规则,得出了起始位置在orig_ns_text中的位置,再根据orig_ns_to_s_map的映射规则得到起始位置在orig_text的位置。

  1. # We then project the characters in `pred_text` back to `orig_text` using
  2. # the character-to-character alignment.
  3. tok_s_to_ns_map = {}
  4. for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
  5. tok_s_to_ns_map[tok_index] = i
  6. orig_start_position = None
  7. if start_position in tok_s_to_ns_map:
  8. ns_start_position = tok_s_to_ns_map[start_position]
  9. if ns_start_position in orig_ns_to_s_map:
  10. orig_start_position = orig_ns_to_s_map[ns_start_position]
  11. if orig_start_position is None:
  12. if FLAGS.verbose_logging:
  13. tf.logging.info("Couldn't map start position")
  14. return orig_text
  15. orig_end_position = None
  16. if end_position in tok_s_to_ns_map:
  17. ns_end_position = tok_s_to_ns_map[end_position]
  18. if ns_end_position in orig_ns_to_s_map:
  19. orig_end_position = orig_ns_to_s_map[ns_end_position]
  20. if orig_end_position is None:
  21. if FLAGS.verbose_logging:
  22. tf.logging.info("Couldn't map end position")
  23. return orig_text

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号