当前位置:   article > 正文

前沿重器[46] RAG开源项目Qanything源码阅读2-离线文件处理

qanything 开源rag框架对比

前沿重器

栏目主要给大家分享各种大厂、顶会的论文和分享,从中抽取关键精华的部分和大家分享,和大家一起把握前沿技术。具体介绍:仓颉专项:飞机大炮我都会,利器心法我还有。(算起来,专项启动已经是20年的事了!)

2023年文章合集发布了!在这里:又添十万字-CS的陋室2023年文章合集来袭

往期回顾

书接上文,最近选了一个开源的RAG项目进行进一步学习:https://github.com/netease-youdao/QAnything,后续一连几篇,会分几篇,从我的角度,给大家介绍这个项目,预计的目录如下:

本期是离线的文件处理,即对多种不同的文件进行详细的阐述,之前我也有写过类似的文章(心法利器[110] | 知识文档处理和使用流程),不过没落到代码层面,这次借着源码阅读的机会,正好介绍一下:

  • 文件上传。

  • 文件读取和切片。

  • 索引构造。

提前说明,这里忽略了大量的业务代码,聚焦在文件处理和相关算法本身,如新建用户、知识库、文件删除,会有选择的忽略,有需要的可以参考我在文中的思路,在代码里找到对应的位置。

文件上传

文件上传是指将文件从前端传到后端的流程,这个流程的工作在docs\API.md有提到。首先是接口字段:

参数名参数值是否必填参数类型描述说明
files文件二进制File需要上传的文件,可多选,目前仅支持[md,txt,pdf,jpg,png,jpeg,docx,xlsx,pptx,eml,csv]
user_idzzpString用户 id
kb_idKBb1dd58e8485443ce81166d24f6febda7String知识库 id
modesoftString上传模式,soft:知识库内存在同名文件时当前文件不再上传,strong:文件名重复的文件强制上传,默认值为 soft

至于文件的上传,作者给出了两种模式,分别是同步和异步。

客户端

客户端只需要请求服务即可,这里穿插一下同步异步请求,以及文件上传的细节,这个直接参考源码就好了,首先是同步的请求源码:

  1. import os
  2. import requests
  3. url = "http://{your_host}:8777/api/local_doc_qa/upload_files"
  4. folder_path = "./docx_data"  # 文件所在文件夹,注意是文件夹!!
  5. data = {
  6.     "user_id""zzp",
  7.     "kb_id""KB6dae785cdd5d47a997e890521acbe1c9",
  8.  "mode""soft"
  9. }
  10. files = []
  11. for root, dirs, file_names in os.walk(folder_path):
  12.     for file_name in file_names:
  13.         if file_name.endswith(".md"):  # 这里只上传后缀是md的文件,请按需修改,支持类型:
  14.             file_path = os.path.join(root, file_name)
  15.             files.append(("files", open(file_path, "rb")))
  16. response = requests.post(url, files=files, data=data)
  17. print(response.text)
  • 发请求用的是通用的requests包。

  • 因为是本地测试,所以使用的就是比较直接的本地文件,直接open就行,文件字段存的是open变量,注意打开方式是rb

