赞
踩
#transformer.generation_logits_process NoRepeatNGramLogitsProcessor的_calc_banned_ngram_tokens目的是生成不重复的ngram
import torch
from typing import List, Iterable
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
return generated_ngrams
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - ngram_size
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
return banned_ngrams.get(ngram_idx, [])
def _calc_banned_ngram_tokens(
ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
) -> List[Iterable[int]]:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
banned_tokens = [
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
for hypo_idx in range(num_hypos)
]
return banned_tokens
x = _get_ngrams(3, torch.LongTensor([[0, 5, 6, 0, 5, 6], [0, 6, 4, 3, 2, 1]]), 2)
print('x', x)
y = _calc_banned_ngram_tokens(3, torch.LongTensor([[0, 5, 6, 0, 5, 6], [0, 6, 4, 3, 2, 1]]), 2, 6)
print('y', y)
输出:
x [{(0, 5): [6, 6], (5, 6): [0], (6, 0): [5]}, {(0, 6): [4], (6, 4): [3], (4, 3): [2], (3, 2): [1]}]
y [[0], []] #[0, 5, 6, 0, 5, 6]序列后禁止输入0.
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。