当前位置:   article > 正文

transformers.generator_utils函数源码解析之RepetitionPenaltyLogitsProcessor

repetitionpenaltylogitsprocessor

主要记录源码中解决文本生成中词组重复出现的问题,代码中有具体操作解析。

  1. class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
  2. r"""
  3. :class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.
  4. Args:
  5. repetition_penalty (:obj:`float`):
  6. The parameter for repetition penalty. 1.0 means no penalty. See `this paper
  7. <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
  8. """
  9. def __init__(self, penalty: float):
  10. if not isinstance(penalty, float) or not (penalty > 0):
  11. raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
  12. self.penalty = penalty
  13. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
  14. #scores为cur-step的词表分布[batch,seq,vocab_size],input_ids为输入decoder的文本序列[batch,seq],则score则是获取当前已经生成文本序列的token概率
  15. score = torch.gather(scores, 1, input_ids)
  16. # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
  17. #减少已经出现的token的概率
  18. score = torch.where(score < 0, score * self.penalty, score / self.penalty)
  19. #将减少后的概率重分配到原始的cur-step词表分布中
  20. scores.scatter_(1, input_ids, score)
  21. return scores

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/519737
推荐阅读
相关标签
  

闽ICP备14008679号