至于异步,则会会复杂一些。

  1. import argparse
  2. import os
  3. import sys
  4. import json
  5. import aiohttp
  6. import asyncio
  7. import time
  8. import random
  9. import string
  10. files = []
  11. for root, dirs, file_names in os.walk("./docx_data"):  # 文件夹
  12.     for file_name in file_names:
  13.         if file_name.endswith(".docx"):  # 只上传docx文件
  14.             file_path = os.path.join(root, file_name)
  15.             files.append(file_path)
  16. print(len(files))
  17. response_times = []
  18. async def send_request(round_, files):
  19.     print(len(files))
  20.     url = 'http://{your_host}:8777/api/local_doc_qa/upload_files'
  21.     data = aiohttp.FormData()
  22.     data.add_field('user_id''zzp')
  23.     data.add_field('kb_id''KBf1dafefdb08742f89530acb7e9ed66dd')
  24.     data.add_field('mode''soft')
  25.     total_size = 0
  26.     for file_path in files:
  27.         file_size = os.path.getsize(file_path)
  28.         total_size += file_size
  29.         data.add_field('files', open(file_path, 'rb'))
  30.     print('size:', total_size / (1024 * 1024))
  31.     try:
  32.         start_time = time.time()
  33.         async with aiohttp.ClientSession() as session:
  34.             async with session.post(url, data=data) as response:
  35.                 end_time = time.time()
  36.                 response_times.append(end_time - start_time)
  37.                 print(f"round_:{round_}, 响应状态码: {response.status}, 响应时间: {end_time - start_time}秒")
  38.     except Exception as e:
  39.         print(f"请求发送失败: {e}")
  40. async def main():
  41.     start_time = time.time()
  42.     num = int(sys.argv[1])  // 一次上传数量,http协议限制一次请求data不能大于100M,请自行控制数量
  43.     round_ = 0
  44.     r_files = files[:num]
  45.     tasks = []
  46.     task = asyncio.create_task(send_request(round_, r_files))
  47.     tasks.append(task)
  48.     await asyncio.gather(*tasks)
  49.     print(f"请求完成")
  50.     end_time = time.time()
  51.     total_requests = len(response_times)
  52.     total_time = end_time - start_time
  53.     qps = total_requests / total_time
  54.     print(f"total_time:{total_time}")
  55. if __name__ == '__main__':
  56.     asyncio.run(main())

请求用的是aiohttp,而且使用的是python的协程,即asyncio一套的python技术,具体细节可以参考这篇博客:https://blog.csdn.net/m0_68949064/article/details/132805165。协程在高密度的http请求下,能有效提升CPU的使用率,提升综合性能,毕竟在请求等待过程,可以做很多别的事,就避免CPU空跑了。

服务端

服务端则比较复杂了,文件上传后要经过大量的校验,并且需要返回最终的处理结果。

文件上传的接口是/api/local_doc_qa/upload_files,我们可以在handlers.py里面找到,排除掉一些校验代码,handlers里面的核心代码是这段(upload_files函数下):

  1. for file, file_name in zip(files, file_names):
  2.     if file_name in exist_file_names:
  3.         continue
  4.     file_id, msg = local_doc_qa.milvus_summary.add_file(user_id, kb_id, file_name, timestamp)
  5.     debug_logger.info(f"{file_name}, {file_id}, {msg}")
  6.     local_file = LocalFile(user_id, kb_id, file, file_id, file_name, local_doc_qa.embeddings)
  7.     local_files.append(local_file)
  8.     local_doc_qa.milvus_summary.update_file_size(file_id, len(local_file.file_content))
  9.     data.append(
  10.         {"file_id": file_id, "file_name": file_name, "status""gray""bytes"len(local_file.file_content),
  11.             "timestamp": timestamp})
  12. asyncio.create_task(local_doc_qa.insert_files_to_milvus(user_id, kb_id, local_files))

这里面的几个关键的函数:

  • local_doc_qa.milvus_summary.add_file:向指定知识库下面增加文件,这是一个mysql操作,要在mysql数据库内记录在案。

  • local_doc_qa.insert_files_to_milvus:将文档加入到milvus中,当然这里也包含了文件切片、推理向量、存入数据库等一系列操作。

回到服务,这里最终还是会收集各种处理的信息,最终以json形式形式返回,这里包括状态码、返回信息以及必要的数据信息(例如文件id、上传后的文件名、更新时间等)

return sanic_json({"code"200"msg": msg, "data": data})

文件处理核心流程

