当前位置:   article > 正文



原文:前沿重器[46] RAG开源项目Qanything源码阅读2-离线文件处理


  • 文件上传
  • 文件读取和切片
  • 索引构造




user_idzzpString用户 id
kb_idKBb1dd58e8485443ce81166d24f6febda7String知识库 id
modesoftString上传模式,soft:知识库内存在同名文件时当前文件不再上传,strong:文件名重复的文件强制上传,默认值为 soft


1.1 客户端


1.1.1 上传文件同步请求示例
import os
import requests

url = "http://{your_host}:8777/api/local_doc_qa/upload_files"
folder_path = "./docx_data"  # 文件所在文件夹,注意是文件夹!!
data = {
    "user_id": "zzp",
    "kb_id": "KB6dae785cdd5d47a997e890521acbe1c9",
		"mode": "soft"

files = []
for root, dirs, file_names in os.walk(folder_path):
    for file_name in file_names:
        if file_name.endswith(".md"):  # 这里只上传后缀是md的文件,请按需修改,支持类型:
            file_path = os.path.join(root, file_name)
            files.append(("files", open(file_path, "rb")))

response = requests.post(url, files=files, data=data)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 发请求用的是通用的requests包。
  • 因为是本地测试,所以使用的就是比较直接的本地文件,直接open就行,文件字段存的是open变量,注意打开方式是rb。


1.1.2 上传文件异步请求示例
import argparse
import os
import sys
import json
import aiohttp
import asyncio
import time
import random
import string

files = []
for root, dirs, file_names in os.walk("./docx_data"):  # 文件夹
    for file_name in file_names:
        if file_name.endswith(".docx"):  # 只上传docx文件
            file_path = os.path.join(root, file_name)
response_times = []

async def send_request(round_, files):
    url = 'http://{your_host}:8777/api/local_doc_qa/upload_files'
    data = aiohttp.FormData()
    data.add_field('user_id', 'zzp')
    data.add_field('kb_id', 'KBf1dafefdb08742f89530acb7e9ed66dd')
    data.add_field('mode', 'soft')

    total_size = 0
    for file_path in files:
        file_size = os.path.getsize(file_path)
        total_size += file_size
        data.add_field('files', open(file_path, 'rb'))
    print('size:', total_size / (1024 * 1024))
        start_time = time.time()
        async with aiohttp.ClientSession() as session:
            async with session.post(url, data=data) as response:
                end_time = time.time()
                response_times.append(end_time - start_time)
                print(f"round_:{round_}, 响应状态码: {response.status}, 响应时间: {end_time - start_time}秒")
    except Exception as e:
        print(f"请求发送失败: {e}")

async def main():
    start_time = time.time()
    num = int(sys.argv[1])  // 一次上传数量,http协议限制一次请求data不能大于100M,请自行控制数量
    round_ = 0
    r_files = files[:num]
    tasks = []
    task = asyncio.create_task(send_request(round_, r_files))
    await asyncio.gather(*tasks)

    end_time = time.time()
    total_requests = len(response_times)
    total_time = end_time - start_time
    qps = total_requests / total_time

if __name__ == '__main__':
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62


1.3 上传文件响应示例

  "code": 200, //状态码
  "msg": "success,后台正在飞速上传文件,请耐心等待", //提示信息
  "data": [
      "file_id": "1b6c0781fb9245b2973504cb031cc2f3", //文件id
      "file_name": "网易有道智云平台产品介绍2023.6.ppt", //文件名
      "status": "gray", //文件状态(red:入库失败-切分失败,green,成功入库,yellow:入库失败-milvus失败,gray:正在入库)
      "bytes": 17925, //文件大小(字节数)
      "timestamp": "202401251056" // 上传时间
      "file_id": "aeaec708c7a34952b7de484fb3374f5d",
      "file_name": "有道知识库问答产品介绍.pptx",
      "status": "gray",
      "bytes": 12928, //文件大小(字节数)
      "timestamp": "202401251056" // 上传时间
  ] //文件列表
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20




for file, file_name in zip(files, file_names):
    if file_name in exist_file_names:
    file_id, msg = local_doc_qa.milvus_summary.add_file(user_id, kb_id, file_name, timestamp)
    debug_logger.info(f"{file_name}, {file_id}, {msg}")
    local_file = LocalFile(user_id, kb_id, file, file_id, file_name, local_doc_qa.embeddings)
    local_doc_qa.milvus_summary.update_file_size(file_id, len(local_file.file_content))
        {"file_id": file_id, "file_name": file_name, "status": "gray", "bytes": len(local_file.file_content),
            "timestamp": timestamp})
asyncio.create_task(local_doc_qa.insert_files_to_milvus(user_id, kb_id, local_files))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12


  • local_doc_qa.milvus_summary.add_file:向指定知识库下面增加文件,这是一个mysql操作,要在mysql数据库内记录在案。

  • local_doc_qa.insert_files_to_milvus:将文档加入到milvus中,当然这里也包含了文件切片、推理向量、存入数据库等一系列操作。


    if exist_file_names:
        msg = f'warning,当前的mode是soft,无法上传同名文件{exist_file_names},如果想强制上传同名文件,请设置mode:strong'
        msg = "success,后台正在飞速上传文件,请耐心等待"
    return sanic_json({"code": 200, "msg": msg, "data": data})
  • 1
  • 2
  • 3
  • 4
  • 5


继续往里面看 qanything_kernel\core\local_doc_qa.py

async def insert_files_to_milvus(self, user_id, kb_id, local_files: List[LocalFile]):
    debug_logger.info(f'insert_files_to_milvus: {kb_id}')
    milvus_kv = self.match_milvus_kb(user_id, [kb_id])
    assert milvus_kv is not None
    success_list = []
    failed_list = []
    for local_file in local_files:
        start = time.time()
            content_length = sum([len(doc.page_content) for doc in local_file.docs])
        except Exception as e:
            error_info = f'split error: {traceback.format_exc()}'
            self.milvus_summary.update_file_status(local_file.file_id, status='red')
        end = time.time()
        self.milvus_summary.update_content_length(local_file.file_id, content_length)
        debug_logger.info(f'split time: {end - start} {len(local_file.docs)}')
        start = time.time()
        except Exception as e:
            error_info = f'embedding error: {traceback.format_exc()}'
            self.milvus_summary.update_file_status(local_file.file_id, status='red')
        end = time.time()
        debug_logger.info(f'embedding time: {end - start} {len(local_file.embs)}')
        self.milvus_summary.update_chunk_size(local_file.file_id, len(local_file.docs))
        ret = await milvus_kv.insert_files(local_file.file_id, local_file.file_name, local_file.file_path,
                                            local_file.docs, local_file.embs)
        insert_time = time.time()
        debug_logger.info(f'insert time: {insert_time - end}')
        if ret:
            self.milvus_summary.update_file_status(local_file.file_id, status='green')
            self.milvus_summary.update_file_status(local_file.file_id, status='yellow')
        f"insert_to_milvus: success num: {len(success_list)}, failed num: {len(failed_list)}")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46


  • local_file.split_file_to_docs:文件的切片,这里还涉及不同类型的文件处理,例如md、图片等。

  • local_file.create_embedding:看名字就知道了,向量化。

  • milvus_kv.insert_files:存入milvus。






