赞
踩
话不多说,直接上代码
import math def count_ngram(candidate, references, n): clipped_count = 0 count = 0 r = 0 c = 0 for si in range(len(candidate)): # 计算每个句子的精度 ref_counts = [] ref_lengths = [] # 构建 ngram 计数字典 for reference in references: ref_sentence = reference[si] ngram_d = {} words = ref_sentence.strip().split() ref_lengths.append(len(words)) limits = len(words) - n + 1 # 遍历句子考虑 ngram 长度 for i in range(limits): ngram = ' '.join(words[i:i+n]).lower() if ngram in ngram_d.keys(): ngram_d[ngram] += 1 else: ngram_d[ngram] = 1 ref_counts.append(ngram_d) # 候选句 cand_sentence = candidate[si] cand_dict = {} words = cand_sentence.strip().split() limits = len(words) - n + 1 for i in range(0, limits): ngram = ' '.join(words[i:i + n]).lower() if ngram in cand_dict: cand_dict[ngram] += 1 else: cand_dict[ngram] = 1 clipped_count += clip_count(cand_dict, ref_counts) count += limits r += best_length_match(ref_lengths, len(words)) c += len(words) if clipped_count == 0: pr = 0 else: pr = float(clipped_count) / count bp = brevity_penalty(c, r) return pr, bp def clip_count(cand_d, ref_ds): """考虑所有引用,计算每个 ngram 的剪辑计数""" count = 0 for m in cand_d.keys(): m_w = cand_d[m] m_max = 0 for ref in ref_ds: if m in ref: m_max = max(m_max, ref[m]) m_w = min(m_w, m_max) count += m_w return count def best_length_match(ref_l, cand_l): """找到最接近候选句的参考长度""" least_diff = abs(cand_l-ref_l[0]) best = ref_l[0] for ref in ref_l: if abs(cand_l-ref) < least_diff: least_diff = abs(cand_l-ref) best = ref return best def brevity_penalty(c, r): if c > r: bp = 1 else: bp = math.exp(1-(float(r)/c)) return bp def BLEU(candidate, references, n_gram): pr, bp = count_ngram(candidate, references, n_gram) return pr candidate = ['我', '爱', '你'] references = [['我', '爱', '你']] bleu = BLEU(candidate, references, 1) bleu
其中1代表n-gram,运行结果:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。