继续往里面看,这个函数的代码不是很长,我直接放了:

  1. async def insert_files_to_milvus(self, user_id, kb_id, local_files: List[LocalFile]):
  2.     debug_logger.info(f'insert_files_to_milvus: {kb_id}')
  3.     milvus_kv = self.match_milvus_kb(user_id, [kb_id])
  4.     assert milvus_kv is not None
  5.     success_list = []
  6.     failed_list = []
  7.     for local_file in local_files:
  8.         start = time.time()
  9.         try:
  10.             local_file.split_file_to_docs(self.get_ocr_result)
  11.             content_length = sum([len(doc.page_content) for doc in local_file.docs])
  12.         except Exception as e:
  13.             error_info = f'split error: {traceback.format_exc()}'
  14.             debug_logger.error(error_info)
  15.             self.milvus_summary.update_file_status(local_file.file_id, status='red')
  16.             failed_list.append(local_file)
  17.             continue
  18.         end = time.time()
  19.         self.milvus_summary.update_content_length(local_file.file_id, content_length)
  20.         debug_logger.info(f'split time: {end - start} {len(local_file.docs)}')
  21.         start = time.time()
  22.         try:
  23.             local_file.create_embedding()
  24.         except Exception as e:
  25.             error_info = f'embedding error: {traceback.format_exc()}'
  26.             debug_logger.error(error_info)
  27.             self.milvus_summary.update_file_status(local_file.file_id, status='red')
  28.             failed_list.append(local_file)
  29.             continue
  30.         end = time.time()
  31.         debug_logger.info(f'embedding time: {end - start} {len(local_file.embs)}')
  32.         self.milvus_summary.update_chunk_size(local_file.file_id, len(local_file.docs))
  33.         ret = await milvus_kv.insert_files(local_file.file_id, local_file.file_name, local_file.file_path,
  34.                                             local_file.docs, local_file.embs)
  35.         insert_time = time.time()
  36.         debug_logger.info(f'insert time: {insert_time - end}')
  37.         if ret:
  38.             self.milvus_summary.update_file_status(local_file.file_id, status='green')
  39.             success_list.append(local_file)
  40.         else:
  41.             self.milvus_summary.update_file_status(local_file.file_id, status='yellow')
  42.             failed_list.append(local_file)
  43.     debug_logger.info(
  44.         f"insert_to_milvus: success num: {len(success_list)}, failed num: {len(failed_list)}")

除开各种校验和数据的同步更新,主要经历的是这几个流程:

  • local_file.split_file_to_docs:文件的切片,这里还涉及不同类型的文件处理,例如md、图片等。

  • local_file.create_embedding:看名字就知道了,向量化。

  • milvus_kv.insert_files:存入milvus。

这就是文件上传后核心要经历的4个流程,即文件读取、文件切片、向量化和入库,接下来我会逐个展开讲。

文件读取和切片

文件读取和切片在代码里有不少是混合的,所以我也合在一起说了。在代码里,我们能看到,他们目前支持的是这几种格式:md,txt,pdf,jpg,png,jpeg,docx,xlsx,pptx,eml,csv,另外还有一个基于url的网页,大概就是这几块的内容,代码里对这几个类型都提供了处理代码,我来逐步解析。

load_and_split

在开始之前,必须了解一下文件读取的这基类BaseLoader,这里对加载、切分都有详细的预定义。这里向大家关注的点只有一个,就是load_and_split,我只把有关的部分放出来,这是一个支持在自定义好加载组件和切片组建后,一条龙使用的函数,注意这个BaseLoader是在langchain_core里的,不是在Qanything项目里的。

  1. class BaseLoader(ABC):
  2.     def load_and_split(
  3.         self, text_splitter: Optional[TextSplitter] = None
  4.     ) -> List[Document]:
  5.         """Load Documents and split into chunks. Chunks are returned as Documents.
  6.         Do not override this method. It should be considered to be deprecated!
  7.         Args:
  8.             text_splitter: TextSplitter instance to use for splitting documents.
  9.               Defaults to RecursiveCharacterTextSplitter.
  10.         Returns:
  11.             List of Documents.
  12.         """
  13.         if text_splitter is None:
  14.             try:
  15.                 from langchain_text_splitters import RecursiveCharacterTextSplitter
  16.             except ImportError as e:
  17.                 raise ImportError(
  18.                     "Unable to import from langchain_text_splitters. Please specify "
  19.                     "text_splitter or install langchain_text_splitters with "
  20.                     "`pip install -U langchain-text-splitters`."
  21.                 ) from e
  22.             _text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
  23.         else:
  24.             _text_splitter = text_splitter
  25.         docs = self.load()
  26.         return _text_splitter.split_documents(docs)