class BaseLoader(ABC):
    def load_and_split(
        self, text_splitter: Optional[TextSplitter] = None
    ) -> List[Document]:
        """Load Documents and split into chunks. Chunks are returned as Documents.
        Do not override this method. It should be considered to be deprecated!
            text_splitter: TextSplitter instance to use for splitting documents.
              Defaults to RecursiveCharacterTextSplitter.
            List of Documents.
        if text_splitter is None:
                from langchain_text_splitters import RecursiveCharacterTextSplitter
            except ImportError as e:
                raise ImportError(
                    "Unable to import from langchain_text_splitters. Please specify "
                    "text_splitter or install langchain_text_splitters with "
                    "`pip install -U langchain-text-splitters`."
                ) from e
            _text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
            _text_splitter = text_splitter
        docs = self.load()
        return _text_splitter.split_documents(docs)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28


loader = MyRecursiveUrlLoader(url=self.url)
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(text_splitter=textsplitter)
  • 1
  • 2
  • 3




  • MyRecursiveUrlLoaderURL加载器,即网络链接下的内容加载,内部直接用了langchainWebBaseLoader,网页解析则使用的是BeautifulSoup,算是爬虫技术里的老朋友了,BeautifulSoup主要用于解析代码里暗藏的url,方便进一步查询。

  • UnstructuredFileLoader,直接从langchain里面加载的,from langchain.document_loaders import UnstructuredFileLoader。这个也就只用在了markdown里面(.md)。

  • TextLoader,也是直接从langchain里面加载的from langchain.document_loaders import UnstructuredFileLoader, TextLoader 。这个也就只用在了txt里面(.txt)。

  • UnstructuredPaddlePDFLoader,这个是专门用在pdf文件里的,作者自己写的类,继承自前面提到的UnstructuredFileLoader,但不局限在此,主要重写的是_get_elements函数,内部写了一个函数pdf_ocr_txt,首先用fitz读取pdf每页的图片,然后用ocr_engine来解析(请求orc接口,本项目里用的是一个triton部署的paddleocr服务),最后用unstructured下的一个函数partition_text来完成切片pip install unstructured),当然后续还会有针对中文的综合切片,后面会说。

  • UnstructuredPaddleImageLoader,用来解析图片的工具,对应jpg、png、jpeg后缀文件。同样继承自UnstructuredFileLoader,和PDF不同的是加载部分,图片加载使用的是cv2,加载后和PDF的处理一样,都是走一遍ocr_enginepartition_text

  • UnstructuredWordDocumentLoader用于处理docx文件,来自langchain。

  • xlsx使用的是pandas,值得注意的是engine使用的是openpyxl,另外文件读取后,作者会把内容转为csv,然后用CSVLoader来处理。

  • CSVLoader顾名思义处理的是csv文件,这里用的是csv.DictReader来读取的。

  • UnstructuredPowerPointLoader用于读取PPT,从langchain里面加载的,from langchain.document_loaders import UnstructuredPowerPointLoader

  • UnstructuredEmailLoader用于读取邮件格式的文件.eml,也是从langchain中加载的,from langchain.document_loaders import UnstructuredEmailLoader


2.2.1 qanything_kernel\utils\loader\csv_loader.py
import csv
from io import TextIOWrapper
from typing import Any, Dict, List, Optional, Sequence

from langchain_core.documents import Document

from langchain_community.document_loaders.base import BaseLoader
from langchain_community.document_loaders.helpers import detect_file_encodings

class CSVLoader(BaseLoader):
    """Load a `CSV` file into a list of Documents.

    Each document represents one row of the CSV file. Every row is converted into a
    key/value pair and outputted to a new line in the document's page_content.

    The source for each document loaded from csv is set to the value of the
    `file_path` argument for all documents by default.
    You can override this by setting the `source_column` argument to the
    name of a column in the CSV file.
    The source of each document will then be set to the value of the column
    with the name specified in `source_column`.

    Output Example:
        .. code-block:: txt

            column1: value1
            column2: value2
            column3: value3

    def __init__(
            file_path: str,
            source_column: Optional[str] = None,
            metadata_columns: Sequence[str] = (),
            csv_args: Optional[Dict] = None,
            encoding: Optional[str] = None,
            autodetect_encoding: bool = False,

            file_path: The path to the CSV file.
            source_column: The name of the column in the CSV file to use as the source.
              Optional. Defaults to None.
            metadata_columns: A sequence of column names to use as metadata. Optional.
            csv_args: A dictionary of arguments to pass to the csv.DictReader.
              Optional. Defaults to None.
            encoding: The encoding of the CSV file. Optional. Defaults to None.
            autodetect_encoding: Whether to try to autodetect the file encoding.
        self.file_path = file_path
        self.source_column = source_column
        self.metadata_columns = metadata_columns
        self.encoding = encoding
        self.csv_args = csv_args or {}
        self.autodetect_encoding = autodetect_encoding

    def load(self) -> List[Document]:
        """Load data into document objects."""

        docs = []
            with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
                docs = self.__read_file(csvfile)
        except UnicodeDecodeError as e:
            if self.autodetect_encoding:
                detected_encodings = detect_file_encodings(self.file_path)
                for encoding in detected_encodings:
                        with open(
                                self.file_path, newline="", encoding=encoding.encoding
                        ) as csvfile:
                            docs = self.__read_file(csvfile)
                    except UnicodeDecodeError:
                raise RuntimeError(f"Error loading {self.file_path}") from e
        except Exception as e:
            raise RuntimeError(f"Error loading {self.file_path}") from e

        return docs

    def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
        docs = []
        csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore
        # 初始化一个字典,用于存储每一列最后一次的非空值
        last_non_empty_values = {}
        for i, row in enumerate(csv_reader):
                source = (
                    if self.source_column is not None
                    else self.file_path
            except KeyError:
                raise ValueError(
                    f"Source column '{self.source_column}' not found in CSV file."

            line_contents = []
            for k, v in row.items():
                if k in self.metadata_columns:
                line_contents.append(f"{k.strip()}: {v.strip() if v else last_non_empty_values.get(k, v)}")
                if v:
                    last_non_empty_values[k] = v
            content = '------------------------\n'
            # content += " & ".join(
            #     f"{k.strip()}: {v.strip() if v is not None else v}"
            #     for k, v in row.items()
            #     if k not in self.metadata_columns
            # )
            content += ' & '.join(line_contents)
            content += '\n------------------------'

            metadata = {"source": source, "row": i}
            for col in self.metadata_columns:
                    metadata[col] = row[col]
                except KeyError:
                    raise ValueError(f"Metadata column '{col}' not found in CSV file.")
            doc = Document(page_content=content, metadata=metadata)

        return docs

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
2.2.2 qanything_kernel\utils\loader\image_loader.py
"""Loader that loads image files."""
from typing import List, Callable

from langchain.document_loaders.unstructured import UnstructuredFileLoader
import os
from typing import Union, Any
import cv2
import base64

class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
    """Loader that uses unstructured to load image files, such as PNGs and JPGs."""

    def __init__(
            file_path: Union[str, List[str]],
            ocr_engine: Callable,
            mode: str = "single",
            **unstructured_kwargs: Any,
        """Initialize with file path."""
        self.ocr_engine = ocr_engine
        super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)

    def _get_elements(self) -> List:
        def image_ocr_txt(filepath, dir_path="tmp_files"):
            full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
            if not os.path.exists(full_dir_path):
            filename = os.path.split(filepath)[-1]
            img_np = cv2.imread(filepath)
            h, w, c = img_np.shape
            img_data = {"img64": base64.b64encode(img_np).decode("utf-8"), "height": h, "width": w, "channels": c}
            result = self.ocr_engine(img_data)
            result = [line for line in result if line]

            ocr_result = [i[1][0] for line in result for i in line]
            txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename))
            with open(txt_file_path, 'w', encoding='utf-8') as fout:
            return txt_file_path

        txt_file_path = image_ocr_txt(self.file_path)
        from unstructured.partition.text import partition_text
        return partition_text(filename=txt_file_path, **self.unstructured_kwargs)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
2.2.3 qanything_kernel\utils\loader\my_recursive_url_loader.py
from typing import Iterator, List, Optional, Set
from urllib.parse import urljoin, urldefrag

import requests

from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader

class MyRecursiveUrlLoader(BaseLoader):
    """Loads all child links from a given url."""

    def __init__(
        url: str,
        exclude_dirs: Optional[str] = None,
        max_depth: int = -1
    ) -> None:
        """Initialize with URL to crawl and any subdirectories to exclude.

            url: The URL to crawl.
            exclude_dirs: A list of subdirectories to exclude.

        self.url = url
        self.exclude_dirs = exclude_dirs
        self.max_depth = max_depth

    def get_child_links_recursive(
        self, url: str, depth: int, visited: Optional[Set[str]] = None
    ) -> Iterator[Document]:
        """Recursively get all child links starting with the path of the input URL.

            url: The URL to crawl.
            visited: A set of visited URLs.

        from langchain.document_loaders import WebBaseLoader

            from bs4 import BeautifulSoup
        except ImportError:
            raise ImportError(
                "The BeautifulSoup package is required for the RecursiveUrlLoader."

        # Exclude the root and parent from a list
        visited = set() if visited is None else visited

        if self.max_depth > 0 and depth <= self.max_depth:
            return None

        # Exclude the links that start with any of the excluded directories
        if self.exclude_dirs and any(
            url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs
            return visited

        yield from WebBaseLoader(web_path=url).load()

        # Get all links that are relative to the root of the website
        response = requests.get(url, timeout=60)
        soup = BeautifulSoup(response.text, "html.parser")
        all_links = [urljoin(url, link.get("href")) for link in soup.find_all("a")]
        # Filter children url of current url
        child_links = [link for link in set(all_links) if link.startswith(url)]
        # Remove framents to avoid repititions
        defraged_child_links = [urldefrag(link).url for link in child_links]

        # Store the visited links and recursively visit the children
        for link in set(defraged_child_links):
            # Check all unvisited links
            if link not in visited:
                yield from WebBaseLoader(link).load()
                # If the link is a directory (w/ children) then visit it
                if link.endswith("/"):
                    yield from self.get_child_links_recursive(link, depth+1, visited)

        return visited

    def lazy_load(self) -> Iterator[Document]:
        """Lazy load web pages."""
        return self.get_child_links_recursive(self.url, depth=0)

    def load(self) -> List[Document]:
        """Load web pages."""
        return list(self.lazy_load())

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
2.2.4 qanything_kernel\utils\loader\pdf_loader.py
"""Loader that loads image files."""
from typing import List, Callable

from langchain.document_loaders.unstructured import UnstructuredFileLoader
from unstructured.partition.text import partition_text
import os
import fitz
from tqdm import tqdm
from typing import Union, Any
import numpy as np
import base64

class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
    """Loader that uses unstructured to load image files, such as PNGs and JPGs."""
    def __init__(
        file_path: Union[str, List[str]],
        ocr_engine: Callable,
        mode: str = "single",
        **unstructured_kwargs: Any,
        """Initialize with file path."""
        self.ocr_engine = ocr_engine
        super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs)

    def _get_elements(self) -> List:
        def pdf_ocr_txt(filepath, dir_path="tmp_files"):
            full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
            if not os.path.exists(full_dir_path):
            doc = fitz.open(filepath)
            txt_file_path = os.path.join(full_dir_path, "{}.txt".format(os.path.split(filepath)[-1]))
            img_name = os.path.join(full_dir_path, 'tmp.png')
            with open(txt_file_path, 'w', encoding='utf-8') as fout:
                for i in tqdm(range(doc.page_count)):
                    page = doc.load_page(i)
                    pix = page.get_pixmap()
                    img = np.frombuffer(pix.samples, dtype=np.uint8).reshape((pix.h, pix.w, pix.n))

                    img_data = {"img64": base64.b64encode(img).decode("utf-8"), "height": pix.h, "width": pix.w,
                                "channels": pix.n}
                    result = self.ocr_engine(img_data)
                    result = [line for line in result if line]
                    ocr_result = [i[1][0] for line in result for i in line]
            if os.path.exists(img_name):
            return txt_file_path

        txt_file_path = pdf_ocr_txt(self.file_path)
        return partition_text(filename=txt_file_path, **self.unstructured_kwargs)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
2.2.5 qanything_kernel\core\local_file.py
from qanything_kernel.utils.general_utils import *
from typing import List, Union, Callable
from qanything_kernel.configs.model_config import UPLOAD_ROOT_PATH, SENTENCE_SIZE, ZH_TITLE_ENHANCE
from langchain.docstore.document import Document
from qanything_kernel.utils.loader.my_recursive_url_loader import MyRecursiveUrlLoader
from langchain.document_loaders import UnstructuredFileLoader, TextLoader
from langchain.document_loaders import UnstructuredWordDocumentLoader
from langchain.document_loaders import UnstructuredExcelLoader
from langchain.document_loaders import UnstructuredEmailLoader
from langchain.document_loaders import UnstructuredPowerPointLoader
from qanything_kernel.utils.loader.csv_loader import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from qanything_kernel.utils.custom_log import debug_logger, qa_logger
from qanything_kernel.utils.splitter import ChineseTextSplitter
from qanything_kernel.utils.loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader
from qanything_kernel.utils.splitter import zh_title_enhance
from sanic.request import File
import pandas as pd
import os

text_splitter = RecursiveCharacterTextSplitter(
    separators=["\n", ".", "。", "!", "!", "?", "?", ";", ";", "……", "…", "、", ",", ",", " "],

class LocalFile:
    def __init__(self, user_id, kb_id, file: Union[File, str], file_id, file_name, embedding, is_url=False, in_milvus=False):
        self.user_id = user_id
        self.kb_id = kb_id
        self.file_id = file_id
        self.docs: List[Document] = []
        self.embs = []
        self.emb_infer = embedding
        self.url = None
        self.in_milvus = in_milvus
        self.file_name = file_name
        if is_url:
            self.url = file
            self.file_path = "URL"
            self.file_content = b''
            if isinstance(file, str):
                self.file_path = file
                with open(file, 'rb') as f:
                    self.file_content = f.read()
                upload_path = os.path.join(UPLOAD_ROOT_PATH, user_id)
                file_dir = os.path.join(upload_path, self.file_id)
                os.makedirs(file_dir, exist_ok=True)
                self.file_path = os.path.join(file_dir, self.file_name)
                self.file_content = file.body
            with open(self.file_path, "wb+") as f:
        debug_logger.info(f'success init localfile {self.file_name}')

    def split_file_to_docs(self, ocr_engine: Callable, sentence_size=SENTENCE_SIZE,
        if self.url:
            debug_logger.info("load url: {}".format(self.url))
            loader = MyRecursiveUrlLoader(url=self.url)
            textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
            docs = loader.load_and_split(text_splitter=textsplitter)
        elif self.file_path.lower().endswith(".md"):
            loader = UnstructuredFileLoader(self.file_path, mode="elements")
            docs = loader.load()
        elif self.file_path.lower().endswith(".txt"):
            loader = TextLoader(self.file_path, autodetect_encoding=True)
            texts_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
            docs = loader.load_and_split(texts_splitter)
        elif self.file_path.lower().endswith(".pdf"):
            loader = UnstructuredPaddlePDFLoader(self.file_path, ocr_engine)
            texts_splitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
            docs = loader.load_and_split(texts_splitter)
        elif self.file_path.lower().endswith(".jpg") or self.file_path.lower().endswith(
                ".png") or self.file_path.lower().endswith(".jpeg"):
            loader = UnstructuredPaddleImageLoader(self.file_path, ocr_engine, mode="elements")
            texts_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
            docs = loader.load_and_split(text_splitter=texts_splitter)
        elif self.file_path.lower().endswith(".docx"):
            loader = UnstructuredWordDocumentLoader(self.file_path, mode="elements")
            texts_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
            docs = loader.load_and_split(texts_splitter)
        elif self.file_path.lower().endswith(".xlsx"):
            # loader = UnstructuredExcelLoader(self.file_path, mode="elements")
            docs = []
            xlsx = pd.read_excel(self.file_path, engine='openpyxl', sheet_name=None)
            for sheet in xlsx.keys():
                df = xlsx[sheet]
                df.dropna(how='all', inplace=True)
                csv_file_path = self.file_path[:-5] + '_' + sheet + '.csv'
                df.to_csv(csv_file_path, index=False)
                loader = CSVLoader(csv_file_path, csv_args={"delimiter": ",", "quotechar": '"'})
                docs += loader.load()
        elif self.file_path.lower().endswith(".pptx"):
            loader = UnstructuredPowerPointLoader(self.file_path, mode="elements")
            docs = loader.load()
        elif self.file_path.lower().endswith(".eml"):
            loader = UnstructuredEmailLoader(self.file_path, mode="elements")
            docs = loader.load()
        elif self.file_path.lower().endswith(".csv"):
            loader = CSVLoader(self.file_path, csv_args={"delimiter": ",", "quotechar": '"'})
            docs = loader.load()
            raise TypeError("文件类型不支持,目前仅支持:[md,txt,pdf,jpg,png,jpeg,docx,xlsx,pptx,eml,csv]")
        if using_zh_title_enhance:
            debug_logger.info("using_zh_title_enhance %s", using_zh_title_enhance)
            docs = zh_title_enhance(docs)

        # 重构docs,如果doc的文本长度大于800tokens,则利用text_splitter将其拆分成多个doc
        # text_splitter: RecursiveCharacterTextSplitter
        debug_logger.info(f"before 2nd split doc lens: {len(docs)}")
        docs = text_splitter.split_documents(docs)
        debug_logger.info(f"after 2nd split doc lens: {len(docs)}")

        # 这里给每个docs片段的metadata里注入file_id
        for doc in docs:
            doc.metadata["file_id"] = self.file_id
            doc.metadata["file_name"] = self.url if self.url else os.path.split(self.file_path)[-1]
        write_check_file(self.file_path, docs)
        if docs:
            debug_logger.info('langchain analysis content head: %s', docs[0].page_content[:100])
            debug_logger.info('langchain analysis docs is empty!')
        self.docs = docs

    def create_embedding(self):
        self.embs = self.emb_infer._get_len_safe_embeddings([doc.page_content for doc in self.docs])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130


文件切片作者也是写成了通用的工具,方便调用,而且这个相比各种文件格式,这里的泛用性会更高,毕竟都解析成文本了,这个比较通用ChineseTextSplitter,继承自langchain的from langchain.text_splitter import CharacterTextSplitter,重写后,更符合中文的使用习惯。直接来看源码吧:qanything_kernel\utils\splitter\chinese_text_splitter.py

from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
from qanything_kernel.configs.model_config import SENTENCE_SIZE

class ChineseTextSplitter(CharacterTextSplitter):
    def __init__(self, pdf: bool = False, sentence_size: int = SENTENCE_SIZE, **kwargs):
        self.pdf = pdf
        self.sentence_size = sentence_size

    def split_text1(self, text: str) -> List[str]:
        if self.pdf:
            text = re.sub(r"\n{3,}", "\n", text)
            text = re.sub('\s', ' ', text)
            text = text.replace("\n\n", "")
        sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))')  # del :;
        sent_list = []
        for ele in sent_sep_pattern.split(text):
            if sent_sep_pattern.match(ele) and sent_list:
                sent_list[-1] += ele
            elif ele:
        return sent_list

    def split_text(self, text: str) -> List[str]:   ##此处需要进一步优化逻辑
        if self.pdf:
            text = re.sub(r"\n{3,}", r"\n", text)
            text = re.sub('\s', " ", text)
            text = re.sub("\n\n", "", text)

        text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text)  # 单字符断句符
        text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text)  # 英文省略号
        text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text)  # 中文省略号
        text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text)
        # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
        text = text.rstrip()  # 段尾如果有多余的\n就去掉它
        # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
        ls = [i for i in text.split("\n") if i]
        for ele in ls:
            if len(ele) > self.sentence_size:
                ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele)
                ele1_ls = ele1.split("\n")
                for ele_ele1 in ele1_ls:
                    if len(ele_ele1) > self.sentence_size:
                        ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
                        ele2_ls = ele_ele2.split("\n")
                        for ele_ele2 in ele2_ls:
                            if len(ele_ele2) > self.sentence_size:
                                ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
                                ele2_id = ele2_ls.index(ele_ele2)
                                ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
                                                                                                       ele2_id + 1:]
                        ele_id = ele1_ls.index(ele_ele1)
                        ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]

                id = ls.index(ele)
                ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
        return ls

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61



