赞
踩
上一篇博客介绍的损失部分就涉及训练的过程。
本篇介绍一下测试。按照上一篇博客介绍损失时,start_logits选取最大的概率值作为起始位置与真实起始位置比较,end_logits选取最大的概率值作为终止位置与真实终止位置比较。那么直观观念上测试只需要分别选取start_logits和end_logits的最大值,就能得到起始位置和终止位置。
(1)很多时候句子达不到设定的seq_length的长度,假如设定输入模型的整个句子的向量长度为384,但是实际问题+段落的长度才181,而预测的起始位置为230,预测得到的起始位置大于句子的实际长度,这显然不合理。
(2)预测得到的终止位置大于句子的实际长度,这显然不合理。
(3)由于输入模型的句子向量中,第一部分为问题,第二部分为段落。预测得到的起始位置落在第一部分问题的区域间断,这显然不合理。
(4)由于输入模型的句子向量中,第一部分为问题,第二部分为段落。预测得到的终止位置落在第一部分问题的区域间断,这显然不合理。
(5) 由于在【Bert】(十)简易问答系统--数据解析_mjiansun的专栏-CSDN博客中介绍过滑动窗切割段落的情况,是的部分重叠单词有了归属,也就是部分单词属于切割后的第一个句子,有些单词属于切割后的第二个句子,所以预测的起始位置的单词如果不属于本句话,那么也应该判定为不合理。
(6)如果start_logits最大概率所在位置A,end_logits最大概率所在位置B,A的位置在B的后面,这就不符合逻辑,起始位置怎么能在终止位置后面?
(7)一般的答案都会有一个长度,一般都不会太长,如果出现一个答案特别长,这显然也不合理。
选取一定数量的候选起始位置和终止位置,让每一个起始位置和终止位置进行排列组合,然后跳过出现上问题的组合,保留满足条件的组合即可。
- for (feature_index, feature) in enumerate(features):
- result = unique_id_to_result[feature.unique_id]
- start_indexes = _get_best_indexes(result.start_logits, n_best_size)#按照得分选取前n_best_size个对应的候选起始位置
- end_indexes = _get_best_indexes(result.end_logits, n_best_size)#按照得分选取前n_best_size个对应的候选终止位置
-
- for start_index in start_indexes:
- for end_index in end_indexes:
- # We could hypothetically create invalid predictions, e.g., predict
- # that the start of the span is in the question. We throw out all
- # invalid predictions.
- if start_index >= len(feature.tokens):
- continue
- if end_index >= len(feature.tokens):
- continue
- if start_index not in feature.token_to_orig_map:
- continue
- if end_index not in feature.token_to_orig_map:
- continue
- if not feature.token_is_max_context.get(start_index, False):
- continue
- if end_index < start_index:
- continue
- length = end_index - start_index + 1
- if length > max_answer_length:
- continue
- prelim_predictions.append(
- _PrelimPrediction(
- feature_index=feature_index,
- start_index=start_index,
- end_index=end_index,
- start_logit=result.start_logits[start_index],
- end_logit=result.end_logits[end_index]))
- prelim_predictions = sorted(
- prelim_predictions,
- key=lambda x: (x.start_logit + x.end_logit),
- reverse=True)
将起始位置和终止位置的得分综合起来作为该对组合的最终得分,并按照最终得分排序。
根据tokens和起始终止位置,可以得到tokens拼接出来的字符串tok_text。根据examples和起始终止位置,可以得到examples拼接出来的字符串orig_text。
通过get_final_text函数综合得到最终字符串final_text,在下面我们会介绍该函数,这里先讲逻辑。
通过n_best_size限制保存的数量,通过seen_predictions防止重复答案。
- for pred in prelim_predictions:
- if len(nbest) >= n_best_size:
- break
- feature = features[pred.feature_index]
- if pred.start_index > 0: # this is a non-null prediction
- tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]#获取经过tokenizer处理过后的tokens的答案字符串
- orig_doc_start = feature.token_to_orig_map[pred.start_index]
- orig_doc_end = feature.token_to_orig_map[pred.end_index]
- orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] #获取原始句子example中的答案字符串
- tok_text = " ".join(tok_tokens)
-
- # De-tokenize WordPieces that have been split off.之前分词添加的##需要还原到原来的词
- tok_text = tok_text.replace(" ##", "")
- tok_text = tok_text.replace("##", "")
-
- # Clean whitespace
- tok_text = tok_text.strip()
- tok_text = " ".join(tok_text.split())
- orig_text = " ".join(orig_tokens)
-
- final_text = get_final_text(tok_text, orig_text, do_lower_case)
- if final_text in seen_predictions:
- continue
-
- seen_predictions[final_text] = True
- else:
- final_text = ""
- seen_predictions[final_text] = True
-
- nbest.append(
- _NbestPrediction(
- text=final_text,
- start_logit=pred.start_logit,
- end_logit=pred.end_logit))
这里我讲解下get_final_text函数, 首先他注释中的例子我看懂了,但是与代码似乎不对应(是我理解错了?)。
1)找出起始和终止位置
将orig_text使用tokenizer进行处理,然后比较找出pred_text在orig_text的起始位置。
再根据pred_text的长度得到终止位置。
- tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
-
- tok_text = " ".join(tokenizer.tokenize(orig_text))
-
- start_position = tok_text.find(pred_text)
- if start_position == -1:
- if FLAGS.verbose_logging:
- tf.logging.info(
- "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
- return orig_text
- end_position = start_position + len(pred_text) - 1
2)剔除空格的影响
orig_ns_text:剔除空格后的字符串
orig_ns_to_s_map:字典,键表示字符在orig_ns_text中的位置,值表示字符在未剔除空格字符串的位置
tok_ns_text:剔除空格后的字符串
tok_ns_to_s_map:字典,键表示字符在orig_ns_text中的位置,值表示字符在未剔除空格字符串的位置
- def _strip_spaces(text):
- ns_chars = []
- ns_to_s_map = collections.OrderedDict()
- for (i, c) in enumerate(text):
- if c == " ":
- continue
- ns_to_s_map[len(ns_chars)] = i
- ns_chars.append(c)
- ns_text = "".join(ns_chars)
- return (ns_text, ns_to_s_map)
- (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
- (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
例如orig_ns_text:'DenverBroncosdefeatedtheNationalFootballConference(NFC)championCarolinaPanthers'
orig_ns_to_s_map:
OrderedDict([(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 15), (14, 16), (15, 17), (16, 18), (17, 19), (18, 20), (19, 21), (20, 22), (21, 24), (22, 25), (23, 26), (24, 28), (25, 29), (26, 30), (27, 31), (28, 32), (29, 33), (30, 34), (31, 35), (32, 37), (33, 38), (34, 39), (35, 40), (36, 41), (37, 42), (38, 43), (39, 44), (40, 46), (41, 47), (42, 48), (43, 49), (44, 50), (45, 51), (46, 52), (47, 53), (48, 54), (49, 55), (50, 57), (51, 58), (52, 59), (53, 60), (54, 61), (55, 63), (56, 64), (57, 65), (58, 66), (59, 67), (60, 68), (61, 69), (62, 70), (63, 72), (64, 73), (65, 74), (66, 75), (67, 76), (68, 77), (69, 78), (70, 79), (71, 81), (72, 82), (73, 83), (74, 84), (75, 85), (76, 86), (77, 87), (78, 88)])
tok_ns_text:'denverbroncosdefeatedthenationalfootballconference(nfc)championcarolinapanthers'
tok_ns_to_s_map:
OrderedDict([(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 15), (14, 16), (15, 17), (16, 18), (17, 19), (18, 20), (19, 21), (20, 22), (21, 24), (22, 25), (23, 26), (24, 28), (25, 29), (26, 30), (27, 31), (28, 32), (29, 33), (30, 34), (31, 35), (32, 37), (33, 38), (34, 39), (35, 40), (36, 41), (37, 42), (38, 43), (39, 44), (40, 46), (41, 47), (42, 48), (43, 49), (44, 50), (45, 51), (46, 52), (47, 53), (48, 54), (49, 55), (50, 57), (51, 59), (52, 60), (53, 61), (54, 63), (55, 65), (56, 66), (57, 67), (58, 68), (59, 69), (60, 70), (61, 71), (62, 72), (63, 74), (64, 75), (65, 76), (66, 77), (67, 78), (68, 79), (69, 80), (70, 81), (71, 83), (72, 84), (73, 85), (74, 86), (75, 87), (76, 88), (77, 89), (78, 90)])
3)从tok的起始终止位置转成orig中的起始终止位置
从tok_text的起始位置找到tok_ns_text的位置,然后根据tok_ns_text位置和orig_ns_text位置一一对应的规则,得出了起始位置在orig_ns_text中的位置,再根据orig_ns_to_s_map的映射规则得到起始位置在orig_text的位置。
- # We then project the characters in `pred_text` back to `orig_text` using
- # the character-to-character alignment.
- tok_s_to_ns_map = {}
- for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
- tok_s_to_ns_map[tok_index] = i
-
- orig_start_position = None
- if start_position in tok_s_to_ns_map:
- ns_start_position = tok_s_to_ns_map[start_position]
- if ns_start_position in orig_ns_to_s_map:
- orig_start_position = orig_ns_to_s_map[ns_start_position]
-
- if orig_start_position is None:
- if FLAGS.verbose_logging:
- tf.logging.info("Couldn't map start position")
- return orig_text
-
- orig_end_position = None
- if end_position in tok_s_to_ns_map:
- ns_end_position = tok_s_to_ns_map[end_position]
- if ns_end_position in orig_ns_to_s_map:
- orig_end_position = orig_ns_to_s_map[ns_end_position]
-
- if orig_end_position is None:
- if FLAGS.verbose_logging:
- tf.logging.info("Couldn't map end position")
- return orig_text
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。