有这个基类后,只需要继承这个积累就能写自己的加载器了,至于文档切分器,则可以在load_and_split使用的时候传进去,例如这样:

  1. loader = MyRecursiveUrlLoader(url=self.url)
  2. textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
  3. docs = loader.load_and_split(text_splitter=textsplitter)

MyRecursiveUrlLoader是URL加载器(具体后面会讲),初始化以后,再定义一个中文的切分器ChineseTextSplitter(具体后面也会讲),然后直接用loader.load_and_split(text_splitter=textsplitter)即可把加载、切片都给搞定了。

下面就来分开把加载和切片两者的操作讲一遍。

文件读取

在这个基类下,根据不同需要,会有各种不一样的加载器,用于应对多种不同的格式,自定义的加载器直接从BaseLoader继承即可。

  • MyRecursiveUrlLoader,URL加载器,即网络链接下的内容加载,内部直接用了langchain的WebBaseLoader,网页解析则使用的是BeautifulSoup,算是爬虫技术里的老朋友了,BeautifulSoup主要用于解析代码里暗藏的url,方便进一步查询。

  • UnstructuredFileLoader,直接从langchain里面加载的,from langchain.document_loaders import UnstructuredFileLoader。这个也就只用在了markdown里面(.md)。

  • TextLoader,也是直接从langchain里面加载的from langchain.document_loaders import UnstructuredFileLoader, TextLoader 。这个也就只用在了txt里面(.txt)。

  • UnstructuredPaddlePDFLoader,这个是专门用在pdf文件里的,作者自己写的类,继承自前面提到的UnstructuredFileLoader,但不局限在此,主要重写的是_get_elements函数,内部写了一个函数pdf_ocr_txt,首先用fitz读取pdf每页的图片,然后用ocr_engine来解析(请求orc接口,本项目里用的是一个triton部署的paddleocr服务),最后用unstructured下的一个函数partition_text来完成切片(pip install unstructured),当然后续还会有针对中文的综合切片,后面会说。

  • UnstructuredPaddleImageLoader,用来解析图片的工具,对应jpg、png、jpeg后缀文件。同样继承自UnstructuredFileLoader,和PDF不同的是加载部分,图片加载使用的是cv2,加载后和PDF的处理一样,都是走一遍ocr_enginepartition_text

  • UnstructuredWordDocumentLoader用于处理docx文件,来自langchain。

  • xlsx使用的是pandas,值得注意的是engine使用的是openpyxl,另外文件读取后,作者会把内容转为csv,然后用CSVLoader来处理。

  • CSVLoader顾名思义处理的是csv文件,这里用的是csv.DictReader来读取的。

  • UnstructuredPowerPointLoader用于读取PPT,从langchain里面加载的,from langchain.document_loaders import UnstructuredPowerPointLoader

  • UnstructuredEmailLoader用于读取邮件格式的文件.eml,也是从langchain中加载的,from langchain.document_loaders import UnstructuredEmailLoader

至此,所有支持的文件加载都在这里了,这些文件加载都挺有借鉴意义的,后续在做自己的RAG系统的过程中,也可以考虑直接使用。

文件切片

