当前位置:   article > 正文

RedisSearch(附 python demo 代码)_python redis列表搜索

python redis列表搜索

Redis作为一个高性能的键值对存储系统,常被用于缓存、消息队列等场景。然而,对于需要全文搜索的应用,Redis原生的数据结构可能无法满足需求。此时,RedisSearch模块便派上了用场。
RedisSearch是一个源代码可用的Redis模块,为Redis增加了查询、辅助索引和全文搜索功能。它基于RediSearch实现,能够在Redis上执行复杂的多字段查询、聚合、精确短语匹配、数字过滤、地理过滤和向量相似性语义搜索。
以下是RedisSearch的主要特性:

  1. 多字段联合检索:支持在多个字段上进行搜索,满足复杂查询需求。
  2. 高性能增量索引:能够高效地处理大量数据,实现快速的索引构建和更新。
  3. 精确短语搜索:支持精确匹配短语查询,提高搜索准确率。
  4. 数字过滤器和范围:能够根据数字字段进行过滤和范围查询。
  5. 地理过滤:利用Redis的地理命令,实现地理位置相关的搜索功能。
  6. 支持自定义评分函数:类似于Elasticsearch的function_score,可根据特定需求调整搜索结果的排序。
  7. 检索完整的文档内容或只是ID:根据需要选择检索完整文档或仅检索ID。
  8. 支持文档删除和更新与索引垃圾收集:确保索引的实时性和有效性。
  9. 支持部分更新和条件文档更新:方便对文档进行部分修改或基于特定条件的更新。
  10. 支持拼写纠错:提供拼写建议,提高用户体验。
  11. 支持高亮显示:对匹配的关键词进行高亮显示,帮助用户快速识别结果。
  12. 支持聚合分析:对搜索结果进行分组、计数等聚合操作,便于数据分析和挖掘。
  13. 支持配置停用词和同义词:通过配置停用词和同义词,优化搜索效果,提高查询准确率。
  14. 向量存储与KNN检索:支持向量存储和最近邻(KNN)检索,适用于需要相似性搜索的场景。

RedisSearchHelper 工具类

pip install redis redisearch

  • 1
  • 2
import redis
from redisearch import Client, TextField, NumericField, Query, Document, IndexDefinition, IndexType