from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
    separators=["\n", ".", "。", "!", "!", "?", "?", ";", ";", "……", "…", "、", ",", ",", " "],
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6


# 这里给每个docs片段的metadata里注入file_id
for doc in docs:
    doc.metadata["file_id"] = self.file_id
    doc.metadata["file_name"] = self.url if self.url else os.path.split(self.file_path)[-1]
  • 1
  • 2
  • 3
  • 4



核心的代码同样是在qanything_kernel\core\local_doc_qa.py的 insert_files_to_milvus这个函数下,这里面create_embedding (下面代码24行)就是构造向量的过程,在前面的章节(RAG开源项目Qanything源码阅读1-概述+服务)有提及

    async def insert_files_to_milvus(self, user_id, kb_id, local_files: List[LocalFile]):
        debug_logger.info(f'insert_files_to_milvus: {kb_id}')
        milvus_kv = self.match_milvus_kb(user_id, [kb_id])
        assert milvus_kv is not None
        success_list = []
        failed_list = []

        for local_file in local_files:
            start = time.time()
                content_length = sum([len(doc.page_content) for doc in local_file.docs])
            except Exception as e:
                error_info = f'split error: {traceback.format_exc()}'
                self.milvus_summary.update_file_status(local_file.file_id, status='red')
            end = time.time()
            self.milvus_summary.update_content_length(local_file.file_id, content_length)
            debug_logger.info(f'split time: {end - start} {len(local_file.docs)}')
            start = time.time()
            except Exception as e:
                error_info = f'embedding error: {traceback.format_exc()}'
                self.milvus_summary.update_file_status(local_file.file_id, status='red')
            end = time.time()
            debug_logger.info(f'embedding time: {end - start} {len(local_file.embs)}')

            self.milvus_summary.update_chunk_size(local_file.file_id, len(local_file.docs))
            ret = await milvus_kv.insert_files(local_file.file_id, local_file.file_name, local_file.file_path,
                                               local_file.docs, local_file.embs)
            insert_time = time.time()
            debug_logger.info(f'insert time: {insert_time - end}')
            if ret:
                self.milvus_summary.update_file_status(local_file.file_id, status='green')
                self.milvus_summary.update_file_status(local_file.file_id, status='yellow')
            f"insert_to_milvus: success num: {len(success_list)}, failed num: {len(failed_list)}")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

