赞
踩
直接上代码吧,公式在维基百科上搜一下。有帮助的话就点赞收藏一下吧,有问题直接评论,会进行答复。
import json import math import os import pickle import sys from typing import Dict, List class BM25: EPSILON = 0.25 PARAM_K1 = 1.5 # BM25算法中超参数 PARAM_B = 0.6 # BM25算法中超参数 def __init__(self, config: str, corpus: Dict): """ 初始化BM25模型,可以从参数config中加载,也可以从corpus中获取,参数config和corpus只能有一个不为空 :param corpus: 文档集, 文档集合应该是字典形式,key为文档的唯一标识,val对应其文本内容 :param config: 若config不为空,则从其中加载对应的参数 """ if (config and corpus) or (not config and not corpus): raise ValueError("config 和 corpus 不能同时为空 或 同时不为空") self.corpus_size = 0 # 文档数量 self.wordNumsOfAllDoc = 0 # 用于计算文档集合中平均每篇文档的词数 -> wordNumsOfAllDoc / corpus_size self.doc_freqs = {} # 记录每篇文档中查询词的词频 self.idf = {} # 记录查询词的 IDF self.doc_len = {} # 记录每篇文档的单词数 self.docContainedWord = {} # 包含单词 word 的文档集合 if config: self._load_config(config) else: self._initialize(corpus) self._save_config() def _load_config(self, config: Dict): with open(config, 'rb') as f: config_data = pickle.load(f) self.corpus_size = config_data.get("corpus_size") self.wordNumsOfAllDoc = config_data.get("wordNumsOfAllDoc") self.doc_freqs = config_data.get("doc_freqs") self.idf = config_data.get("idf") self.doc_len = config_data.get("doc_len") self.docContainedWord = config_data.get("docContainedWord") def _initialize(self, corpus: Dict): """ 根据语料库构建倒排索引 """ # nd = {} # word -> number of documents containing the word for index, document in corpus.items(): self.corpus_size += 1 self.doc_len[index] = len(document) # 文档的单词数 self.wordNumsOfAllDoc += len(document) frequencies = {} # 一篇文档中单词出现的频率 for word in document: if word not in frequencies: frequencies[word] = 0 frequencies[word] += 1 self.doc_freqs[index] = frequencies # 构建词到文档的倒排索引,将包含单词的和文档和包含关系进行反向映射 for word in frequencies.keys(): if word not in self.docContainedWord: self.docContainedWord[word] = set() self.docContainedWord[word].add(index) # 计算 idf idf_sum = 0 # collect idf sum to calculate an average idf for epsilon value negative_idfs = [] for word in self.docContainedWord.keys(): doc_nums_contained_word = len(self.docContainedWord[word]) idf = math.log(self.corpus_size - doc_nums_contained_word + 0.5) - math.log(doc_nums_contained_word + 0.5) self.idf[word] = idf idf_sum += idf if idf < 0: negative_idfs.append(word) average_idf = float(idf_sum) / len(self.idf) eps = BM25.EPSILON * average_idf for word in negative_idfs: self.idf[word] = eps @property def avgdl(self): return float(self.wordNumsOfAllDoc) / self.corpus_size def _save_config(self): path = os.path.join(sys.path[0], "BM25_config.pickle") with open(path, 'wb') as f: json.dump(self.__dict__, f) def get_score(self, query: List, doc_index): """ 计算查询 q 和文档 d 的相关性分数 :param query: 查询词列表 :param doc_index: 为语料库中某篇文档对应的索引 """ k1 = BM25.PARAM_K1 b = BM25.PARAM_B score = 0 doc_freqs = self.doc_freqs[doc_index] for word in query.values(): if word not in doc_freqs: continue score += self.idf[word] * doc_freqs[word] * (k1 + 1) / ( doc_freqs[word] + k1 * (1 - b + b * self.doc_len[doc_index] / self.avgdl)) return [doc_index, score] def get_scores(self, query): scores = [self.get_score(query, index) for index in self.doc_len.keys()] return scores
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。