当前位置:   article > 正文

BeamSearch算法原理及代码解析_beam search中对词的编码方式

beam search中对词的编码方式

1.算法原理

beam search有一个超参数beam_size,设为 k 。第一个时间步长,选取当前条件概率最大的 k 个词,当做候选输出序列的第一个词。之后的每个时间步长,基于上个步长的输出序列,挑选出所有组合中条件概率最大的 k 个,作为该时间步长下的候选输出序列。始终保持 k 个候选。最后从k 个候选中挑出最优的。

2.中心思想

假设有n句话,每句话的长度为T。encoder的输出shape为(n, T, hidden_dim),扩展成(n*beam_size, T, hidden_dim)。decoder第一次输入shape为(n, 1),扩展到(n*beam_size, 1)。经过一次解码,输出得分的shape为(n*beam_size, vocab_size),路径得分log_prob的shape为(n*beam_size, 1),两者相加得到当前帧的路径得分。reshape到(n, beam_size*vocab_size),取topk(beam_size),得到排序后的索引(n, beam_size),索引除以vocab_size,得到的是每句话的beam_id,用来获取当前路径前一个字;对vocab_size取余,得到的是每句话的token_id,用来获取当前路径下一次字。

3.代码解析

  1. def beam_search():
  2. k_prev_words = torch.full((k, 1), SOS_TOKEN, dtype=torch.long) # (k, 1)
  3. # 此时输出序列中只有sos token
  4. seqs = k_prev_words #当前路径(k, 1)
  5. # 初始化scores向量为0
  6. top_k_scores = torch.zeros(k, 1)
  7. complete_seqs = [] #已完成序列
  8. complete_seqs_scores = [] #已完成序列的得分
  9. step = 1
  10. hidden = torch.zeros(1, k, hidden_size) # encoder的输出: (1, k, hidden_size)
  11. while True:
  12. outputs, hidden = decoder(k_prev_words, hidden) # outputs: (k, seq_len, vocab_size)
  13. next_token_logits = outputs[:,-1,:] # (k, vocab_size)
  14. if step == 1:
  15. # 因为最开始解码的时候只有一个结点<sos>,所以只需要取其中一个结点计算topk
  16. top_k_scores, top_k_words = next_token_logits[0].topk(k, dim=0, largest=True, sorted=True)
  17. else:
  18. # 此时要先展开再计算topk,如上图所示。
  19. # top_k_scores: (k) top_k_words: (k)
  20. top_k_scores, top_k_words = next_token_logits.view(-1).topk(k, 0, True, True)
  21. prev_word_inds = top_k_words / vocab_size # (k) 实际是beam_id,哪个beam就是哪条最优路径
  22. next_word_inds = top_k_words % vocab_size # (k) 实际是token_id,在单词表中的index
  23. # seqs: (k, step) ==> (k, step+1)
  24. seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)
  25. # 当前输出的单词不是eos的有哪些(输出其在next_wod_inds中的位置, 实际是beam_id)
  26. incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
  27. next_word != vocab['<eos>']]
  28. # 输出已经遇到eos的句子的beam id(即seqs中的句子索引)
  29. complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
  30. if len(complete_inds) > 0:
  31. complete_seqs.extend(seqs[complete_inds].tolist()) # 加入句子
  32. complete_seqs_scores.extend(top_k_scores[complete_inds]) # 加入句子对应的累加log_prob
  33. # 减掉已经完成的句子的数量,更新k, 下次就不用执行那么多topk了,因为若干句子已经被解码出来了
  34. k -= len(complete_inds)
  35. if k == 0: # 完成
  36. break
  37. # 更新下一次迭代数据, 仅专注于那些还没完成的句子
  38. seqs = seqs[incomplete_inds]
  39. hidden = hidden[prev_word_inds[incomplete_inds]]
  40. top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) #(s, 1) s < k
  41. k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) #(s, 1) s < k
  42. if step > max_length: # decode太长后,直接break掉
  43. break
  44. step += 1
  45. i = complete_seqs_scores.index(max(complete_seqs_scores)) # 寻找score最大的序列
  46. # 有些许问题,在训练初期一直碰不到eos时,此时complete_seqs为空
  47. seq = complete_seqs[i]
  48. return seq

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/547151
推荐阅读
相关标签
  

闽ICP备14008679号