文件切片作者也是写成了通用的工具,方便调用,而且这个相比各种文件格式,这里的泛用性会更高,毕竟都解析成文本了,这个比较通用ChineseTextSplitter,继承自langchain的from langchain.text_splitter import CharacterTextSplitter,重写后,更符合中文的使用习惯。直接来看源码吧。

  1. class ChineseTextSplitter(CharacterTextSplitter):
  2.     def __init__(self, pdf: bool = False, sentence_size: int = SENTENCE_SIZE, **kwargs):
  3.         super().__init__(**kwargs)
  4.         self.pdf = pdf
  5.         self.sentence_size = sentence_size
  6.     def split_text1(self, text: str) -> List[str]:
  7.         if self.pdf:
  8.             text = re.sub(r"\n{3,}""\n", text)
  9.             text = re.sub('\s'' ', text)
  10.             text = text.replace("\n\n""")
  11.         sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))')  # del :;
  12.         sent_list = []
  13.         for ele in sent_sep_pattern.split(text):
  14.             if sent_sep_pattern.match(ele) and sent_list: 
  15.                 sent_list[-1] += ele
  16.             elif ele:
  17.                 sent_list.append(ele)
  18.         return sent_list
  19.     def split_text(self, text: str) -> List[str]:   ##此处需要进一步优化逻辑
  20.         if self.pdf:
  21.             text = re.sub(r"\n{3,}", r"\n", text)
  22.             text = re.sub('\s'" ", text)
  23.             text = re.sub("\n\n""", text)
  24.         text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text)  # 单字符断句符
  25.         text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text)  # 英文省略号
  26.         text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text)  # 中文省略号
  27.         text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text)
  28.         # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
  29.         text = text.rstrip()  # 段尾如果有多余的\n就去掉它
  30.         # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
  31.         ls = [i for i in text.split("\n"if i]
  32.         for ele in ls:
  33.             if len(ele) > self.sentence_size:
  34.                 ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele)
  35.                 ele1_ls = ele1.split("\n")
  36.                 for ele_ele1 in ele1_ls:
  37.                     if len(ele_ele1) > self.sentence_size:
  38.                         ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
  39.                         ele2_ls = ele_ele2.split("\n")
  40.                         for ele_ele2 in ele2_ls:
  41.                             if len(ele_ele2) > self.sentence_size:
  42.                                 ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
  43.                                 ele2_id = ele2_ls.index(ele_ele2)
  44.                                 ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n"if i] + ele2_ls[
  45.                                                                                                        ele2_id + 1:]
  46.                         ele_id = ele1_ls.index(ele_ele1)
  47.                         ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]
  48.                 id = ls.index(ele)
  49.                 ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
  50.         return ls

实际使用的应该是split_text,不带1那个,这里涉及了很多逻辑和替换,主要都是为了做句子片段的划分,这里的正则大家也可以多多了解和尝试。

在此基础上,都会再过第二次切分,这次切分旨在对长度太长(800tokens+)的进行进一步切分,此处使用的是langchain的RecursiveCharacterTextSplitter

  1. from langchain.text_splitter import RecursiveCharacterTextSplitter
  2. text_splitter = RecursiveCharacterTextSplitter(
  3.     separators=["\n"".""。""!""!""?""?"";"";""……""…""、"","","" "],
  4.     chunk_size=400,
  5.     length_function=num_tokens,
  6. )

后面,为了确保信息的存储的可查性(检索这段话后,能找到对应的文章),还把文件id和文件名都给记录到doc内(说白了就是正排)。

  1. # 这里给每个docs片段的metadata里注入file_id
  2. for doc in docs:
  3.     doc.metadata["file_id"] = self.file_id
  4.     doc.metadata["file_name"] = self.url if self.url else os.path.split(self.file_path)[-1]

索引构造

在对文本进行好切片后,就可以开始跑模型准备向数据库灌数据了。此处我把他叫做索引构造,主要包括数据转化和灌库两个操作。

核心的代码同样是在local_doc_qa.insert_files_to_milvus这个函数下,这里面create_embedding就是构造向量的过程,在前面的章节(前沿重器[45] RAG开源项目Qanything源码阅读1-概述+服务)有提及,向量化的模型是单独用triton部署的,所以此处是直接请求模型服务获取的。

CUDA_VISIBLE_DEVICES=$gpu_id1 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble_embed_rerank --http-port=9000 --grpc-port=9001 --metrics-port=9002 --log-verbose=1 > /workspace/qanything_local/logs/debug_logs/embed_rerank_tritonserver.log 2>&1 &

而请求方面,先放一个调用的关键入口。

  1. def create_embedding(self):
  2.     self.embs = self.emb_infer._get_len_safe_embeddings([doc.page_content for doc in self.docs])

