当前位置:   article > 正文

NLP实践——文本生成中停不下来的问题_repetition_penalty

repetition_penalty

NLP实践——文本生成中停不下来的问题

1. 问题概述

对于NLG任务,在推理阶段可能经常会遇到“停不下来”的问题,即重复的token被反复预测出来。
例如,输入“Google”,翻译模型可能会翻译为“谷歌谷歌”。

这个问题已经有很多人研究很久了,在模型侧提出的应对方案也有很多,本文介绍最简便的一种处理方法,只需要添加一行代码,就可以有效地改善。

2. 造成的原因

对于这种现象出现的原因,有很多相关的分析和介绍,其中苏神的这篇文章让我感到受益匪浅,从数学的角度分析了为什么会重复,非常建议大家读一下这篇文章。

3. 解决的方法

其实在transformers的源码中,以及预置了一个参数,用来控制对重复出现token的惩罚,思想非常朴素,最早应该是出现在CTRL的论文中:
https://arxiv.org/pdf/1909.05858.pdf

我们来看一下论文里是怎么描述的:
ctrl
在生成的时候,就是在计算词表中词汇的概率嘛,如果我们不希望之前出现的token连续出现,那只要把出现过的token对应的得分,人为地降低就好了,也就是给它一个惩罚的力度,让它变小一点。

反应在代码中,就是transformers/generation_utils.py中的GenerationMixin.generate方法,其中的repetition_penalty参数,就是用来控制这个惩罚的,也就是论文中的theta。

这个参数必须为大于0的浮点数,当取值为1.0时,相当于什么也没有做。如果在调用generate的时候给了这个参数,则会创建一个RepetitionPenaltyLogitsProcessor,简单看一下这个Processor是如何运作的:

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
    r"""
    :class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.

    Args:
        repetition_penalty (:obj:`float`):
            The parameter for repetition penalty. 1.0 means no penalty. See `this paper
            <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
    """

    def __init__(self, penalty: float):
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

        self.penalty = penalty

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        score = torch.gather(scores, 1, input_ids)

        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
        score = torch.where(score < 0, score * self.penalty, score / self.penalty)

        scores.scatter_(1, input_ids, score)
        return scores
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

其中input_ids就是generate时,输入的input_ids, scores是每一步推理计算出来的为下一步提供的得分。简单来说,这个类就是根据输入序列的token id,把score里边对应位置的得分取出来,然后惩罚一下这些位置的得分,让它的得分变小,然后把惩罚过的分数,替换掉原来计算出来的得分。

4. 效果

还是以翻译模型为例,采用的模型是opus-mt-en-zh,实例化这个模型:

from transformers import AutoModelWithLMHead,AutoTokenizer
mode_name = 'liam168/trans-opus-mt-en-zh'
model = AutoModelWithLMHead.from_pretrained(mode_name)
tokenizer = AutoTokenizer.from_pretrained(mode_name)
  • 1
  • 2
  • 3
  • 4

翻译一个词:

text = 'Google'
batch = tokenizer.prepare_seq2seq_batch(src_texts=[text], return_tensors='pt', max_length=512)
translation = model.generate(**batch)
res = tokenizer.batch_decode(translation, skip_special_tokens=True)
  • 1
  • 2
  • 3
  • 4

翻译结果为“谷歌谷歌”。可以看到,当输入文本很短时,很容易就出现了重复。

而如果在generate的时候,增加一个参数:

text = 'Google'
batch = tokenizer.prepare_seq2seq_batch(src_texts=[text], return_tensors='pt', max_length=512)
batch['repetition_penalty'] = 1.2   # 论文中默认的参数1.2
translation = model.generate(**batch)
res = tokenizer.batch_decode(translation, skip_special_tokens=True)
  • 1
  • 2
  • 3
  • 4
  • 5

翻译结果就变成了只有一个"谷歌"。

再大胆一点,如果把惩罚力度设置为无穷大,也会出问题。当设置惩罚为float('inf')时,在翻译句子“Google has Google translate”的时候,就会变成“谷歌有Google翻译”,第二个Google就因为被惩罚了而没有翻译成谷歌,而如果惩罚为1.2,则翻译结果为“谷歌有谷歌翻译”。所以惩罚力度设置为多大,还需要自己把握一下。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/372175
推荐阅读
相关标签
  

闽ICP备14008679号