赞
踩
Beam search是一种动态规划算法,能够极大的减少搜索空间,增加搜索效率,并且其误差在可接受范围内,常被用于Sequence to Sequence模型,CTC解码等应用中
对于 T × N T\times N T×N的时间序列,如果我们要遍历所有可能能,则其所需的时间复杂度为 O ( N + N 2 + N 3 + . . . + N T ) \mathcal{O}(N+N^2+N^3+...+N^T) O(N+N2+N3+...+NT),在每一时间节点,所需遍历的节点数呈指数增加。对于Viterbi算法来说,时间复杂度为 O ( N + ( T − 1 ) N 2 ) \mathcal{O}(N+(T-1)N^2) O(N+(T−1)N2),在每个时间节点输入为N个best节点,需要比较的次数为 N 2 N^2 N2,然而这个时间复杂度还是太高。在N比较大的情况下,Beam Search为更好的选择,其时间复杂度为 O ( N + ( T − 1 ) ∗ b e a m s i z e ∗ N ) \mathcal{O}(N+(T-1)*beamsize*N) O(N+(T−1)∗beamsize∗N),每个时间节点的输入为beamsize个best节点,需要比较的次数为 b e a m s i z e ∗ N beamsize*N beamsize∗N
如上图所示,常规的beam search在每个时间节点,对输入的每个节点比较N次,并从
b
e
a
m
s
i
z
e
∗
N
beamsize*N
beamsize∗N个比较结果中,选择
b
e
a
m
s
i
z
e
beamsize
beamsize个结果作为下一时间节点的输入,其python的简单实现如下
import numpy as np import math def beam_search(nodes, topk=1): # log-likelihood可以相加 paths = {'A':math.log(nodes[0]['A']), 'B': math.log(nodes[0]['B']), 'C':math.log(nodes[0]['C'])} calculations = [] for l in range(1, len(nodes)): # 拷贝当前路径 paths_ = paths.copy() paths = {} nows = {} cur_cal = 0 for i in nodes[l].keys(): # 计算到达节点i的所有路径 for j in paths_.keys(): nows[j+i] = paths_[j]+math.log(nodes[l][i]) cur_cal += 1 calculations.append(cur_cal) # 选择topk条路径 indices = np.argpartition(list(nows.values()), -topk)[-topk:] # 保存topk路径 for k in indices: paths[list(nows.keys())[k]] = list(nows.values())[k] print(f'calculation number {calculations}') return paths nodes = [{'A':0.1, 'B':0.3, 'C':0.6}, {'A':0.2, 'B':0.4, 'C':0.4}, {'A':0.6, 'B':0.2, 'C':0.2}, {'A': 0.3, 'B': 0.3, 'C': 0.4}] print(beam_search(nodes, topk=2)) 输出结果: calculation number [9, 6, 6] {'CBAA': -3.1419147837320724, 'CBAC': -2.854232711280291, 'CCAC': -2.854232711280291}
我们可以看到,在 N = 3 N=3 N=3, b e a m s i z e = 2 beamsize=2 beamsize=2的情况下,每个节点的比较次数为6。
在CTC算法中,由于添加了blank以及重复字符串无blank合并的规则,例如ab
可能aab
,abb
,a blank b
等多种情况的输入,因此ab
的可能性应该为多种情况log概率之和,而不能通过单条beam进行搜索,因此可以采用改进版的prefix beam search,其代码如下
""" Code from https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0 Author: Awni Hannun CTC decoder in python, 简单例子可能不太效率 用于CTC模型的输出的前缀beam search 更多细节参考 https://distill.pub/2017/ctc/#inference https://arxiv.org/abs/1408.2873 """ import numpy as np import math import collections NEG_INF = -float("inf") def make_new_beam(): fn = lambda: (NEG_INF, NEG_INF) return collections.defaultdict(fn) def logsumexp(*args): """ Stable log sum exp. """ if all(a == NEG_INF for a in args): return NEG_INF a_max = max(args) lsp = math.log(sum(math.exp(a - a_max) for a in args)) return a_max + lsp def decode(probs, beam_size=100, blank=0): """ 对给定输出概率进行预测 Arguments: probs: 输出概率 (e.g. post-softmax) for each time step. Should be an array of shape (time x output dim). beam_size (int): Size of the beam to use during inference. blank (int): Index of the CTC blank label. Returns the output label sequence and the corresponding negative log-likelihood estimated by the decoder. """ T, S = probs.shape probs = np.log(probs) # 在beam中的元素为(prefix, (p_blank, p_no_blank)) # 初始beam为空序列,第一个是前缀,第二个是后接blank的log概率,第三个是后接非blank的log概率 # 我们需要后接blank和后接非blank两种情况,来区分重复字符是否应该被合并,对于后接blank的情况,重复字符就不会被合并 beam = [(tuple(), (0.0, NEG_INF))] for t in range(T): # 沿时间维度循环 # 存储下一个候选集的预设置字典,每次新的时间节点都会重设 next_beam = make_new_beam() for s in range(S): # 沿词表维度循环 p = probs[t, s] # p_b和p_nb分别为在当前时刻下前缀后接blank和非blank的log概率 for prefix, (p_b, p_nb) in beam: # 对beam进行循环 # 如果s为blank,那么前缀不会改变 # 因为后接的是blank,所以只需要更新前缀不变的情况下后接blank的log概率 if s == blank: n_p_b, n_p_nb = next_beam[prefix] n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p) next_beam[prefix] = (n_p_b, n_p_nb) continue # 记录前缀最后一个字符,用于判断当前字符与前缀最后一个字符是否相同 end_t = prefix[-1] if prefix else None n_prefix = prefix + (s,) # n_prefix代表next prefix n_p_b, n_p_nb = next_beam[n_prefix] # n_p_b代表 next probability of blank # 将新的字符s加到prefix后面并将整体加入到beam中 # 因为后接的是非blank,所以只需要更新后接非blank的log概率 if s != end_t: n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p) else: # 如果后接s是重复的,那么我们在更新后接非blank的log概率时, # 不包括上一时刻后接非blank的概率。CTC算法会合并没有用blank分隔的重复字符 n_p_nb = logsumexp(n_p_nb, p_b + p) # 这里是加入语言模型分数的好地方 next_beam[n_prefix] = (n_p_b, n_p_nb) # 这是合并的情况,如果s重复出现了,前缀也不会改变,我们也更新前缀不变的情况下后接非blank的log概率 if s == end_t: n_p_b, n_p_nb = next_beam[prefix] n_p_nb = logsumexp(n_p_nb, p_nb + p) next_beam[prefix] = (n_p_b, n_p_nb) # 在进入下一时间步之前,排序并裁剪beam beam = sorted(next_beam.items(), key=lambda x: logsumexp(*x[1]), reverse=True) beam = beam[:beam_size] best = beam[0] return best[0], -logsumexp(*best[1]) if __name__ == "__main__": np.random.seed(3) time = 50 output_dim = 20 probs = np.random.rand(time, output_dim) probs = probs / np.sum(probs, axis=1, keepdims=True) labels, score = decode(probs) print(labels) print("Score {:.3f}".format(score))
与常规BS不同的地方主要在于, PBS区分了几种情况以及log probability的计算方式
BS根据不同的场景可以有不同的写法,其主要目的在于在每个时间点选择TOPK的路径继续搜索,达到增加搜索效率的目的,在BS的搜索过程中,如果是生成字符串,我们还可以加入语言模型的分数,得到更好的结果:
Y
∗
=
l
o
g
P
(
Y
∣
X
)
+
α
l
o
g
(
P
l
m
(
Y
)
)
+
β
l
e
n
(
Y
)
Y^*=logP(Y|X)+\alpha log(P_{lm}(Y))+\beta len(Y)
Y∗=logP(Y∣X)+αlog(Plm(Y))+βlen(Y)
语言模型的加入地方一般为字符串扩增时。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。