这里实际的调用挺深的,首先对于local,有YouDaoLocalEmbeddings,这里是包装向量模型的,里面更多是考虑并发的concurrent代码,向量是内部的embedding_client(一个EmbeddingClient实例)负责的(当然EmbeddingClient下还有concurrent的代码),这个应该才是算法比较关心的部分吧,我直接把EmbeddingClient的核心代码放出来。

  1. import os
  2. import math
  3. import numpy as np
  4. import time
  5. from typing import Optional
  6. import onnxruntime as ort
  7. from tritonclient import utils as client_utils
  8. from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput
  9. from transformers import AutoTokenizer
  10. WEIGHT2NPDTYPE = {
  11.     "fp32": np.float32,
  12.     "fp16": np.float16,
  13. }
  14. class EmbeddingClient:
  15.     DEFAULT_MAX_RESP_WAIT_S = 120
  16.     embed_version = "local_v0.0.1_20230525_6d4019f1559aef84abc2ab8257e1ad4c"
  17.     def __init__(
  18.         self,
  19.         server_url: str,
  20.         model_name: str,
  21.         model_version: str,
  22.         tokenizer_path: str,
  23.         resp_wait_s: Optional[float] = None,
  24.     ):
  25.         self._server_url = server_url
  26.         self._model_name = model_name
  27.         self._model_version = model_version
  28.         self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
  29.         self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
  30.     def get_embedding(self, sentences, max_length=512):
  31.         # Setting up client
  32.     
  33.         inputs_data = self._tokenizer(sentences, padding=True, truncation=True, max_length=max_length, return_tensors='np')
  34.         inputs_data = {k: v for k, v in inputs_data.items()}
  35.     
  36.         client = InferenceServerClient(url=self._server_url)
  37.         model_config = client.get_model_config(self._model_name, self._model_version)
  38.         model_metadata = client.get_model_metadata(self._model_name, self._model_version)
  39.     
  40.         inputs_info = {tm.name: tm for tm in model_metadata.inputs}
  41.         outputs_info = {tm.name: tm for tm in model_metadata.outputs}
  42.         output_names = list(outputs_info)
  43.         outputs_req = [InferRequestedOutput(name_) for name_ in outputs_info]
  44.         infer_inputs = []
  45.         for name_ in inputs_info:
  46.             data = inputs_data[name_]
  47.             infer_input = InferInput(name_, data.shape, inputs_info[name_].datatype)
  48.     
  49.             target_np_dtype = client_utils.triton_to_np_dtype(inputs_info[name_].datatype)
  50.             data = data.astype(target_np_dtype)
  51.     
  52.             infer_input.set_data_from_numpy(data)
  53.             infer_inputs.append(infer_input)
  54.     
  55.         results = client.infer(
  56.             model_name=self._model_name,
  57.             model_version=self._model_version,
  58.             inputs=infer_inputs,
  59.             outputs=outputs_req,
  60.             client_timeout=120,
  61.         )
  62.         y_pred = {name_: results.as_numpy(name_) for name_ in output_names}
  63.         embeddings = y_pred["output"][:,0]
  64.         norm_arr = np.linalg.norm(embeddings, axis=1, keepdims=True)
  65.         embeddings_normalized = embeddings / norm_arr
  66.         return embeddings_normalized.tolist()
  67.     
  68.     def getModelVersion(self):
  69.         return self.embed_version
  • 首先可以看到,tokenizer依旧是本服务做的。

  • 服务的请求主要是client负责,triton是一个grpc接口(GRPC我很早之前写过,可以参考系统学习),输入和输出的数据结构参考InferInputInferRequestedOutput

  • 细节,对模型的输出结果,结果作者还做了额外的处理,主要是做了一个归一化,用np.linalg.norm求了二范数(默认),然后想了都除以了这个二范数。

  • 有留意到,对模型的版本,作者有可以保留,方便进行模型迭代的版本可控性。

GRPC文章:

