赞
踩
def sample_logits(out: paddle.Tensor, temperature: float = 1.0, top_p: float = 0.8): """ 对模型输出的logits进行采样。 Args: out (paddle.Tensor): 模型输出的logits张量,形状为[Batch, vocab_size]。 temperature (float): 温度参数,用于调节采样的多样性,默认为1.0。 top_p (float): Top-p截断参数,用于稳定和控制采样概率分布,默认为0.8。 Returns: paddle.Tensor: 采样结果,形状为[Batch, 1],每个元素表示一个样本中采样得到的词的索引。 """ # 确保top_p和temperature都是非负值 top_p = max(0.0, min(1.0, top_p)) temperature = max(0.0, temperature) # 将out转换为概率分布 probs = paddle.nn.functional.softmax(out, axis=-1) # 根据top_p截断概率分布 sorted_probs = paddle.sort(probs, descending=True) cumulative_probs = paddle.cumsum(sorted_probs, axis=-1) cutoff_mask = cumulative_probs > top_p for i,ii in enumerate(paddle.argmax(cutoff_mask.astype("int"),-1)): probs[i][probs[i]<ii] = 0.0 probs[i][paddle.argmax(probs[i])]=0 # 对概率分布进行温度调节 if temperature != 1.0: probs = paddle.pow(probs, 1.0 / temperature) # 归一化概率分布 probs /= paddle.sum(probs,axis=-1, keepdim=True) # 如果top_p为0,则选择概率最大的位置;否则按照概率分布随机采样 if top_p != 0: sampled_indices = paddle.multinomial(probs, num_samples=1) else: sampled_indices = paddle.argmax(probs, axis=-1, keepdim=True) return sampled_indices
该函数的作用是对模型输出的logits进行采样,返回采样结果。
首先,函数定义了输入参数out、temperature和top_p,分别表示模型输出的logits张量、温度参数和Top-p截断参数。
然后,通过将out进行softmax操作,将其转化为概率分布。softmax操作通过将logits的每个元素从线性尺度转化为非线性概率分布来实现。
接下来,根据Top-p截断参数对概率分布进行截断操作。首先,将概率分布按照概率值进行排序,并计算累积概率。然后,根据Top-p截断参数,确定需要截断的位置。截断操作的方式是将概率值小于截断位置的元素置为0,并保留概率值最大的元素。
在对概率分布进行温度调节之前,需要对概率分布进行归一化操作,将概率值除以概率总和,使得概率分布的概率值之和为1。
最后,根据温度参数和Top-p截断参数选择采样的方式。如果Top-p截断参数不为0,则根据概率分布随机采样一个样本;如果Top-p截断参数为0,则选择概率值最大的元素作为采样结果。
最终,函数返回采样结果。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。