当前位置:   article > 正文

BM25算法,python实现(源代码)_bm25训练 python

bm25训练 python

BM25算法,python实现

直接上代码吧,公式在维基百科上搜一下。有帮助的话就点赞收藏一下吧,有问题直接评论,会进行答复。

import json
import math
import os
import pickle
import sys
from typing import Dict, List


class BM25:

    EPSILON = 0.25
    PARAM_K1 = 1.5 # BM25算法中超参数
    PARAM_B = 0.6  # BM25算法中超参数

    def __init__(self, config: str, corpus: Dict):
        """
            初始化BM25模型,可以从参数config中加载,也可以从corpus中获取,参数config和corpus只能有一个不为空
            :param corpus: 文档集, 文档集合应该是字典形式,key为文档的唯一标识,val对应其文本内容
            :param config: 若config不为空,则从其中加载对应的参数
        """
        if (config and corpus) or (not config and not corpus):
            raise ValueError("config 和 corpus 不能同时为空 或 同时不为空")

        self.corpus_size = 0            # 文档数量
        self.wordNumsOfAllDoc = 0       # 用于计算文档集合中平均每篇文档的词数 -> wordNumsOfAllDoc / corpus_size
        self.doc_freqs = {}             # 记录每篇文档中查询词的词频
        self.idf = {}                   # 记录查询词的 IDF
        self.doc_len = {}               # 记录每篇文档的单词数
        self.docContainedWord = {}      # 包含单词 word 的文档集合

        if config:
            self._load_config(config)
        else:
            self._initialize(corpus)
            self._save_config()

    def _load_config(self, config: Dict):
        with open(config, 'rb') as f:
            config_data = pickle.load(f)
        self.corpus_size = config_data.get("corpus_size")
        self.wordNumsOfAllDoc = config_data.get("wordNumsOfAllDoc")
        self.doc_freqs = config_data.get("doc_freqs")
        self.idf = config_data.get("idf")
        self.doc_len = config_data.get("doc_len")
        self.docContainedWord = config_data.get("docContainedWord")

    def _initialize(self, corpus: Dict):
        """
            根据语料库构建倒排索引
        """
        # nd = {} # word -> number of documents containing the word
        for index, document in corpus.items():
            self.corpus_size += 1
            self.doc_len[index] = len(document)  # 文档的单词数
            self.wordNumsOfAllDoc += len(document)

            frequencies = {}  # 一篇文档中单词出现的频率
            for word in document:
                if word not in frequencies:
                    frequencies[word] = 0
                frequencies[word] += 1
            self.doc_freqs[index] = frequencies

            # 构建词到文档的倒排索引,将包含单词的和文档和包含关系进行反向映射
            for word in frequencies.keys():
                if word not in self.docContainedWord:
                    self.docContainedWord[word] = set()
                self.docContainedWord[word].add(index)

        # 计算 idf
        idf_sum = 0  # collect idf sum to calculate an average idf for epsilon value
        negative_idfs = []
        for word in self.docContainedWord.keys():
            doc_nums_contained_word = len(self.docContainedWord[word])
            idf = math.log(self.corpus_size - doc_nums_contained_word +
                           0.5) - math.log(doc_nums_contained_word + 0.5)
            self.idf[word] = idf
            idf_sum += idf
            if idf < 0:
                negative_idfs.append(word)

        average_idf = float(idf_sum) / len(self.idf)
        eps = BM25.EPSILON * average_idf
        for word in negative_idfs:
            self.idf[word] = eps

    @property
    def avgdl(self):
        return float(self.wordNumsOfAllDoc) / self.corpus_size

    def _save_config(self):
        path = os.path.join(sys.path[0], "BM25_config.pickle")
        with open(path, 'wb') as f:
            json.dump(self.__dict__, f)

    def get_score(self, query: List, doc_index):
        """
        计算查询 q 和文档 d 的相关性分数
        :param query: 查询词列表
        :param doc_index: 为语料库中某篇文档对应的索引
        """
        k1 = BM25.PARAM_K1
        b = BM25.PARAM_B
        score = 0
        doc_freqs = self.doc_freqs[doc_index]
        for word in query.values():
            if word not in doc_freqs:
                continue
            score += self.idf[word] * doc_freqs[word] * (k1 + 1) / (
                doc_freqs[word] + k1 * (1 - b + b * self.doc_len[doc_index] / self.avgdl))
        return [doc_index, score]

    def get_scores(self, query):
        scores = [self.get_score(query, index) for index in self.doc_len.keys()]
        return scores
    

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
'
运行
本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/小桥流水78/article/detail/972271
推荐阅读
相关标签
  

闽ICP备14008679号