class RedisSearchHelper:
	"""
	该类封装了对 Redis 和 Redisearch 的基本操作,提供了创建索引、添加文档、批量添加文档、更新文档、删除文档、搜索文档、获取所有文档以及删除索引等功能。
	"""
    def __init__(self, index_name, host='localhost', port=6379, password=None):
        """
        初始化 Redis 客户端和 Redisearch 客户端。参数包括索引名称、Redis 服务器地址、端口和密码。

        :param index_name: 索引名称
        :param host: Redis 服务器地址,默认为 localhost
        :param port: Redis 服务器端口,默认为 6379
        :param password: Redis 密码(如果有的话)
        """
        self.index_name = index_name
        self.redis_client = redis.Redis(host=host, port=port, password=password, decode_responses=True)
        self.search_client = Client(index_name, conn=self.redis_client)

    def create_index(self, schema, definition=None):
        """
        创建一个新的 Redisearch 索引。参数包括索引模式和可选的索引定义。

        :param schema: 索引模式,包含字段及其类型
        :param definition: 可选的索引定义
        """
        try:
            if definition is None:
                self.search_client.create_index(schema)
            else:
                self.search_client.create_index(schema, definition=definition)
        except Exception as e:
            print(f"Failed to create index: {e}")

    def add_document(self, document_id, **fields):
        """
        添加文档到索引。向索引中添加单个文档。参数包括文档 ID 及其字段和值。

        :param document_id: 文档 ID
        :param fields: 字段名和值
        """
        try:
            self.search_client.add_document(document_id, **fields)
        except Exception as e:
            print(f"Failed to add document: {e}")

    def batch_add_documents(self, documents):
        """
        批量添加文档到索引。参数是一个包含多个文档的列表,每个文档是一个字典。

        :param documents: 文档列表,每个文档是一个字典,包含文档ID和字段名值对
        """
        try:
            pipeline = self.search_client.pipeline()
            for doc in documents:
                document_id = doc.pop('id')
                pipeline.add_document(document_id, **doc)
            pipeline.execute()
        except Exception as e:
            print(f"Failed to batch add documents: {e}")

    def update_document(self, document_id, **fields):
        """
        更新已存在的文档。参数包括文档 ID 和要更新的字段和值。

        :param document_id: 文档 ID
        :param fields: 要更新的字段名和值
        """
        try:
            self.search_client.add_document(document_id, replace=True, partial=True, **fields)
        except Exception as e:
            print(f"Failed to update document: {e}")

    def delete_document(self, document_id):
        """
        从索引中删除指定的文档。参数是文档 ID。

        :param document_id: 文档 ID
        """
        try:
            self.search_client.delete_document(document_id)
        except Exception as e:
            print(f"Failed to delete document: {e}")

    def search(self, query_text, with_scores=False, with_payloads=False, with_sorting=False, sort_by=None):
        """
        在索引中执行搜索查询。参数包括查询文本、是否返回得分、是否返回负载、是否启用排序和排序字段。

        :param query_text: 查询文本
        :param with_scores: 是否返回得分
        :param with_payloads: 是否返回额外的负载
        :param with_sorting: 是否启用排序
        :param sort_by: 排序字段
        :return: 查询结果
        """
        q = Query(query_text)
        if with_scores:
            q = q.with_scores()
        if with_payloads:
            q = q.with_payloads()
        if with_sorting and sort_by:
            q = q.sort_by(sort_by)

        try:
            results = self.search_client.search(q)
            return results
        except Exception as e:
            print(f"Failed to execute search: {e}")
            return None

    def delete_index(self):
        """
        获取索引中的所有文档。
        """
        try:
            self.search_client.drop_index(delete_documents=True)
        except Exception as e:
            print(f"Failed to delete index: {e}")

    def get_all_documents(self):
        """
        获取索引中的所有文档。

        :return: 所有文档
        """
        try:
            q = Query('*')
            results = self.search_client.search(q)
            return results
        except Exception as e:
            print(f"Failed to get all documents: {e}")
            return None

if __name__ == "__main__":
    # 创建 RedisSearchHelper 实例
    helper = RedisSearchHelper('exampleIndex')

    # 定义索引模式
    schema = (
        TextField('title', weight=5.0),
        TextField('body'),
        NumericField('year')
    )
    definition = IndexDefinition(prefix=['doc:'], score=0.5, index_type=IndexType.HASH)

    # 创建索引
    helper.create_index(schema, definition=definition)

    # 批量添加文档
    documents = [
        {'id': 'doc:1', 'title': 'Python Programming', 'body': 'Python is a high-level programming language.', 'year': 2023},
        {'id': 'doc:2', 'title': 'Java Programming', 'body': 'Java is another popular programming language.', 'year': 2022}
    ]
    helper.batch_add_documents(documents)

    # 更新文档
    helper.update_document('doc:1', body='Python is a versatile language.')

    # 删除文档
    helper.delete_document('doc:2')

    # 搜索文档
    results = helper.search('programming')
    print(results.docs)

    # 获取所有文档
    all_docs = helper.get_all_documents()
    print(all_docs.docs)

    # 删除索引
    helper.delete_index()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173

关键点总结

  • 初始化和配置:提供了灵活的配置选项来连接 Redis 服务器,并且支持设置索引的定义。
  • 文档操作:提供了添加、批量添加、更新和删除文档的操作,确保对索引数据的全面管理。
  • 查询操作:支持复杂的搜索查询,包括得分、负载和排序等功能,增强了搜索的灵活性。
  • 索引管理:提供了删除索引的功能,确保可以清理和重建索引。

这个类提供了一个高效、灵活的接口来使用 Redisearch 进行全文搜索和索引管理,适用于多种使用场景。


