当前位置:   article > 正文

nanoGPT 中 generate 函数

nanoGPT 中 generate 函数

函数位置 model.py

@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
  """
   Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
   the sequence max_new_tokens times, feeding the predictions back into the model each time.
   Most likely you'll want to make sure to be in model.eval() mode of operation for this.
   """
   # 循环计算下一个 token
   for _ in range(max_new_tokens):
       # if the sequence context is growing too long we must crop it at block_size
       idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] # 截取到 block_size 长度
       # forward the model to get the logits for the index in the sequence
       logits, _ = self(idx_cond) # 调用 forward 函数
       # pluck the logits at the final step and scale by desired temperature
       logits = logits[:, -1, :] / temperature
       # optionally crop the logits to only the top k options
       if top_k is not None:
           v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
           logits[logits < v[:, [-1]]] = -float('Inf') # top_k 之外的词赋值 -float('Inf')
       # apply softmax to convert logits to (normalized) probabilities
       probs = F.softmax(logits, dim=-1)
       # sample from the distribution
       idx_next = torch.multinomial(probs, num_samples=1) # 随机选取一个词
       # append sampled index to the running sequence and continue
       idx = torch.cat((idx, idx_next), dim=1)

   return idx
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号