当前位置:   article > 正文

NLP.TM[33] | 纠错:pycorrector的错误检测

取n元文法得分

【NLP.TM】

本人有关自然语言处理和文本挖掘方面的学习和笔记,欢迎大家关注。

往期回顾

纠错是NLP中的一个看着不是很火但其实在现实应用中非常重要的一个部分,在一个强NLP以来的项目(如搜索)发展至中期,纠错就会成为一个效果提升的新增长点,经过统计,在微博等新媒体领域中,文本出错概率在2%左右,在语音识别领域中,出错率最高可达8-10%(数据来自:https://zhuanlan.zhihu.com/p/159101860),从这个比例来看,如果能修正这些错误,对效果的提升无疑是巨大的,那么我们来看看,纠错任务是怎么做的。

文章较长,懒人目录再现:

  • pycorrector简介

  • pycorrector的纠错思路

  • 混淆词典

  • 未登录词检测

  • 语言模型

  • 结果输出

  • 小结

pycorrector简介

pycorrector是非常基础的纠错模块工具,里面已经实现了一些非常通用的纠错方法,用里面的方法来做基线其实其实非常方便。

连接先放在这里:https://github.com/shibing624/pycorrector

他的使用方法其实也比较简单:

  1. import pycorrector
  2. corrected_sent, detail = pycorrector.correct('少先队员因该为老人让坐')
  3. print(corrected_sent, detail)

这是一个非常简单的官方case,详情还是可以去github里面去看看。

pycorrect的纠错思路

其实pycorrect里面造了很多飞机,不过实质上正式使用的还是非常经典的方法,来看看它的主函数具体思路是什么样的。

  1. def correct(self, text, include_symbol=True, num_fragment=1, threshold=57, **kwargs):
  2.     """
  3.     句子改错
  4.     :param text: str, query 文本
  5.     :param include_symbol: bool, 是否包含标点符号
  6.     :param num_fragment: 纠错候选集分段数, 1 / (num_fragment + 1)
  7.     :param threshold: 语言模型纠错ppl阈值
  8.     :param kwargs: ...
  9.     :return: text (str)改正后的句子, list(wrong, right, begin_idx, end_idx)
  10.     """
  11.     text_new = ''
  12.     details = []
  13.     self.check_corrector_initialized()
  14.     # 编码统一,utf-8 to unicode
  15.     text = convert_to_unicode(text)
  16.     # 长句切分为短句
  17.     blocks = self.split_2_short_text(text, include_symbol=include_symbol)
  18.     for blk, idx in blocks:
  19.         maybe_errors = self.detect_short(blk, idx)
  20.         for cur_item, begin_idx, end_idx, err_type in maybe_errors:
  21.             # 纠错,逐个处理
  22.             before_sent = blk[:(begin_idx - idx)]
  23.             after_sent = blk[(end_idx - idx):]
  24.             # 困惑集中指定的词,直接取结果
  25.             if err_type == ErrorType.confusion:
  26.                 corrected_item = self.custom_confusion[cur_item]
  27.             else:
  28.                 # 取得所有可能正确的词
  29.                 candidates = self.generate_items(cur_item, fragment=num_fragment)
  30.                 if not candidates:
  31.                     continue
  32.                 corrected_item = self.get_lm_correct_item(cur_item, candidates, before_sent, after_sent,
  33.                                                           threshold=threshold)
  34.             # output
  35.             if corrected_item != cur_item:
  36.                 blk = before_sent + corrected_item + after_sent
  37.                 detail_word = [cur_item, corrected_item, begin_idx, end_idx]
  38.                 details.append(detail_word)
  39.         text_new += blk
  40.     details = sorted(details, key=operator.itemgetter(2))
  41.     return text_new, details

这里面其实还是比较明确的:

  • 分句。一个长句分成多个断句。

  • 对每个短句进行错误检测detect_short

  • 错误点召回可能正确的词。

  • 召回后筛选最佳结果。

在这个框架下,来看看具体pycorrect的错误检测是怎么做的。

混淆词典

直接看源码:

  1. # 自定义混淆集加入疑似错误词典
  2. for confuse in self.custom_confusion:
  3.     idx = sentence.find(confuse)
  4.     if idx > -1:
  5.         maybe_err = [confuse, idx + start_idx, idx + len(confuse) + start_idx, ErrorType.confusion]
  6.         self._add_maybe_error_item(maybe_err, maybe_errors)

这块其实还是比较简单的,其实就是用户自定义了一个词典,这个词典作者叫做混淆词典,我更愿意叫做改写词典,遇到了key,就去找v,直接做这种改写。

不过个人感觉这种遍历整个整个词典然后find的方法复杂度可能比较高,如果是我我还是比较喜欢最大逆向匹配的方式来查字典。

未登录词检测

同样上代码:

  1. if self.is_word_error_detect:
  2.     # 切词
  3.     tokens = self.tokenizer.tokenize(sentence)
  4.     # 未登录词加入疑似错误词典
  5.     for token, begin_idx, end_idx in tokens:
  6.         # pass filter word
  7.         if self.is_filter_token(token):
  8.             continue
  9.         # pass in dict
  10.         if token in self.word_freq:
  11.             continue
  12.         maybe_err = [token, begin_idx + start_idx, end_idx + start_idx, ErrorType.word]
  13.         self._add_maybe_error_item(maybe_err, maybe_errors)

注释其实还是非常友好的,其实就这几个步骤:

  • 切词。

  • 跳过特定词汇的检测。

  • 查字典看是否有低频词(未登录词)出现。

  • 结果整理。

首先就是切词,这里的切词是一个函数,我们也来看看他具体切词是怎么切的:

  1. class Tokenizer(object):
  2.     def __init__(self, dict_path='', custom_word_freq_dict=None, custom_confusion_dict=None):
  3.         self.model = jieba
  4.         self.model.default_logger.setLevel(logging.ERROR)
  5.         # 初始化大词典
  6.         if os.path.exists(dict_path):
  7.             self.model.set_dictionary(dict_path)
  8.         # 加载用户自定义词典
  9.         if custom_word_freq_dict:
  10.             for w, f in custom_word_freq_dict.items():
  11.                 self.model.add_word(w, freq=f)
  12.         # 加载混淆集词典、
  13.         if custom_confusion_dict:
  14.             for k, word in custom_confusion_dict.items():
  15.                 # 添加到分词器的自定义词典中
  16.                 self.model.add_word(k)
  17.                 self.model.add_word(word)
  18.     def tokenize(self, unicode_sentence, mode="search"):
  19.         """
  20.         切词并返回切词位置, search mode用于错误扩召回
  21.         :param unicode_sentence: query
  22.         :param mode: search, default, ngram
  23.         :param HMM: enable HMM
  24.         :return: (w, start, start + width) model='default'
  25.         """
  26.         if mode == 'ngram':
  27.             n = 2
  28.             result_set = set()
  29.             tokens = self.model.lcut(unicode_sentence)
  30.             tokens_len = len(tokens)
  31.             start = 0
  32.             for i in range(0, tokens_len):
  33.                 w = tokens[i]
  34.                 width = len(w)
  35.                 result_set.add((w, start, start + width))
  36.                 for j in range(i, i + n):
  37.                     gram = "".join(tokens[i:j + 1])
  38.                     gram_width = len(gram)
  39.                     if i + j > tokens_len:
  40.                         break
  41.                     result_set.add((gram, start, start + gram_width))
  42.                 start += width
  43.             results = list(result_set)
  44.             result = sorted(results, key=lambda x: x[-1])
  45.         else:
  46.             result = list(self.model.tokenize(unicode_sentence, mode=mode))
  47.         return result

看着很高端,稍微看看源码其实就可以发现用的是以jieba为基础的操作,只不过多了一种n-gram切词而已,其实就是切词以后按照n-gram拼装而已。

切完词后,就是过滤一些不需要检测的词汇,主要是一些数字之类的,来看看具体有哪些:

  1. @staticmethod
  2. def is_filter_token(token):
  3.     result = False
  4.     # pass blank
  5.     if not token.strip():
  6.         result = True
  7.     # pass num
  8.     if token.isdigit():
  9.         result = True
  10.     # pass alpha
  11.     if is_alphabet_string(token.lower()):
  12.         result = True
  13.     # pass not chinese
  14.     if not is_chinese_string(token):
  15.         result = True
  16.     return result
  • 空字符串

  • 数字

  • 字母

  • 非中文

然后就是判断是否是低频词,这个就比较容易,他是构建了一个词典,直接判断是否在里面就好了。

语言模型

NLP领域最基础的东西就要数语言模型了,这里的假设其实是人输入的语言大都是常用的,如果出现了不太常用的东西,其实说明是有错的,带着这个假设,我们来看看利用这个方法是怎么判错的。

  1. # 语言模型检测疑似错误字
  2. try:
  3.     ngram_avg_scores = []
  4.     for n in [23]:
  5.         scores = []
  6.         for i in range(len(sentence) - n + 1):
  7.             word = sentence[i:i + n]
  8.             score = self.ngram_score(list(word))
  9.             scores.append(score)
  10.         if not scores:
  11.             continue
  12.         # 移动窗口补全得分
  13.         for _ in range(n - 1):
  14.             scores.insert(0, scores[0])
  15.             scores.append(scores[-1])
  16.         avg_scores = [sum(scores[i:i + n]) / len(scores[i:i + n]) for i in range(len(sentence))]
  17.         ngram_avg_scores.append(avg_scores)
  18.     if ngram_avg_scores:
  19.         # 取拼接后的n-gram平均得分
  20.         sent_scores = list(np.average(np.array(ngram_avg_scores), axis=0))
  21.         # 取疑似错字信息
  22.         for i in self._get_maybe_error_index(sent_scores):
  23.             token = sentence[i]
  24.             # pass filter word
  25.             if self.is_filter_token(token):
  26.                 continue
  27.             # pass in stop word dict
  28.             if token in self.stopwords:
  29.                 continue
  30.             # token, begin_idx, end_idx, error_type
  31.             maybe_err = [token, i + start_idx, i + start_idx + 1,
  32.                          ErrorType.char]
  33.             self._add_maybe_error_item(maybe_err, maybe_errors)
  34. except IndexError as ie:
  35.     logger.warn("index error, sentence:" + sentence + str(ie))
  36. except Exception as e:
  37.     logger.warn("detect error, sentence:" + sentence + str(e))

首先这个是基于字来判断的,所以不需要切词,直接把字符串一个一个的拼接成n-gram即可。

要分析整个句子中每个位点字合理,是需要看上下文的,这里分别采用了2-gram和3-gram进行了分析,分别计算了一个叫做ngram_score的东西,具体是这样的:

  1. def ngram_score(self, chars):
  2.     """
  3.     取n元文法得分
  4.     :param chars: list, 以词或字切分
  5.     :return:
  6.     """
  7.     self.check_detector_initialized()
  8.     return self.lm.score(' '.join(chars), bos=False, eos=False)

这里使用的是kenlm来训练的语言模型,然后用score进行得分计算,这个得分实质上就是分析这个句子组合产生的可能性,概率当然就是在 之间了,然后取对数,因此这个得分就是一个非正数了,越接近0,说明这个组合出现的可能性越大,越不可能有错了。

另外,为了保证整个句子的完整性,是需要padding的,代码里做了一个移动窗口的处理,直接看可能有些难懂,但是知道了padding,应该会好明白一些:

  1. # 移动窗口补全得分
  2. for _ in range(n - 1):
  3.     scores.insert(0, scores[0])
  4.     scores.append(scores[-1])

然后就对分数进行根据句子长度的均值计算,计算完之后分别保存了每个字的2-gram得分和3-gram得分,然后后续取了这两个分数的均值,这里的代码这么看:

  1. avg_scores = [sum(scores[i:i + n]) / len(scores[i:i + n]) for i in range(len(sentence))]
  2. ngram_avg_scores.append(avg_scores)
  3. if ngram_avg_scores:
  4.     # 取拼接后的n-gram平均得分
  5.     sent_scores = list(np.average(np.array(ngram_avg_scores), axis=0))

然后就会开始对这个分数进行分析,最终抽取可能有问题的位点,使用的函数就是_get_maybe_error_index

  1. @staticmethod
  2. def _get_maybe_error_index(scores, ratio=0.6745, threshold=2):
  3.     """
  4.     取疑似错字的位置,通过平均绝对离差(MAD)
  5.     :param scores: np.array
  6.     :param ratio: 正态分布表参数
  7.     :param threshold: 阈值越小,得到疑似错别字越多
  8.     :return: 全部疑似错误字的index: list
  9.     """
  10.     result = []
  11.     scores = np.array(scores)
  12.     if len(scores.shape) == 1:
  13.         scores = scores[:, None]
  14.     median = np.median(scores, axis=0)  # get median of all scores
  15.     margin_median = np.abs(scores - median).flatten()  # deviation from the median
  16.     # 平均绝对离差值
  17.     med_abs_deviation = np.median(margin_median)
  18.     if med_abs_deviation == 0:
  19.         return result
  20.     y_score = ratio * margin_median / med_abs_deviation
  21.     # 打平
  22.     scores = scores.flatten()
  23.     maybe_error_indices = np.where((y_score > threshold) & (scores < median))
  24.     # 取全部疑似错误字的index
  25.     result = list(maybe_error_indices[0])
  26.     return result

思路其实大概说了,就是基于平均离差来算,这其实就是常用异常检测的MAD。说白了就是整个句子,大部分情况是不会出错的,正常情况下打分就会在特定的一个范围内,但是出错的位置的打分会距离这个打分很远(可以理解为和常规语境和语言水平差别很大),我们需要把这几个打分比较远的对应位置提取出来。

另外这里蛮有意思的是,可以看到作者对numpy比较熟悉,可以看看里面这些操作。

结果输出

然后就是一些整理结果输出的操作了,基本的数据处理还是比较容易的,直接看看最终的输出格式吧

  1. import pycorrector
  2. idx_errors = pycorrector.detect('少先队员因该为老人让坐')
  3. print(idx_errors)
  4. # 输出:[['因该', 4, 6, 'word'], ['坐', 10, 11, 'char']]

会把他定的位置和错误类型给指出来,最终只需要整理出这个格式就行。

小结

这里给大家介绍的是pycorrector内baseline的检测方法,让大家理解最基本的错误识别方式。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/347829?site
推荐阅读
相关标签
  

闽ICP备14008679号