赞
踩
beam search有一个超参数beam_size,设为 k 。第一个时间步长,选取当前条件概率最大的 k 个词,当做候选输出序列的第一个词。之后的每个时间步长,基于上个步长的输出序列,挑选出所有组合中条件概率最大的 k 个,作为该时间步长下的候选输出序列。始终保持 k 个候选。最后从k 个候选中挑出最优的。
假设有n句话,每句话的长度为T。encoder的输出shape为(n, T, hidden_dim),扩展成(n*beam_size, T, hidden_dim)。decoder第一次输入shape为(n, 1),扩展到(n*beam_size, 1)。经过一次解码,输出得分的shape为(n*beam_size, vocab_size),路径得分log_prob的shape为(n*beam_size, 1),两者相加得到当前帧的路径得分。reshape到(n, beam_size*vocab_size),取topk(beam_size),得到排序后的索引(n, beam_size),索引除以vocab_size,得到的是每句话的beam_id,用来获取当前路径前一个字;对vocab_size取余,得到的是每句话的token_id,用来获取当前路径下一次字。
- def beam_search():
- k_prev_words = torch.full((k, 1), SOS_TOKEN, dtype=torch.long) # (k, 1)
- # 此时输出序列中只有sos token
- seqs = k_prev_words #当前路径(k, 1)
- # 初始化scores向量为0
- top_k_scores = torch.zeros(k, 1)
- complete_seqs = [] #已完成序列
- complete_seqs_scores = [] #已完成序列的得分
- step = 1
- hidden = torch.zeros(1, k, hidden_size) # encoder的输出: (1, k, hidden_size)
- while True:
- outputs, hidden = decoder(k_prev_words, hidden) # outputs: (k, seq_len, vocab_size)
- next_token_logits = outputs[:,-1,:] # (k, vocab_size)
- if step == 1:
- # 因为最开始解码的时候只有一个结点<sos>,所以只需要取其中一个结点计算topk
- top_k_scores, top_k_words = next_token_logits[0].topk(k, dim=0, largest=True, sorted=True)
- else:
- # 此时要先展开再计算topk,如上图所示。
- # top_k_scores: (k) top_k_words: (k)
- top_k_scores, top_k_words = next_token_logits.view(-1).topk(k, 0, True, True)
- prev_word_inds = top_k_words / vocab_size # (k) 实际是beam_id,哪个beam就是哪条最优路径
- next_word_inds = top_k_words % vocab_size # (k) 实际是token_id,在单词表中的index
- # seqs: (k, step) ==> (k, step+1)
- seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)
-
- # 当前输出的单词不是eos的有哪些(输出其在next_wod_inds中的位置, 实际是beam_id)
- incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
- next_word != vocab['<eos>']]
- # 输出已经遇到eos的句子的beam id(即seqs中的句子索引)
- complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
-
- if len(complete_inds) > 0:
- complete_seqs.extend(seqs[complete_inds].tolist()) # 加入句子
- complete_seqs_scores.extend(top_k_scores[complete_inds]) # 加入句子对应的累加log_prob
- # 减掉已经完成的句子的数量,更新k, 下次就不用执行那么多topk了,因为若干句子已经被解码出来了
- k -= len(complete_inds)
-
- if k == 0: # 完成
- break
-
- # 更新下一次迭代数据, 仅专注于那些还没完成的句子
- seqs = seqs[incomplete_inds]
- hidden = hidden[prev_word_inds[incomplete_inds]]
- top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) #(s, 1) s < k
- k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) #(s, 1) s < k
-
- if step > max_length: # decode太长后,直接break掉
- break
- step += 1
- i = complete_seqs_scores.index(max(complete_seqs_scores)) # 寻找score最大的序列
- # 有些许问题,在训练初期一直碰不到eos时,此时complete_seqs为空
- seq = complete_seqs[i]
-
- return seq
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。