当前位置:   article > 正文

NoRepeatNGramLogitsProcessor的_calc_banned_ngram_tokens

norepeatngramlogitsprocessor

#transformer.generation_logits_process NoRepeatNGramLogitsProcessor的_calc_banned_ngram_tokens目的是生成不重复的ngram

import torch
from typing import List, Iterable


def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
    return generated_ngrams


def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
    # Before decoding the next token, prevent decoding of ngrams that have already appeared
    start_idx = cur_len + 1 - ngram_size
    ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
    return banned_ngrams.get(ngram_idx, [])


def _calc_banned_ngram_tokens(
        ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
) -> List[Iterable[int]]:
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]

    generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)

    banned_tokens = [
        _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
        for hypo_idx in range(num_hypos)
    ]
    return banned_tokens


x = _get_ngrams(3, torch.LongTensor([[0, 5, 6, 0, 5, 6], [0, 6, 4, 3, 2, 1]]), 2)
print('x', x)

y = _calc_banned_ngram_tokens(3, torch.LongTensor([[0, 5, 6, 0, 5, 6], [0, 6, 4, 3, 2, 1]]), 2, 6)
print('y', y)
 

输出:

x [{(0, 5): [6, 6], (5, 6): [0], (6, 0): [5]}, {(0, 6): [4], (6, 4): [3], (4, 3): [2], (3, 2): [1]}]
y [[0], []] #[0, 5, 6, 0, 5, 6]序列后禁止输入0.

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/679109
推荐阅读
相关标签
  

闽ICP备14008679号