完成后,就可以开始灌库了,milvus_kv.insert_files。milvus自己是有开源的库的,即pymilvus,作者自己写了一个完整的类MilvusClient,至于pymilvus具体教程大家可以看:https://zhuanlan.zhihu.com/p/676124465。这里我不展开具体的使用方法了,不过还是可以从灌库的源码里挑出一些重要的细节。

  1. async def insert_files(self, file_id, file_name, file_path, docs, embs, batch_size=1000):
  2.     debug_logger.info(f'now inser_file {file_name}')
  3.     now = datetime.now()
  4.     timestamp = now.strftime("%Y%m%d%H%M")
  5.     loop = asyncio.get_running_loop()
  6.     contents = [doc.page_content for doc in docs]
  7.     num_docs = len(docs)
  8.     for batch_start in range(0, num_docs, batch_size):
  9.         batch_end = min(batch_start + batch_size, num_docs)
  10.         data = [[] for _ in range(len(self.sess.schema))]
  11.         for idx in range(batch_start, batch_end):
  12.             cont = contents[idx]
  13.             emb = embs[idx]
  14.             chunk_id = f'{file_id}_{idx}'
  15.             data[0].append(chunk_id)
  16.             data[1].append(file_id)
  17.             data[2].append(file_name)
  18.             data[3].append(file_path)
  19.             data[4].append(timestamp)
  20.             data[5].append(cont)
  21.             data[6].append(emb)
  22.         # 执行插入操作
  23.         try:
  24.             debug_logger.info('Inserting into Milvus...')
  25.             mr = await loop.run_in_executor(
  26.                 self.executor, partial(self.partitions[0].insert, data=data))
  27.             debug_logger.info(f'{file_name} {mr}')
  28.         except Exception as e:
  29.             debug_logger.error(f'Milvus insert file_id:{file_id}, file_name:{file_name} failed: {e}')
  30.             return False
  31.     # 混合检索
  32.     if self.hybrid_search:
  33.         debug_logger.info(f'now inser_file for es: {file_name}')
  34.         for batch_start in range(0, num_docs, batch_size):
  35.             batch_end = min(batch_start + batch_size, num_docs)
  36.             data_es = []
  37.             for idx in range(batch_start, batch_end):
  38.                 data_es_item = {
  39.                     'file_id': file_id,
  40.                     'content': contents[idx],
  41.                     'metadata': {
  42.                         'file_name': file_name,
  43.                         'file_path': file_path,
  44.                         'chunk_id': f'{file_id}_{idx}',
  45.                         'timestamp': timestamp,
  46.                     }
  47.                 }
  48.                 data_es.append(data_es_item)
  49.             try:
  50.                 debug_logger.info('Inserting into es ...')
  51.                 mr = await self.client.insert(data=data_es, refresh=batch_end==num_docs)
  52.                 debug_logger.info(f'{file_name} {mr}')
  53.             except Exception as e:
  54.                 debug_logger.error(f'ES insert file_id: {file_id}\nfile_name: {file_name}\nfailed: {e}')
  55.                 return False
  56.     return True
  • milvus使用的是pymilvus工具来读写,其中self.partitions[0].insert就是用存储数据的,此处可以注意到data内有很多不同的字段。

  • 执行代码使用的是loop.run_in_executor,有留意到,在MilvusClient内有一个self.executor,这个的定义在这个类的__init__内,self.executor = ThreadPoolExecutor(max_workers=10),这里新建了一个线程池,新技能get。

  • 下方是ES的数据灌入。个人感觉,这个ES数据处理写在这个位置并不是很合适,应该单独出来处理,毕竟混合代码不太好看到。

小结

本文离线的文件处理,我看了挺久,而且写的时间也很长。我自己看完的收获还挺大的,原本是对文档处理比较生疏,但这次看完对这块的理解比较深了,而且通过通篇阅读,也能了解到作者的设计思路,希望大家也能在阅读本文的过程中有所收获吧。

下一篇,在线推理,敬请期待。

84264a505a39326e635f5ed1199cda56.png

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/885507
推荐阅读
相关标签
  

闽ICP备14008679号