向量化的模型是单独用triton部署的,所以此处是直接请求模型服务获取的(RAG开源项目Qanything源码阅读1-概述+服务4.2 的12行)。

CUDA_VISIBLE_DEVICES=$gpu_id1 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble_embed_rerank --http-port=9000 --grpc-port=9001 --metrics-port=9002 --log-verbose=1 > /workspace/qanything_local/logs/debug_logs/embed_rerank_tritonserver.log 2>&1 &
  • 1


def create_embedding(self):
    self.embs = self.emb_infer._get_len_safe_embeddings([doc.page_content for doc in self.docs])
  • 1
  • 2


import os
import math
import numpy as np
import time
from typing import Optional
import onnxruntime as ort
from tritonclient import utils as client_utils
from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput
from transformers import AutoTokenizer
    "fp32": np.float32,
    "fp16": np.float16,
class EmbeddingClient:
    embed_version = "local_v0.0.1_20230525_6d4019f1559aef84abc2ab8257e1ad4c"
    def __init__(
        server_url: str,
        model_name: str,
        model_version: str,
        tokenizer_path: str,
        resp_wait_s: Optional[float] = None,
        self._server_url = server_url
        self._model_name = model_name
        self._model_version = model_version
        self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
        self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    def get_embedding(self, sentences, max_length=512):
        # Setting up client
        inputs_data = self._tokenizer(sentences, padding=True, truncation=True, max_length=max_length, return_tensors='np')
        inputs_data = {k: v for k, v in inputs_data.items()}
        client = InferenceServerClient(url=self._server_url)
        model_config = client.get_model_config(self._model_name, self._model_version)
        model_metadata = client.get_model_metadata(self._model_name, self._model_version)
        inputs_info = {tm.name: tm for tm in model_metadata.inputs}
        outputs_info = {tm.name: tm for tm in model_metadata.outputs}
        output_names = list(outputs_info)
        outputs_req = [InferRequestedOutput(name_) for name_ in outputs_info]
        infer_inputs = []
        for name_ in inputs_info:
            data = inputs_data[name_]
            infer_input = InferInput(name_, data.shape, inputs_info[name_].datatype)
            target_np_dtype = client_utils.triton_to_np_dtype(inputs_info[name_].datatype)
            data = data.astype(target_np_dtype)
        results = client.infer(
        y_pred = {name_: results.as_numpy(name_) for name_ in output_names}
        embeddings = y_pred["output"][:,0]
        norm_arr = np.linalg.norm(embeddings, axis=1, keepdims=True)
        embeddings_normalized = embeddings / norm_arr
        return embeddings_normalized.tolist()
    def getModelVersion(self):
        return self.embed_version
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 首先可以看到,tokenizer依旧是本服务做的。

  • 服务的请求主要是client负责,triton是一个grpc接口,输入和输出的数据结构参考InferInputInferRequestedOutput

  • 细节,对模型的输出结果,结果作者还做了额外的处理,主要是做了一个归一化,用np.linalg.norm求了二范数(默认),然后想了都除以了这个二范数。

  • 有留意到,对模型的版本,作者有保留,方便进行模型迭代的版本可控性。

完成后,就可以开始灌库了,qanything_kernel\core\local_file.py 的 milvus_kv.insert_files。milvus自己是有开源的库的,即pymilvus,作者自己写了一个完整的类MilvusClient,至于pymilvus具体教程大家可以看:https://zhuanlan.zhihu.com/p/676124465

from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility, \
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
from functools import partial
import time
import copy
from datetime import datetime
from qanything_kernel.utils.custom_log import debug_logger
from langchain.docstore.document import Document
import math
from itertools import groupby
from typing import List

# 混合检索
from .es_client import ElasticsearchClient
from qanything_kernel.configs.model_config import HYBRID_SEARCH

#ret = await milvus_kv.insert_files(local_file.file_id, local_file.file_name, local_file.file_path,
#                                               local_file.docs, local_file.embs)

async def insert_files(self, file_id, file_name, file_path, docs, embs, batch_size=1000):
    debug_logger.info(f'now inser_file {file_name}')
    now = datetime.now()
    timestamp = now.strftime("%Y%m%d%H%M")
    loop = asyncio.get_running_loop()
    contents = [doc.page_content for doc in docs]
    num_docs = len(docs)
    for batch_start in range(0, num_docs, batch_size):
        batch_end = min(batch_start + batch_size, num_docs)
        data = [[] for _ in range(len(self.sess.schema))]
        for idx in range(batch_start, batch_end):
            cont = contents[idx]
            emb = embs[idx]
            chunk_id = f'{file_id}_{idx}'
        # 执行插入操作
            debug_logger.info('Inserting into Milvus...')
            mr = await loop.run_in_executor(
                self.executor, partial(self.partitions[0].insert, data=data))
            debug_logger.info(f'{file_name} {mr}')
        except Exception as e:
            debug_logger.error(f'Milvus insert file_id:{file_id}, file_name:{file_name} failed: {e}')
            return False
    # 混合检索
    if self.hybrid_search:
        debug_logger.info(f'now inser_file for es: {file_name}')
        for batch_start in range(0, num_docs, batch_size):
            batch_end = min(batch_start + batch_size, num_docs)
            data_es = []
            for idx in range(batch_start, batch_end):
                data_es_item = {
                    'file_id': file_id,
                    'content': contents[idx],
                    'metadata': {
                        'file_name': file_name,
                        'file_path': file_path,
                        'chunk_id': f'{file_id}_{idx}',
                        'timestamp': timestamp,
                debug_logger.info('Inserting into es ...')
                mr = await self.client.insert(data=data_es, refresh=batch_end==num_docs)
                debug_logger.info(f'{file_name} {mr}')
            except Exception as e:
                debug_logger.error(f'ES insert file_id: {file_id}\nfile_name: {file_name}\nfailed: {e}')
                return False
    return True
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • milvus使用的是pymilvus工具来读写,其中self.partitions[0].insert就是用存储数据的,此处可以注意到data内有很多不同的字段。

  • 执行代码使用的是loop.run_in_executor,有留意到,在MilvusClient内有一个self.executor,这个的定义在这个类的__init__内,self.executor = ThreadPoolExecutor(max_workers=10),这里新建了一个线程池,新技能get。

  • 下方是ES的数据灌入。个人感觉,这个ES数据处理写在这个位置并不是很合适,应该单独出来处理,毕竟混合代码不太好看到。

MilvusClient 类
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility, \
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
from functools import partial
import time
import copy
from datetime import datetime
from qanything_kernel.utils.custom_log import debug_logger
from langchain.docstore.document import Document
import math
from itertools import groupby
from typing import List

# 混合检索
from .es_client import ElasticsearchClient
from qanything_kernel.configs.model_config import HYBRID_SEARCH

class MilvusFailed(Exception):

class MilvusClient:
    def __init__(self, mode, user_id, kb_ids, *, threshold=1.1, client_timeout=3):
        self.user_id = user_id
        self.kb_ids = kb_ids
        if mode == 'local':
            self.host = MILVUS_HOST_LOCAL
            self.host = MILVUS_HOST_ONLINE
        self.port = MILVUS_PORT
        self.user = MILVUS_USER
        self.password = MILVUS_PASSWORD
        self.db_name = MILVUS_DB_NAME
        self.client_timeout = client_timeout
        self.threshold = threshold
        self.sess: Collection = None
        self.partitions: List[Partition] = []
        self.executor = ThreadPoolExecutor(max_workers=10)
        self.top_k = VECTOR_SEARCH_TOP_K
        self.search_params = {"metric_type": "L2", "params": {"nprobe": 256}}
        if mode == 'local':
            self.create_params = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 2048}}
            self.create_params = {"metric_type": "L2", "index_type": "GPU_IVF_FLAT", "params": {"nlist": 2048}}
        self.last_init_ts = time.time() - 100  # 减去100保证最初的init不会被拒绝

        # 混合检索
        self.hybrid_search = HYBRID_SEARCH
        if self.hybrid_search:
            self.index_name = [f"{user_id}++{kb_id}" for kb_id in kb_ids]
            self.client = ElasticsearchClient(index_name=self.index_name)

    def fields(self):
        fields = [
            FieldSchema(name='chunk_id', dtype=DataType.VARCHAR, max_length=64, is_primary=True),
            FieldSchema(name='file_id', dtype=DataType.VARCHAR, max_length=64),
            FieldSchema(name='file_name', dtype=DataType.VARCHAR, max_length=640),
            FieldSchema(name='file_path', dtype=DataType.VARCHAR, max_length=640),
            FieldSchema(name='timestamp', dtype=DataType.VARCHAR, max_length=64),
            FieldSchema(name='content', dtype=DataType.VARCHAR, max_length=4000),
            FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=768)
        return fields

    def parse_batch_result(self, batch_result):
        new_result = []
        for batch_idx, result in enumerate(batch_result):
            new_cands = []
            result.sort(key=lambda x: x.score)
            valid_results = [cand for cand in result if cand.score <= self.threshold]
            if len(valid_results) == 0:  # 如果没有合适的结果,就取topk
                valid_results = result[:self.top_k]
            for cand_i, cand in enumerate(valid_results):
                doc = Document(page_content=cand.entity.get('content'),
                               metadata={"score": cand.score, "file_id": cand.entity.get('file_id'),
                                         "file_name": cand.entity.get('file_name'),
                                         "chunk_id": cand.entity.get('chunk_id')})
            # csv和xlsx文件不做expand_cand_docs
            need_expand, not_need_expand = [], []
            for doc in new_cands:
                if doc.metadata['file_name'].lower().split('.')[-1] in ['csv', 'xlsx']:
                    doc.metadata["kernel"] = doc.page_content
            expand_res = self.expand_cand_docs(need_expand)
            new_cands = not_need_expand + expand_res
        return new_result
    # 混合检索
    def parse_es_batch_result(self, es_records, milvus_records):
        milvus_records_seen = set()
        for result in milvus_records:
            result.sort(key=lambda x: x.score)
            flag = True
            for cand in result:
                if cand.score <= self.threshold:
                    flag = False
            if flag:
                for cand in result[:self.top_k]:
        new_cands = []
        for es_record in es_records:
            if es_record['id'] not in milvus_records_seen:
                doc = Document(page_content=es_record['content'],
                               metadata={"score": es_record['score'], "file_id": es_record['file_id'],
                                         "file_name": es_record['metadata']['file_name'],
                                         "chunk_id": es_record['metadata']['chunk_id']})
        # csv和xlsx文件不做expand_cand_docs
        need_expand, not_need_expand = [], []
        for doc in new_cands:
            if doc.metadata['file_name'].lower().split('.')[-1] in ['csv', 'xlsx']:
                doc.metadata["kernel"] = doc.page_content
        expand_res = self.expand_cand_docs(need_expand)
        new_result = not_need_expand + expand_res

        return new_result

    def output_fields(self):
        return ['chunk_id', 'file_id', 'file_name', 'file_path', 'timestamp', 'content']

    def init(self):
            connections.connect(host=self.host, port=self.port, user=self.user,
                                password=self.password, db_name=self.db_name)  # timeout=3 [cannot set]
            if utility.has_collection(self.user_id):
                self.sess = Collection(self.user_id)
                debug_logger.info(f'collection {self.user_id} exists')
                schema = CollectionSchema(self.fields)
                debug_logger.info(f'create collection {self.user_id} {schema}')
                self.sess = Collection(self.user_id, schema)
                self.sess.create_index(field_name="embedding", index_params=self.create_params)
            for kb_id in self.kb_ids:
                if not self.sess.has_partition(kb_id):
            self.partitions = [Partition(self.sess, kb_id) for kb_id in self.kb_ids]
            debug_logger.info('partitions: %s', self.kb_ids)
        except Exception as e:

    def __search_emb_sync(self, embs, expr='', top_k=None, client_timeout=None, queries=None):
        if not top_k:
            top_k = self.top_k
        milvus_records = self.sess.search(data=embs, partition_names=self.kb_ids, anns_field="embedding",
                                          param=self.search_params, limit=top_k,
                                          output_fields=self.output_fields, expr=expr, timeout=client_timeout)
        milvus_records_proc = self.parse_batch_result(milvus_records)
        # debug_logger.info(milvus_records)

        # 混合检索
        if self.hybrid_search:
            es_records = self.client.search(queries)
            es_records_proc = self.parse_es_batch_result(es_records, milvus_records)

        return milvus_records_proc

    def search_emb_async(self, embs, expr='', top_k=None, client_timeout=None, queries=None):
        if not top_k:
            top_k = self.top_k
        # 将search_emb_sync函数放入线程池中运行
        future = self.executor.submit(self.__search_emb_sync, embs, expr, top_k, client_timeout, queries)
        return future.result()

    def query_expr_async(self, expr, output_fields=None, client_timeout=None):
        if client_timeout is None:
            client_timeout = self.client_timeout
        if not output_fields:
            output_fields = self.output_fields
        future = self.executor.submit(
            partial(self.sess.query, partition_names=self.kb_ids, output_fields=output_fields, expr=expr,
        return future.result()

    async def insert_files(self, file_id, file_name, file_path, docs, embs, batch_size=1000):
        debug_logger.info(f'now inser_file {file_name}')
        now = datetime.now()
        timestamp = now.strftime("%Y%m%d%H%M")
        loop = asyncio.get_running_loop()
        contents = [doc.page_content for doc in docs]
        num_docs = len(docs)
        for batch_start in range(0, num_docs, batch_size):
            batch_end = min(batch_start + batch_size, num_docs)
            data = [[] for _ in range(len(self.sess.schema))]

            for idx in range(batch_start, batch_end):
                cont = contents[idx]
                emb = embs[idx]
                chunk_id = f'{file_id}_{idx}'

            # 执行插入操作
                debug_logger.info('Inserting into Milvus...')
                mr = await loop.run_in_executor(
                    self.executor, partial(self.partitions[0].insert, data=data))
                debug_logger.info(f'{file_name} {mr}')
            except Exception as e:
                debug_logger.error(f'Milvus insert file_id:{file_id}, file_name:{file_name} failed: {e}')
                return False

        # 混合检索
        if self.hybrid_search:
            debug_logger.info(f'now inser_file for es: {file_name}')
            for batch_start in range(0, num_docs, batch_size):
                batch_end = min(batch_start + batch_size, num_docs)
                data_es = []
                for idx in range(batch_start, batch_end):
                    data_es_item = {
                        'file_id': file_id,
                        'content': contents[idx],
                        'metadata': {
                            'file_name': file_name,
                            'file_path': file_path,
                            'chunk_id': f'{file_id}_{idx}',
                            'timestamp': timestamp,

                    debug_logger.info('Inserting into es ...')
                    mr = await self.client.insert(data=data_es, refresh=batch_end==num_docs)
                    debug_logger.info(f'{file_name} {mr}')
                except Exception as e:
                    debug_logger.error(f'ES insert file_id: {file_id}\nfile_name: {file_name}\nfailed: {e}')
                    return False

        return True

    def delete_collection(self):
        # 混合检索
        if self.hybrid_search:
            index_name_delete = []
            for index_name in self.client.indices.get_alias().keys():
                if index_name.startswith(f"{self.user_id}++"):

    def delete_partition(self, partition_name):
        part = Partition(self.sess, partition_name)
        # 混合检索
        if self.hybrid_search:
            index_name_delete = []
            if isinstance(partition_name, str):
            elif isinstance(partition_name, list) and isinstance(partition_name[0], str):
                for kb_id in partition_name:
                debug_logger.info(f"##ES## - kb_ids not valid: {partition_name}")
            debug_logger.info(f"##ES## - success delete kb_ids: {partition_name}")

    def delete_files(self, files_id):
        self.sess.delete(expr=f"file_id in {files_id}")
        debug_logger.info('milvus delete files_id: %s', files_id)
        # 混合检索
        if self.hybrid_search:
            es_records = self.client.search(files_id, field='file_id')
            delete_index_ids = {}
            for record in es_records:
                if record['index'] not in delete_index_ids:
                    delete_index_ids[record['index']] = []
            for index, ids in delete_index_ids.items():
                self.client.delete_chunks(index_name=index, ids=ids)
            debug_logger.info(f"##ES## - success delete files_id: {files_id}")

    def get_files(self, files_id):
        res = self.query_expr_async(expr=f"file_id in {files_id}", output_fields=["file_id"])
        valid_ids = [result['file_id'] for result in res]
        return valid_ids

    def seperate_list(self, ls: List[int]) -> List[List[int]]:
        lists = []
        ls1 = [ls[0]]
        for i in range(1, len(ls)):
            if ls[i - 1] + 1 == ls[i]:
                ls1 = [ls[i]]
        return lists

    def process_group(self, group):
        new_cands = []
        # 对每个分组按照chunk_id进行排序
        group.sort(key=lambda x: int(x.metadata['chunk_id'].split('_')[-1]))
        id_set = set()
        file_id = group[0].metadata['file_id']
        file_name = group[0].metadata['file_name']
        group_scores_map = {}
        # 先找出该文件所有需要搜索的chunk_id
        cand_chunks_set = set()  # 使用集合而不是列表
        for cand_doc in group:
            current_chunk_id = int(cand_doc.metadata['chunk_id'].split('_')[-1])
            group_scores_map[current_chunk_id] = cand_doc.metadata['score']
            # 使用 set comprehension 一次性生成区间内所有可能的 chunk_id
            chunk_ids = {file_id + '_' + str(i) for i in range(current_chunk_id - 200, current_chunk_id + 200)}
            # 更新 cand_chunks_set 集合

        cand_chunks = list(cand_chunks_set)

        group_relative_chunks = self.query_expr_async(expr=f"file_id == \"{file_id}\" and chunk_id in {cand_chunks}",
                                                      output_fields=["chunk_id", "content"])
        group_chunk_map = {int(item['chunk_id'].split('_')[-1]): item['content'] for item in group_relative_chunks}
        group_file_chunk_num = list(group_chunk_map.keys())
        for cand_doc in group:
            current_chunk_id = int(cand_doc.metadata['chunk_id'].split('_')[-1])
            doc = copy.deepcopy(cand_doc)
            docs_len = len(doc.page_content)
            for k in range(1, 200):
                break_flag = False
                for expand_index in [current_chunk_id + k, current_chunk_id - k]:
                    if expand_index in group_file_chunk_num:
                        merge_content = group_chunk_map[expand_index]
                        if docs_len + len(merge_content) > CHUNK_SIZE:
                            break_flag = True
                            docs_len += len(merge_content)
                if break_flag:

        id_list = sorted(list(id_set))
        id_lists = self.seperate_list(id_list)
        for id_seq in id_lists:
                for id in id_seq:
                    if id == id_seq[0]:
                        doc = Document(page_content=group_chunk_map[id],
                                    metadata={"score": 0, "file_id": file_id,
                                                "file_name": file_name})
                        doc.page_content += " " + group_chunk_map[id]
                doc_score = min([group_scores_map[id] for id in id_seq if id in group_scores_map])
                doc.metadata["score"] = float(format(1 - doc_score / math.sqrt(2), '.4f'))
                doc.metadata["kernel"] = '|'.join([group_chunk_map[id] for id in id_seq if id in group_scores_map])
            except Exception as e:
                debug_logger.error(f"process_group error: {e}. maybe chunks in ES not exists in Milvus. Please delete the file and upload again.")
        return new_cands

    def expand_cand_docs(self, cand_docs):
        cand_docs = sorted(cand_docs, key=lambda x: x.metadata['file_id'])
        # 按照file_id进行分组
        m_grouped = [list(group) for key, group in groupby(cand_docs, key=lambda x: x.metadata['file_id'])]
        debug_logger.info('milvus group number: %s', len(m_grouped))

        with ThreadPoolExecutor(max_workers=10) as executor:
            futures = []
            for group in m_grouped:
                if not group:
                future = executor.submit(self.process_group, group)

            new_cands = []
            for future in as_completed(futures):
                result = future.result()
                if result is not None:
            return new_cands

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398

Sanic 是什么?怎么使用?一文带你快速上手 Sanic
aiohttp 官方文档:Welcome to AIOHTTP — aiohttp 3.8.6 documentation
Python asyncio 文档:asyncio — Asynchronous I/O — Python 3.12.0 documentation
掌握异步网络编程利器:Python aiohttp使用教程及代码示例
正排索引 vs 倒排索引 - 搜索引擎具体原理
ES高频面试问题:一张图带你读懂 Elasticsearch 中“正排索引(正向索引)”和“倒排索引(反向索引)”区别

