当前位置:   article > 正文

自然语言生成中解码算法:beam-search、topk-topp、基于kmeans的聚类_beam topk topp

beam topk topp

1. 背景

在生成任务中,经过如图1所示的预训练网络输出的是词表中各词的分布概率,我们要从中找出一条最优路径,使得 P ( y 1 , … , y T ∣ x 1 , … , x T ) P\left(y_{1}, \ldots, y_{T} \mid x_{1}, \ldots, x_{T}\right) P(y1,,yTx1,,xT)最大,本文将介绍贪婪搜索、beam-search搜索、基于kmeans聚类的beam-search、topk-topp等方法。
在这里插入图片描述

2. 贪婪搜索

若词表大小为 |V|,输出序列的最大长度为T ,则所有可能的输出序列为 O ( ∣ V ∣ T ) O\left(|V|^{T}\right) O(VT) 种,需要找出使得 P ( y 1 , … , y T ∣ x 1 , … , x T ) P\left(y_{1}, \ldots, y_{T} \mid x_{1}, \ldots, x_{T}\right) P(y1,,yTx1,,xT) 最大的输出序列 。最简单的方法一个是暴力法,列出所有可能路径,取最大值,但是暴力法计算量巨大。另一个是使用贪婪搜索,每次都选择当前时间步输出最大概率的词,贪婪搜索简单,计算量少,但是并不能找到全局最优的解。
在这里插入图片描述
贪婪搜索时间复杂度: O ( V T ) O(VT) O(VT)

3. beam-search束搜索

束搜索(beam search是一种近似查找最大概率的输出序列的方法,束搜索有一个束宽(BS)超参数,如图5.3所示,束搜索每个时间步都是选取当前条件概率最大的BS个结果作为输出。束搜索虽然不一定能找到最优解,但是平衡了计算量和结果精度,计算量比贪婪搜索稍大,但是结果中出现最优解的概率也大大提升。
在这里插入图片描述
beam search时间复杂度: O ( K V T ) O(KVT) O(KVT)

4. 基于kmeans聚类的beam-search

因为在束搜索中有一个固定大小的候选集上限BS,并且在这BS个候选项中有些是语义相似的,这对于生成结果的多样性没有意义,因此本文使用了一种基于聚类的束搜索方法,即:选取BS2个候选项,然后将这BS2个候选项使用聚类算法将其聚类成K个簇,聚类的特征是已经解码的序列的词向量,然后在每个簇中选前BS/K个候选项作为下一步解码候选集。这样做的话,语义相似的选项会在一个簇中,不同簇中候选项含义不同,从而增加了不同语义的响应的可能。基于Kmeans聚类的束搜索流程如图5.4所示
在这里插入图片描述

5. topk-topp方法

束搜索序列模型相比于贪心算法每次取概率最大的词作为输出生成的句子在流畅性、合理性等方面都有很大的提升,但是基于Beam Search的序列模型仍然有很多缺陷,由于束搜索采用最大概率的方式来生成对话,这样会造成生成的对话比较普通,没有多样性,基于Kmeans聚类的束搜索一定程度上提升了生成文本的多样性。

为了解决生成文本普通乏味,缺乏多样性,可以通过随机采样的方法,但是随机采样的方法可能会出现语法错误的问题。当对全体词按照GPT2网络输出的概率来采样,很有可能会采样到概率低的词,从而造成生成的句子不符合常理,甚至是错误的句子。因此应该从GPT2网络输出的概率最高的一些词中采样,即对最有可能的一些词采样,这样才能在保证输出的句子正确的情况下增加句子的多样性。首先通过式(5-1)对GPT2网络输出各个词的概率进行强化,使得在经过softmax函数后各个词概率分步更加的突出,即大概率的词概率更大,小概率的词概率更小。
p ( i ) = e f ( i ) T ∑ j e f ( j ) T ( 5 − 1 ) p(i)=\frac{e^{\frac{f(i)}{T}}}{\sum_{j} e^{\frac{f(j)}{T}}} (5-1) p(i)=jeTf(j)eTf(i)51
其中 f ( i ) f(i) f(i)是GPT2模型输出的logits,T是超参数,一般设置为小于1的某个数。经过上述处理后挑选出K个概率最大的词,之后对这K个词重新用softmax函数计算各个词的分布概率,获取概率后进行采样,接着进行下一时间步的生成,不断重复,由于上述算法每次都是从K个概率最大的词中进行采样,因此称为TopK采样法,TopK采样法如图5.5所示,该图中设置超参数K=2。
在这里插入图片描述
TopK算法也存在一些缺陷,比如当GPT2模型输出的某个词概率非常大,而其他词的概率都非常的小,由于需要采样K个词,因此还是会出现采样到概率特别低的词,因此应该设置一个概率界限值p,然后从概率大的词中开始取K个,若还未取够K个词,所有取到的词的概率和就大于等于界限p,则提前停止取词,这样取出的K个词中就不会出现概率特别低的词。图5.6展示了TopP算法采样过程,其中超参数p=0.9,K=3。TopK-Top采样法就是将TopK采样法和TopP采样法相结合。
在这里插入图片描述

6. 总结

贪心解码(Greedy Decoding):直接选择概率最高的单词。这种方法简单高效,但是可能会导致生成的文本过于单调和重复

随机采样(Random Sampling):按照概率分布随机选择一个单词。这种方法可以增加生成的多样性,但是可能会导致生成的文本不连贯和无意义

Beam Search:维护一个大小为 k 的候选序列集合,每一步从每个候选序列的概率分布中选择概率最高的 k 个单词,然后保留总概率最高的 k 个候选序列。这种方法可以平衡生成的质量和多样性,但是可能会导致生成的文本过于保守和不自然

Top-k 采样:是对前面“贪心策略”的优化,它从排名前 k 的 token 中进行抽样,允许其他分数或概率较高的token 也有机会被选中。在很多情况下,这种抽样带来的随机性有助于提高生成质量,但是有一定概率采样到概率低的词,造成生成错误

7.代码实现

7.1 top-P实现

def top_p_sampler(logits, p):
    probs = torch.softmax(logits, dim=-1)
    # 降序排序
    sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
    # 概率累加
    cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
    sorted_indices_to_remove = cum_sum_probs > p
    # sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    # 至少保留一个词
    sorted_indices_to_remove[..., 0] = 0
    
    probs_top_p = sorted_probs.clone()
    probs_top_p[sorted_indices_to_remove] = float("-inf")
    probs_to_smaple = torch.softmax(probs_top_p, dim=-1)
    # 采样
    sample_token_id = torch.multinomial(probs_to_smaple, 1)

    token_index = sorted_indices.gather(dim=-1, index=sample_token_id)
    return token_index
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

7.2 topK实现

def topK_sampler(logits, k):
    probs = torch.softmax(logits, dim=-1)
    values, indices = torch.topk(probs, k, dim=-1)
    zeros = logits.new_ones(logits.shape) * float('-inf')
    # 填充
    zeros.scatter_(-1, indices, values)
    probs_to_smaple = torch.softmax(zeros, dim=-1)
    sample_token_id = torch.multinomial(probs_to_smaple, 1)
    return sample_token_id
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

8. 参考

大模型文本生成——解码策略(Top-k & Top-p & Temperature)


关注微信公众号funNLPer快乐起飞
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小惠珠哦/article/detail/737600
推荐阅读