实现 BM25语义相似度检索打分独立进行,并在异步执行后将结果进行权重融合,最后用 bge-large reranker 重新排序

  1. 引入必要的库和模型:引入 Sentence BERT 模型用于计算语义相似度。
  2. 实现嵌入计算:计算查询和文档的 Sentence BERT 嵌入。
  3. 实现 BM25 排序:使用 Redisearch 提供的 BM25 排序。
  4. 融合分数:将语义相似度分数和 BM25 分数按权重融合。
  5. 重新排序:使用 bge-m3 模型对融合后的结果进行重新排序。

自定义 BM25 类

import math
from collections import Counter
from typing import List

class BM25:
    def __init__(self, documents: List[str]):
        """
        初始化 BM25 模型。

        :param documents: 文档列表
        """
        self.documents = [doc.split() for doc in documents]
        self.doc_count = len(self.documents)
        self.avgdl = sum(len(doc) for doc in self.documents) / self.doc_count
        self.k1 = 1.5
        self.b = 0.75
        self.doc_freqs = []
        self.idf = {}
        self.doc_len = []

        self.initialize()

    def initialize(self):
        """
        初始化 BM25 模型所需的数据结构。
        """
        df = Counter()
        for doc in self.documents:
            self.doc_len.append(len(doc))
            frequencies = Counter(doc)
            self.doc_freqs.append(frequencies)
            for word in frequencies.keys():
                df[word] += 1

        for word, freq in df.items():
            self.idf[word] = math.log(1 + (self.doc_count - freq + 0.5) / (freq + 0.5))

    def get_scores(self, query: List[str]) -> List[float]:
        """
        计算查询的 BM25 分数。

        :param query: 查询词列表
        :return: 分数列表
        """
        scores = [0.0] * self.doc_count
        for word in query:
            if word not in self.idf:
                continue
            idf = self.idf[word]
            for i in range(self.doc_count):
                if word in self.doc_freqs[i]:
                    freq = self.doc_freqs[i][word]
                    score = idf * (freq * (self.k1 + 1) / (freq + self.k1 * (1 - self.b + self.b * self.doc_len[i] / self.avgdl)))
                    scores[i] += score
        return scores



  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
import redis
from redisearch import Client, TextField, Query
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import numpy as np
import asyncio
from typing import List, Dict, Any
from BM25 import BM25  # 需要从自定义文件中导入 BM25 类

