赞
踩
对于NLG任务,在推理阶段可能经常会遇到“停不下来”的问题,即重复的token被反复预测出来。
例如,输入“Google”,翻译模型可能会翻译为“谷歌谷歌”。
这个问题已经有很多人研究很久了,在模型侧提出的应对方案也有很多,本文介绍最简便的一种处理方法,只需要添加一行代码,就可以有效地改善。
对于这种现象出现的原因,有很多相关的分析和介绍,其中苏神的这篇文章让我感到受益匪浅,从数学的角度分析了为什么会重复,非常建议大家读一下这篇文章。
其实在transformers
的源码中,以及预置了一个参数,用来控制对重复出现token的惩罚,思想非常朴素,最早应该是出现在CTRL的论文中:
https://arxiv.org/pdf/1909.05858.pdf
我们来看一下论文里是怎么描述的:
在生成的时候,就是在计算词表中词汇的概率嘛,如果我们不希望之前出现的token连续出现,那只要把出现过的token对应的得分,人为地降低就好了,也就是给它一个惩罚的力度,让它变小一点。
反应在代码中,就是transformers/generation_utils.py
中的GenerationMixin.generate
方法,其中的repetition_penalty
参数,就是用来控制这个惩罚的,也就是论文中的theta。
这个参数必须为大于0的浮点数,当取值为1.0时,相当于什么也没有做。如果在调用generate的时候给了这个参数,则会创建一个RepetitionPenaltyLogitsProcessor
,简单看一下这个Processor是如何运作的:
class RepetitionPenaltyLogitsProcessor(LogitsProcessor): r""" :class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences. Args: repetition_penalty (:obj:`float`): The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details. """ def __init__(self, penalty: float): if not isinstance(penalty, float) or not (penalty > 0): raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") self.penalty = penalty def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability score = torch.where(score < 0, score * self.penalty, score / self.penalty) scores.scatter_(1, input_ids, score) return scores
其中input_ids就是generate时,输入的input_ids, scores是每一步推理计算出来的为下一步提供的得分。简单来说,这个类就是根据输入序列的token id,把score里边对应位置的得分取出来,然后惩罚一下这些位置的得分,让它的得分变小,然后把惩罚过的分数,替换掉原来计算出来的得分。
还是以翻译模型为例,采用的模型是opus-mt-en-zh,实例化这个模型:
from transformers import AutoModelWithLMHead,AutoTokenizer
mode_name = 'liam168/trans-opus-mt-en-zh'
model = AutoModelWithLMHead.from_pretrained(mode_name)
tokenizer = AutoTokenizer.from_pretrained(mode_name)
翻译一个词:
text = 'Google'
batch = tokenizer.prepare_seq2seq_batch(src_texts=[text], return_tensors='pt', max_length=512)
translation = model.generate(**batch)
res = tokenizer.batch_decode(translation, skip_special_tokens=True)
翻译结果为“谷歌谷歌”。可以看到,当输入文本很短时,很容易就出现了重复。
而如果在generate的时候,增加一个参数:
text = 'Google'
batch = tokenizer.prepare_seq2seq_batch(src_texts=[text], return_tensors='pt', max_length=512)
batch['repetition_penalty'] = 1.2 # 论文中默认的参数1.2
translation = model.generate(**batch)
res = tokenizer.batch_decode(translation, skip_special_tokens=True)
翻译结果就变成了只有一个"谷歌"。
再大胆一点,如果把惩罚力度设置为无穷大,也会出问题。当设置惩罚为float('inf')
时,在翻译句子“Google has Google translate”的时候,就会变成“谷歌有Google翻译”,第二个Google就因为被惩罚了而没有翻译成谷歌,而如果惩罚为1.2,则翻译结果为“谷歌有谷歌翻译”。所以惩罚力度设置为多大,还需要自己把握一下。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。