当前位置:   article > 正文

transformers的beam_search_transformers beamsearch

transformers beamsearch
  1. """
  2. transformer的generation_beam_search.py中beam_search过程
  3. 当decoder的输入是[N,1],N为batch_size,设置beams=k,将输入转化为[N*k,1]
  4. 输入放入到decoder中生成了logits,形状为[N*k,T],T为总的token数
  5. logits和历史beam_score相加成为新的beam_score,进行topk排序,获取next_beam_scores、next_beam_index、next_beam_tokens
  6. beam_hyps存储过程:通过上述next_beam_*,判断next_token是否是<eos>,是则存,不是则仍然挑选出beams=k个next_beam进行下一次decoder
  7. 代码实现基于一个数,生成一组连续的数,遇到末尾数为9则终止。
  8. """

none

  1. import torch
  2. from typing import *
  3. from abc import ABC, abstractmethod
  4. from collections import UserDict
  5. import torch
  6. #from .file_utils import add_start_docstrings
  7. class BeamScorer(ABC):
  8. """
  9. Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and
  10. :meth:`~transformers.PretrainedModel.beam_sample`.
  11. """
  12. @abstractmethod
  13. #@add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
  14. def process(
  15. self,
  16. input_ids: torch.LongTensor,
  17. next_scores: torch.FloatTensor,
  18. next_tokens: torch.LongTensor,
  19. next_indices: torch.LongTensor,
  20. **kwargs
  21. ) -> Tuple[torch.Tensor]:
  22. raise NotImplementedError("This is an abstract method.")
  23. @abstractmethod
  24. #@add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
  25. def finalize(
  26. self,
  27. input_ids: torch.LongTensor,
  28. next_scores: torch.FloatTensor,
  29. next_tokens: torch.LongTensor,
  30. next_indices: torch.LongTensor,
  31. **kwargs
  32. ) -> torch.LongTensor:
  33. raise NotImplementedError("This is an abstract method.")
  34. class BeamSearchScorer(BeamScorer):
  35. def __init__(
  36. self,
  37. batch_size: int,
  38. max_length: int,
  39. num_beams: int,
  40. device: torch.device,
  41. length_penalty: Optional[float] = 1.0,
  42. do_early_stopping: Optional[bool] = False,
  43. num_beam_hyps_to_keep: Optional[int] = 1,
  44. num_beam_groups: Optional[int] = 1
  45. ):
  46. self.max_length = max_length
  47. self.num_beams = num_beams
  48. self.device = device
  49. self.length_penalty = length_penalty
  50. self.do_early_stopping = do_early_stopping
  51. self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
  52. self.num_beam_groups = num_beam_groups
  53. self.group_size = self.num_beams // self.num_beam_groups
  54. self._is_init = False
  55. self._beam_hyps = [
  56. BeamHypotheses(
  57. num_beams=self.num_beams,
  58. max_length=self.max_length,
  59. length_penalty=self.length_penalty,
  60. early_stopping=self.do_early_stopping,
  61. )
  62. for _ in range(batch_size)
  63. ]
  64. self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
  65. if not isinstance(num_beams, int) or num_beams <= 1:
  66. raise ValueError(
  67. f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
  68. )
  69. if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
  70. raise ValueError(
  71. f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` "
  72. f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
  73. )
  74. @property
  75. def is_done(self) -> bool:
  76. return self._done.all()
  77. def process(self,
  78. input_ids: torch.LongTensor,
  79. next_scores: torch.FloatTensor,
  80. next_tokens: torch.LongTensor,
  81. next_indices: torch.LongTensor,
  82. pad_token_id: Optional[int] = None,
  83. eos_token_id: Optional[int] = None,
  84. ) -> Tuple[torch.Tensor]:
  85. cur_len = input_ids.shape[-1]
  86. batch_size = len(self._beam_hyps)
  87. assert batch_size == (input_ids.shape[0] // self.group_size)
  88. device = input_ids.device
  89. next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
  90. next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
  91. next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
  92. for batch_idx, beam_hyp in enumerate(self._beam_hyps):
  93. if self._done[batch_idx]:#生成的序列彻底完成情况下,依然设置next_beam_*
  94. assert (
  95. len(beam_hyp) >= self.num_beams
  96. ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams)
  97. assert (
  98. eos_token_id is not None and pad_token_id is not None
  99. ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
  100. # pad the batch
  101. next_beam_scores[batch_idx, :] = 0
  102. next_beam_tokens[batch_idx, :] = pad_token_id
  103. next_beam_indices[batch_idx, :] = 0
  104. continue
  105. # next tokens for this sentence
  106. beam_idx = 0
  107. for beam_token_rank, (next_token, next_score, next_index) in enumerate(
  108. zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
  109. ):
  110. batch_beam_idx = batch_idx * self.group_size + next_index
  111. # add to generated hypotheses if end of sentence
  112. if (eos_token_id is not None) and (next_token.item() == eos_token_id):
  113. # if beam_token does not belong to top num_beams tokens, it should not be added
  114. is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
  115. if is_beam_token_worse_than_top_num_beams:
  116. continue
  117. beam_hyp.add(
  118. input_ids[batch_beam_idx].clone(),
  119. next_score.item(),
  120. )
  121. else:
  122. # add next predicted token since it is not eos_token
  123. next_beam_scores[batch_idx, beam_idx] = next_score
  124. next_beam_tokens[batch_idx, beam_idx] = next_token
  125. next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
  126. beam_idx += 1
  127. # once the beam for next step is full, don't add more tokens to it.
  128. if beam_idx == self.group_size:
  129. break
  130. if beam_idx < self.group_size:
  131. raise ValueError(
  132. f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
  133. )
  134. # Check if we are done so that we can save a pad step if all(done)
  135. self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
  136. next_scores[batch_idx].max().item(), cur_len
  137. )
  138. return UserDict(
  139. {
  140. "next_beam_scores": next_beam_scores.view(-1),
  141. "next_beam_tokens": next_beam_tokens.view(-1),
  142. "next_beam_indices": next_beam_indices.view(-1),
  143. }
  144. )
  145. def finalize(
  146. self,
  147. input_ids: torch.LongTensor,
  148. final_beam_scores: torch.FloatTensor,
  149. final_beam_tokens: torch.LongTensor,
  150. final_beam_indices: torch.LongTensor,
  151. pad_token_id: Optional[int] = None,
  152. eos_token_id: Optional[int] = None,
  153. ) -> Tuple[torch.LongTensor]:
  154. batch_size = len(self._beam_hyps)
  155. # finalize all open beam hypotheses and add to generated hypotheses
  156. for batch_idx, beam_hyp in enumerate(self._beam_hyps):
  157. if self._done[batch_idx]:
  158. continue
  159. # all open beam hypotheses are added to the beam hypothesis
  160. # beam hypothesis class automatically keeps the best beams
  161. for beam_id in range(self.num_beams):
  162. batch_beam_idx = batch_idx * self.num_beams + beam_id
  163. final_score = final_beam_scores[batch_beam_idx].item()
  164. final_tokens = input_ids[batch_beam_idx]
  165. beam_hyp.add(final_tokens, final_score)
  166. # select the best hypotheses
  167. sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
  168. best = []
  169. best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
  170. # retrieve best hypotheses
  171. for i, beam_hyp in enumerate(self._beam_hyps):
  172. sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
  173. for j in range(self.num_beam_hyps_to_keep):
  174. best_hyp_tuple = sorted_hyps.pop()
  175. best_score = best_hyp_tuple[0]
  176. best_hyp = best_hyp_tuple[1]
  177. sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
  178. # append to lists
  179. best.append(best_hyp)
  180. best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
  181. # prepare for adding eos
  182. sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
  183. decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
  184. # shorter batches are padded if needed
  185. if sent_lengths.min().item() != sent_lengths.max().item():
  186. assert pad_token_id is not None, "`pad_token_id` has to be defined"
  187. decoded.fill_(pad_token_id)
  188. # fill with hypotheses and eos_token_id if the latter fits in
  189. for i, hypo in enumerate(best):
  190. decoded[i, : sent_lengths[i]] = hypo
  191. if sent_lengths[i] < self.max_length:
  192. decoded[i, sent_lengths[i]] = eos_token_id
  193. return UserDict(
  194. {
  195. "sequences": decoded,
  196. "sequence_scores": best_scores,
  197. }
  198. )
  199. class BeamHypotheses:
  200. def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
  201. """
  202. Initialize n-best list of hypotheses.
  203. """
  204. self.max_length = max_length - 1 # ignoring bos_token
  205. self.length_penalty = length_penalty
  206. self.early_stopping = early_stopping
  207. self.num_beams = num_beams
  208. self.beams = []
  209. self.worst_score = 1e9
  210. def __len__(self):
  211. """
  212. Number of hypotheses in the list.
  213. """
  214. return len(self.beams)
  215. def add(self, hyp: torch.LongTensor, sum_logprobs: float):
  216. """
  217. Add a new hypothesis to the list.
  218. """
  219. score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
  220. if len(self) < self.num_beams or score > self.worst_score:
  221. self.beams.append((score, hyp))
  222. if len(self) > self.num_beams:
  223. sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
  224. del self.beams[sorted_next_scores[0][1]]
  225. self.worst_score = sorted_next_scores[1][0]
  226. else:
  227. self.worst_score = min(score, self.worst_score)
  228. def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
  229. """
  230. If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
  231. one in the heap, then we are done with this sentence.
  232. """
  233. if len(self) < self.num_beams:
  234. return False
  235. elif self.early_stopping:
  236. return True
  237. else:
  238. cur_score = best_sum_logprobs / cur_len ** self.length_penalty
  239. ret = self.worst_score >= cur_score
  240. return ret
  241. class ToyDecoder():
  242. #@torch.no_grad()
  243. def generate(self,
  244. input_ids: Optional[torch.LongTensor] = None,
  245. max_length: Optional[int] = None,
  246. min_length: Optional[int] = None,
  247. do_sample: Optional[bool] = None,
  248. early_stopping: Optional[bool] = None,
  249. num_beams: Optional[int] = None,
  250. temperature: Optional[float] = None,
  251. top_k: Optional[int] = None,
  252. top_p: Optional[float] = None,
  253. repetition_penalty: Optional[float] = None,
  254. bad_words_ids: Optional[Iterable[int]] = None,
  255. bos_token_id: Optional[int] = None,
  256. pad_token_id: Optional[int] = None,
  257. eos_token_id: Optional[int] = None,
  258. length_penalty: Optional[float] = None,
  259. no_repeat_ngram_size: Optional[int] = None,
  260. encoder_no_repeat_ngram_size: Optional[int] = None,
  261. num_return_sequences: Optional[int] = None,
  262. decoder_start_token_id: Optional[int] = None,
  263. use_cache: Optional[bool] = None,
  264. num_beam_groups: Optional[int] = None,
  265. diversity_penalty: Optional[float] = None,
  266. prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
  267. output_attentions: Optional[bool] = None,
  268. output_hidden_states: Optional[bool] = None,
  269. output_scores: Optional[bool] = None,
  270. return_dict_in_generate: Optional[bool] = None,
  271. **model_kwargs,
  272. ) -> Union[torch.LongTensor]:
  273. model_kwargs["output_attentions"] = output_attentions
  274. model_kwargs["output_hidden_states"] = output_hidden_states
  275. # set input_ids as decoder_input_ids
  276. if "decoder_input_ids" in model_kwargs:
  277. input_ids = model_kwargs.pop("decoder_input_ids")
  278. else:
  279. input_ids = self._prepare_decoder_input_ids_for_generation(
  280. input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id
  281. )
  282. logits_processor = self._get_logits_processor(
  283. repetition_penalty=repetition_penalty,
  284. no_repeat_ngram_size=no_repeat_ngram_size,
  285. encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
  286. encoder_input_ids=input_ids,#encoder_input_ids
  287. bad_words_ids=bad_words_ids,
  288. min_length=min_length,
  289. eos_token_id=eos_token_id,
  290. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  291. num_beams=num_beams,
  292. num_beam_groups=num_beam_groups,
  293. diversity_penalty=diversity_penalty,
  294. )
  295. is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False
  296. if is_beam_gen_mode:
  297. batch_size = input_ids.shape[0]
  298. length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
  299. early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
  300. if num_return_sequences > num_beams:
  301. raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
  302. self.device = input_ids.device
  303. beam_scorer = BeamSearchScorer(
  304. batch_size=batch_size,
  305. max_length=max_length,
  306. num_beams=num_beams,
  307. device=self.device,
  308. length_penalty=length_penalty,
  309. do_early_stopping=early_stopping,
  310. num_beam_hyps_to_keep=num_return_sequences,
  311. )
  312. input_ids, model_kwargs = self._expand_inputs_for_generation(
  313. input_ids, expand_size=num_beams, is_encoder_decoder=True, **model_kwargs
  314. )
  315. return self.beam_search(
  316. input_ids,
  317. beam_scorer,
  318. logits_processor=logits_processor,
  319. max_length=max_length,
  320. pad_token_id=pad_token_id,
  321. eos_token_id=eos_token_id,
  322. output_scores=output_scores,
  323. return_dict_in_generate=return_dict_in_generate,
  324. **model_kwargs,
  325. )
  326. def _prepare_decoder_input_ids_for_generation(
  327. self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None
  328. ) -> torch.LongTensor:
  329. #取输入的最后一个字作为输入
  330. decoder_input_ids = input_ids[:,-1].unsqueeze(-1)
  331. return decoder_input_ids
  332. @staticmethod
  333. def _expand_inputs_for_generation(
  334. input_ids: torch.LongTensor,
  335. expand_size: int = 1,
  336. is_encoder_decoder: bool = False,
  337. attention_mask: torch.LongTensor = None,
  338. #encoder_outputs: ModelOutput = None,
  339. **model_kwargs,
  340. ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
  341. expanded_return_idx = (
  342. torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
  343. )
  344. input_ids = input_ids.index_select(0,expanded_return_idx)
  345. if "token_type_ids" in model_kwargs:
  346. token_type_ids = model_kwargs["token_type_ids"]
  347. model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
  348. if attention_mask is not None:
  349. model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
  350. if is_encoder_decoder:
  351. pass
  352. # assert encoder_outputs is not None
  353. # encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
  354. # 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
  355. # )
  356. # model_kwargs["encoder_outputs"] = encoder_outputs
  357. return input_ids, model_kwargs
  358. def _get_logits_processor(
  359. self,
  360. repetition_penalty: float,
  361. no_repeat_ngram_size: int,
  362. encoder_no_repeat_ngram_size: int,
  363. encoder_input_ids: torch.LongTensor,
  364. bad_words_ids: List[List[int]],
  365. min_length: int,
  366. eos_token_id: int,
  367. prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
  368. num_beams: int,
  369. num_beam_groups: int,
  370. diversity_penalty: float,
  371. ) :
  372. return None
  373. def beam_search(
  374. self,
  375. input_ids: torch.LongTensor,
  376. beam_scorer: BeamScorer,
  377. logits_processor: Optional[List] = None,
  378. max_length: Optional[int] = None,
  379. pad_token_id: Optional[int] = None,
  380. eos_token_id: Optional[int] = None,
  381. output_attentions: Optional[bool] = None,
  382. output_hidden_states: Optional[bool] = None,
  383. output_scores: Optional[bool] = None,
  384. return_dict_in_generate: Optional[bool] = None,
  385. **model_kwargs,
  386. ) -> Union[torch.LongTensor]:
  387. # init attention / hidden states / scores tuples
  388. scores = () if (return_dict_in_generate and output_scores) else None
  389. decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
  390. decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
  391. # if return_dict_in_generate and self.config.is_encoder_decoder:
  392. # encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
  393. # encoder_hidden_states = (
  394. # model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
  395. # )
  396. batch_size = len(beam_scorer._beam_hyps)
  397. num_beams = beam_scorer.num_beams
  398. batch_beam_size, cur_len = input_ids.shape
  399. assert (
  400. num_beams * batch_size == batch_beam_size
  401. ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
  402. beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
  403. beam_scores[:,1:] = -1e9 #这个是针对起始位置是同一个字符,比如<eos>,<bos>情况设置的,目的是这样避免topk的值是一样的。
  404. beam_scores = beam_scores.view((batch_size * num_beams,))
  405. while cur_len < max_length:
  406. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  407. outputs = self(
  408. **model_inputs,
  409. return_dict=True,
  410. output_attentions=output_attentions,
  411. output_hidden_states=output_hidden_states,
  412. )
  413. #next_token_logits = outputs.logits[:, -1, :]
  414. next_token_logits = outputs['logits'][:, -1, :]
  415. next_token_logits = self.adjust_logits_during_generation(
  416. next_token_logits, cur_len=cur_len, max_length=max_length
  417. )
  418. next_token_scores = next_token_logits/100
  419. next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
  420. vocab_size = next_token_scores.shape[-1]
  421. next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
  422. next_token_scores, next_tokens = torch.topk(
  423. next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
  424. )
  425. next_indices = next_tokens // vocab_size
  426. next_tokens = next_tokens % vocab_size
  427. # stateless
  428. beam_outputs = beam_scorer.process(
  429. input_ids,
  430. next_token_scores,
  431. next_tokens,
  432. next_indices,
  433. pad_token_id=pad_token_id,
  434. eos_token_id=eos_token_id,
  435. )
  436. beam_scores = beam_outputs["next_beam_scores"]
  437. beam_next_tokens = beam_outputs["next_beam_tokens"]
  438. beam_idx = beam_outputs["next_beam_indices"]
  439. input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
  440. cur_len = cur_len + 1
  441. # model_kwargs = self._update_model_kwargs_for_generation(
  442. # outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
  443. # )
  444. # if model_kwargs["past"] is not None:
  445. # model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
  446. #
  447. # if beam_scorer.is_done:
  448. # break
  449. sequence_outputs = beam_scorer.finalize(
  450. input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
  451. )
  452. return sequence_outputs["sequences"]
  453. def prepare_inputs_for_generation(
  454. self,
  455. decoder_input_ids,
  456. past=None,
  457. attention_mask=None,
  458. head_mask=None,
  459. use_cache=None,
  460. encoder_outputs=None,
  461. **kwargs
  462. ):
  463. # cut decoder_input_ids if past is used
  464. # if past is not None:
  465. # decoder_input_ids = decoder_input_ids[:, -1:] #取一个batch每个序列最后一个token
  466. decoder_input_ids = decoder_input_ids[:, -1:]
  467. return {
  468. "input_ids": None, # encoder_outputs is defined. input_ids not needed
  469. "encoder_outputs": encoder_outputs,
  470. "past_key_values": past,
  471. "decoder_input_ids": decoder_input_ids,
  472. "attention_mask": attention_mask,
  473. "head_mask": head_mask,
  474. "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
  475. }
  476. def __call__(#其实是forward方法,这里简化decoder计算结果
  477. self,
  478. input_ids=None,
  479. attention_mask=None,
  480. decoder_input_ids=None,
  481. decoder_attention_mask=None,
  482. head_mask=None,
  483. decoder_head_mask=None,
  484. encoder_outputs=None,
  485. past_key_values=None,
  486. inputs_embeds=None,
  487. decoder_inputs_embeds=None,
  488. labels=None,
  489. use_cache=None,
  490. output_attentions=None,
  491. output_hidden_states=None,
  492. return_dict=None,
  493. ):
  494. input_shape = decoder_input_ids.size()
  495. decoder_input_ids = decoder_input_ids.view(-1, input_shape[-1])
  496. shape = tuple(decoder_input_ids.shape)+(100,) #设置0-99个数字的预测
  497. lm_logits = torch.zeros(shape)
  498. for ids, num in enumerate(decoder_input_ids.squeeze()):
  499. if (num.item()+1)%10 == 0:#当遇到以9结尾的数字就停止继续生成数字了
  500. num = 1 #num+1=2是<eos>token
  501. maxnum = min(num+1+10,99)
  502. lm_logits[ids,:,num+1:maxnum] = torch.arange(99,99-(maxnum-num-1), step=-1)
  503. return {'logits': lm_logits}
  504. def adjust_logits_during_generation(self, logits, cur_len, max_length):
  505. # if cur_len == 1 and self.config.force_bos_token_to_be_generated:
  506. # self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
  507. # elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
  508. # self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
  509. if cur_len == max_length - 1:
  510. self._force_token_id_to_be_generated(logits, 2)
  511. return logits
  512. @staticmethod
  513. def _force_token_id_to_be_generated(scores, token_id) -> None:
  514. """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
  515. scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
  516. if __name__ == '__main__':
  517. input_ids = torch.randint(0,100,(2,5))
  518. print(input_ids)
  519. decoder = ToyDecoder()
  520. t=decoder._prepare_decoder_input_ids_for_generation(input_ids)
  521. print(t)
  522. t1=decoder.generate(input_ids,8,num_beams=4, num_beam_groups=1, do_sample=False,
  523. length_penalty=1,early_stopping=True, num_return_sequences=4,eos_token_id=2, pad_token_id=1)
  524. print(t1)

结果:

tensor([[91, 50, 26, 71, 23],
        [25, 22, 31, 20, 71]])
tensor([[23],
        [71]])
tensor([[23, 24, 25, 26, 27, 28, 29,  2],
        [23, 24, 25, 26, 27, 29,  2,  1],
        [23, 24, 25, 26, 28, 29,  2,  1],
        [23, 24, 25, 27, 28, 29,  2,  1],
        [71, 72, 73, 74, 75, 76, 77,  2],
        [71, 72, 73, 74, 75, 77, 78,  2],
        [71, 72, 73, 75, 76, 77, 78,  2],
        [71, 72, 74, 75, 76, 77, 78,  2]])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/272795
推荐阅读
  

闽ICP备14008679号