class RedisSearchHelper:
    def __init__(self, index_name: str, host='localhost', port=6379, password=None, 
                 model_name='sentence-transformers/all-MiniLM-L6-v2', 
                 reranker_model_name='moka-ai/bge-large-reranker', 
                 cache_dir='./cache', use_custom_bm25=False):
        """
        初始化 RedisSearch 客户端和模型。

        :param index_name: 索引名称
        :param host: Redis 服务器地址,默认为 localhost
        :param port: Redis 服务器端口,默认为 6379
        :param password: Redis 密码(如果有的话)
        :param model_name: 用于计算语义相似度的模型名称,默认为 'sentence-transformers/all-MiniLM-L6-v2'
        :param reranker_model_name: 用于重排序的模型名称,默认为 'moka-ai/bge-large-reranker'
        :param cache_dir: 模型缓存目录,默认为 './cache'
        :param use_custom_bm25: 是否使用自定义 BM25 实现,默认为 False
        """
        self.index_name = index_name
        self.redis_client = redis.Redis(host=host, port=port, password=password, decode_responses=True)
        self.search_client = Client(index_name, conn=self.redis_client)
        self.model = SentenceTransformer(model_name, cache_folder=cache_dir)
        self.reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_model_name, cache_dir=cache_dir)
        self.reranker_model = AutoModelForSequenceClassification.from_pretrained(reranker_model_name, cache_dir=cache_dir)
        self.use_custom_bm25 = use_custom_bm25
        self.bm25 = None
        if use_custom_bm25:
            self.bm25 = self.initialize_custom_bm25()

    def initialize_custom_bm25(self) -> BM25:
        """
        初始化自定义 BM25 模型。
        """
        doc_ids = self.redis_client.keys("doc:*")
        documents = [self.redis_client.hget(doc_id, "body") for doc_id in doc_ids]
        return BM25(documents)

    async def bm25_search(self, query_text: str, top_k: int) -> List[Dict[str, Any]]:
        """
        使用 BM25 进行搜索。

        :param query_text: 查询文本
        :param top_k: 返回的文档数量
        :return: BM25 搜索结果
        """
        if self.use_custom_bm25:
            tokenized_query = query_text.split(" ")
            doc_scores = self.bm25.get_scores(tokenized_query)
            top_n_indices = np.argsort(doc_scores)[::-1][:top_k]
            doc_ids = [self.redis_client.keys("doc:*")[i] for i in top_n_indices]
            bm25_results = [{'id': doc_id, 'score': doc_scores[i]} for i, doc_id in enumerate(doc_ids)]
            return bm25_results
        else:
            q = Query(query_text).paging(0, top_k)
            try:
                bm25_results = self.search_client.search(q)
                return [{'id': doc.id, 'score': doc.score} for doc in bm25_results.docs]
            except Exception as e:
                print(f"Failed to execute BM25 search: {e}")
                return []

    async def semantic_search(self, query_text: str, top_k: int) -> List[Dict[str, Any]]:
        """
        使用语义相似度进行搜索。

        :param query_text: 查询文本
        :param top_k: 返回的文档数量
        :return: 语义相似度搜索结果
        """
        try:
            query_embedding = self.model.encode(query_text)
            doc_ids = self.redis_client.keys("doc:*")
            doc_texts = [self.redis_client.hget(doc_id, "body") for doc_id in doc_ids]
            doc_embeddings = self.model.encode(doc_texts)
            semantic_scores = util.pytorch_cos_sim(query_embedding, doc_embeddings).numpy()[0]
            top_n_indices = np.argsort(semantic_scores)[::-1][:top_k]
            semantic_results = [{'id': doc_ids[i], 'body': doc_texts[i], 'score': semantic_scores[i]} for i in top_n_indices]
            return semantic_results
        except Exception as e:
            print(f"Failed to execute semantic search: {e}")
            return []

    def rerank(self, query_text: str, docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        使用重排序模型对文档进行重排序。

        :param query_text: 查询文本
        :param docs: 待重排序的文档
        :return: 重排序后的文档
        """
        try:
            inputs = self.reranker_tokenizer(
                [query_text] * len(docs), 
                [doc['body'] for doc in docs], 
                return_tensors='pt', 
                padding=True, 
                truncation=True
            )
            outputs = self.reranker_model(**inputs)
            rerank_scores = outputs.logits.squeeze().detach().numpy()
            for i, doc in enumerate(docs):
                doc['rerank_score'] = rerank_scores[i]
            return sorted(docs, key=lambda x: x['rerank_score'], reverse=True)
        except Exception as e:
            print(f"Failed to rerank: {e}")
            return docs

    async def search(self, query_text: str, bm25_weight: float = 0.5, semantic_weight: float = 0.5, top_k: int = 10) -> List[Dict[str, Any]]:
        """
        综合使用 BM25 和语义相似度进行搜索,并融合结果进行重排序。

        :param query_text: 查询文本
        :param bm25_weight: BM25 分数的权重
        :param semantic_weight: 语义相似度分数的权重
        :param top_k: 返回的文档数量
        :return: 最终的搜索结果
        """
        bm25_task = self.bm25_search(query_text, top_k)
        semantic_task = self.semantic_search(query_text, top_k)

        bm25_results, semantic_results = await asyncio.gather(bm25_task, semantic_task)

        if not bm25_results and not semantic_results:
            print("Failed to retrieve results from both BM25 and semantic search.")
            return []

        final_scores = []
        if bm25_results:
            for bm25_doc in bm25_results:
                doc_id = bm25_doc['id']
                bm25_score = bm25_doc['score']
                semantic_score = next((doc['score'] for doc in semantic_results if doc['id'] == doc_id), 0) if semantic_results else 0
                final_score = bm25_weight * bm25_score + semantic_weight * semantic_score
                final_scores.append({'id': doc_id, 'score': final_score})

        if semantic_results:
            for semantic_doc in semantic_results:
                doc_id = semantic_doc['id']
                if not any(doc['id'] == doc_id for doc in final_scores):
                    final_scores.append({
                        'id': doc_id, 
                        'score': semantic_weight * semantic_doc['score']
                    })

        final_scores = sorted(final_scores, key=lambda x: x['score'], reverse=True)
        top_results = final_scores[:top_k]
        return top_results

# 示例用法
async def main():
    # 假设已经在 Redis 中创建了文档
    redis_search_helper = RedisSearchHelper(index_name='my_index', use_custom_bm25=True)
    query_text = "example search query"
    
    search_results = await redis_search_helper.search(query_text, bm25_weight=0.5, semantic_weight=0.5, top_k=5)
    for result in search_results:
        print(result)

# 如果直接在脚本中运行
if __name__ == "__main__":
    asyncio.run(main())

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170

为了提高检索速度和准确率,我们可以具体实现以下优化策略。这些策略包括BM25的并行处理和稀疏表示,以及RedisSearch的配置优化、索引优化和模型融合。下面是针对这些优化的代码示例:

1,BM25 算法优化
1.1 并行处理

使用 concurrent.futures 库进行并行计算,以提高BM25评分的计算速度。此示例使用线程池来并行计算每个文档的BM25分数,适用于多核处理器

import math
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List

class BM25:
    def __init__(self, documents: List[str], k1: float = 1.5, b: float = 0.75):
        """
        初始化 BM25 模型。

        :param documents: 文档列表,其中每个文档是一个字符串,文档会被分词
        :param k1: BM25 的 k1 参数,控制词频的影响
        :param b: BM25 的 b 参数,控制文档长度的标准化
        """
        self.documents = [doc.split() for doc in documents]
        self.doc_count = len(self.documents)
        self.avgdl = sum(len(doc) for doc in self.documents) / self.doc_count
        self.k1 = k1
        self.b = b
        self.doc_freqs = []
        self.idf = {}
        self.doc_len = []
        self.initialize()

    def initialize(self):
        """
        初始化 BM25 模型所需的数据结构,包括计算 IDF(逆文档频率)和文档词频。
        """
        df = Counter()
        for doc in self.documents:
            self.doc_len.append(len(doc))
            frequencies = Counter(doc)
            self.doc_freqs.append(frequencies)
            for word in frequencies.keys():
                df[word] += 1

        for word, freq in df.items():
            self.idf[word] = math.log(1 + (self.doc_count - freq + 0.5) / (freq + 0.5))

    def _calculate_score_for_doc(self, doc_idx: int, query: List[str]) -> float:
        """
        计算指定文档的 BM25 分数。

        :param doc_idx: 文档索引
        :param query: 查询词列表
        :return: 文档的 BM25 分数
        """
        score = 0.0
        doc = self.documents[doc_idx]
        freq = Counter(doc)
        for word in query:
            if word in self.idf:
                idf = self.idf[word]
                if word in freq:
                    word_freq = freq[word]
                    score += idf * (word_freq * (self.k1 + 1) / (word_freq + self.k1 * (1 - self.b + self.b * len(doc) / self.avgdl)))
        return score

    def get_scores(self, query: List[str]) -> List[float]:
        """
        计算所有文档对于查询的 BM25 分数。

        :param query: 查询词列表
        :return: 所有文档的 BM25 分数列表
        """
        scores = [0.0] * self.doc_count
        with ThreadPoolExecutor() as executor:
            future_to_doc_idx = {executor.submit(self._calculate_score_for_doc, i, query): i for i in range(self.doc_count)}
            for future in as_completed(future_to_doc_idx):
                doc_idx = future_to_doc_idx[future]
                try:
                    scores[doc_idx] = future.result()
                except Exception as exc:
                    print(f'Document {doc_idx} generated an exception: {exc}')
        return scores

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
解释:
  • __init__:初始化BM25模型,计算文档长度和平均文档长度,设置BM25的参数。
  • initialize:计算每个词的IDF值,并记录每个文档的词频。
  • _calculate_score_for_doc:为指定文档计算BM25分数。计算过程涉及查询中每个词的IDF值和文档中词频的加权。
  • get_scores:使用线程池并行计算所有文档的BM25分数,从而提高计算速度。
1.2 稀疏表示

使用稀疏矩阵表示文档和查询的词项,只计算实际出现的词项的BM25分数,以减少计算量和内存占用。

from scipy.sparse import csr_matrix
from collections import Counter
import math
from typing import List

class SparseBM25:
    def __init__(self, documents: List[str], k1: float = 1.5, b: float = 0.75):
        """
        初始化 BM25 模型,使用稀疏矩阵表示文档。

        :param documents: 文档列表,其中每个文档是一个字符串,文档会被分词
        :param k1: BM25 的 k1 参数,控制词频的影响
        :param b: BM25 的 b 参数,控制文档长度的标准化
        """
        self.documents = [doc.split() for doc in documents]
        self.doc_count = len(self.documents)
        self.avgdl = sum(len(doc) for doc in self.documents) / self.doc_count
        self.k1 = k1
        self.b = b
        self.doc_freqs = []
        self.idf = {}
        self.doc_len = []
        self.initialize()

    def initialize(self):
        """
        初始化 BM25 模型所需的数据结构,包括计算 IDF(逆文档频率)和文档词频。
        """
        df = Counter()
        for doc in self.documents:
            self.doc_len.append(len(doc))
            frequencies = Counter(doc)
            self.doc_freqs.append(frequencies)
            for word in frequencies.keys():
                df[word] += 1

        for word, freq in df.items():
            self.idf[word] = math.log(1 + (self.doc_count - freq + 0.5) / (freq + 0.5))

    def _build_sparse_matrix(self):
        """
        构建文档的稀疏矩阵表示。
        """
        rows, cols, data = [], [], []
        word_to_idx = {}
        word_idx = 0
        for doc_idx, doc in enumerate(self.documents):
            freq = Counter(doc)
            for word in freq:
                if word not in word_to_idx:
                    word_to_idx[word] = word_idx
                    word_idx += 1
                rows.append(doc_idx)
                cols.append(word_to_idx[word])
                data.append(freq[word])
        self.sparse_matrix = csr_matrix((data, (rows, cols)), shape=(self.doc_count, len(word_to_idx)))
        self.word_to_idx = word_to_idx

    def get_scores(self, query: List[str]) -> List[float]:
        """
        计算所有文档对于查询的 BM25 分数,使用稀疏矩阵。

        :param query: 查询词列表
        :return: 所有文档的 BM25 分数列表
        """
        if not hasattr(self, 'sparse_matrix'):
            self._build_sparse_matrix()
        scores = [0.0] * self.doc_count
        for word in query:
            if word in self.idf:
                idf = self.idf[word]
                if word in self.word_to_idx:
                    col_idx = self.word_to_idx[word]
                    word_freqs = self.sparse_matrix[:, col_idx].toarray().flatten()
                    for doc_idx, freq in enumerate(word_freqs):
                        doc_len = self.doc_len[doc_idx]
                        score = idf * (freq * (self.k1 + 1) / (freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)))
                        scores[doc_idx] += score
        return scores


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
解释:
  • __init__:初始化稀疏BM25模型,计算文档长度和平均文档长度,设置BM25的参数。
  • initialize:计算每个词的IDF值,并记录每个文档的词频。
  • _build_sparse_matrix:构建稀疏矩阵表示文档的词项,使用scipy.sparse.csr_matrix来节省空间。
  • get_scores:利用稀疏矩阵计算BM25分数,避免了对所有词项的计算,只处理实际出现的词项。
总结
  • 并行处理:通过线程池并行计算BM25分数,提高计算速度。
  • 稀疏表示:使用稀疏矩阵减少内存占用和计算量,仅处理实际出现的词项,提高处理效率。

RediSearch/redisearch-getting-started
RediSearch - Redis Powered Search Engine
微服务-RedisSearch 使用详解
RediSearch/RediSearch

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/1016318
推荐阅读
相关标签
  

闽ICP备14008679号