当前位置:   article > 正文

创新实训2024.04.24日志:RAG技术初探_知识库匹配相关度阈值

知识库匹配相关度阈值

1. 什么是RAG技术

RAG is short for Retrieval Augmented Generation。结合了检索模型和生成模型的能力,以提高文本生成任务的性能。具体来说,RAG技术允许大型语言模型(Large Language Model, LLM)在生成回答时,不仅依赖于其内部知识,还能检索并利用外部数据源中的信息。

对于这个概念,我自己的理解是,大模型相当于是一个人,而RAG技术检索并利用的外部数据源就是书本、或者电子/数据资料。而RAG就是人检索并根据书本或者电子资料生成任务的能力。

比如一个人一目十行,理解能力强,可以快速地汲取知识并加以理解从而输出,就代表这个人的学习能力强,就相当于RAG技术性能优越。而另一个人阅读能力差,不容易理解新知识,就相当于RAG技术没做好,性能不行。

在这张图中,我把人类智能比作RAG技术,人类比作AI,外部知识来源比作向量数据库(一般与RAG一起使用)。RAG的实现越好,那么相当于越智能,则AI的能力越强。

2. RAG技术的Working Pipeline

首先我们要搜集插入到向量数据库 中,也即实体的文档、结构化知识、手册,读取文本内容,进行文本分割,进行向量嵌入后插入向量数据库中。

当用户请求大模型时,首先将查询向量化,随后检索向量库得到相似度高的知识,作为背景注入到prompt,随后大模型再生成回答。

3. RAG的实现

在github上,有一个RAG实现的Web应用的Demo。Langchain-Chatchat

我们同样打算以Web应用的模式构建一个能够被请求用来检索知识的向量数据库。因此先学习阅读一下这个项目的代码。

3.1. Web应用的入口:挂载Web应用路径

这一部分其实和RAG本身关系不大了,属于是网络通信方面的部分。但因为它是整个应用的入口,所以有必要探索一下。

首先在这个项目的README文件中,我们发现了这个Web应用还有个在线的接口文档。

从这个接口文档中,可以看到对于知识库(Knowledge Base) 的接口,这一部分就涉及了向量数据库。

我们可以通过在IDE中全局搜索这些接口,来找到暴露这些应用路径的地方。

可以看到,server/api.py下挂载了这些接口,我们来到这个文件一探究竟。其中不乏这样的函数:

  1. app.post("/knowledge_base/create_knowledge_base",
  2. tags=["Knowledge Base Management"],
  3. response_model=BaseResponse,
  4. summary="创建知识库"
  5. )(create_kb)
  6. app.post("/knowledge_base/delete_knowledge_base",
  7. tags=["Knowledge Base Management"],
  8. response_model=BaseResponse,
  9. summary="删除知识库"
  10. )(delete_kb)
  11. app.get("/knowledge_base/list_files",
  12. tags=["Knowledge Base Management"],
  13. response_model=ListResponse,
  14. summary="获取知识库内的文件列表"
  15. )(list_files)
  16. app.post("/knowledge_base/search_docs",
  17. tags=["Knowledge Base Management"],
  18. response_model=List[DocumentWithVSId],
  19. summary="搜索知识库"
  20. )(search_docs)

 我们点到每个函数中的参数,即create_kb这样的参数,来到了一个名叫kb_api.py的文件,其中暴露了这个函数(create_kb)。

此时我们就通过挂载Web应用路径的入口,找到了与向量数据库交互的模块。

 3.2. 与向量数据库交互

现在来看看这些与向量数据库交互的函数。

通过交互函数看知识库工程架构

首先我们关注到create_kb中的这样一部分代码:

  1. kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
  2. try:
  3. kb.create_kb()

 光看这个名字,我们就能知道,这是一个工厂方法的设计模式。获取知识库的方式并不是直接拿到知识库的操作柄,而是先通过提供知识库服务的工厂拿到一项知识库的服务。

对于get_service函数,如下:

  1. @staticmethod
  2. def get_service(kb_name: str,
  3. vector_store_type: Union[str, SupportedVSType],
  4. embed_model: str = EMBEDDING_MODEL,
  5. ) -> KBService:
  6. if isinstance(vector_store_type, str):
  7. vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
  8. if SupportedVSType.FAISS == vector_store_type:
  9. from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
  10. return FaissKBService(kb_name, embed_model=embed_model)
  11. elif SupportedVSType.PG == vector_store_type:
  12. from server.knowledge_base.kb_service.pg_kb_service import PGKBService
  13. return PGKBService(kb_name, embed_model=embed_model)
  14. elif SupportedVSType.MILVUS == vector_store_type:
  15. from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
  16. return MilvusKBService(kb_name,embed_model=embed_model)
  17. elif SupportedVSType.ZILLIZ == vector_store_type:
  18. from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
  19. return ZillizKBService(kb_name, embed_model=embed_model)
  20. elif SupportedVSType.DEFAULT == vector_store_type:
  21. from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
  22. return MilvusKBService(kb_name,
  23. embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
  24. elif SupportedVSType.ES == vector_store_type:
  25. from server.knowledge_base.kb_service.es_kb_service import ESKBService
  26. return ESKBService(kb_name, embed_model=embed_model)
  27. elif SupportedVSType.CHROMADB == vector_store_type:
  28. from server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
  29. return ChromaKBService(kb_name, embed_model=embed_model)
  30. elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
  31. from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
  32. return DefaultKBService(kb_name)

那么这个是在干什么?显然,他根据向量嵌入的方式,确定要创建的数据库服务是基于哪个向量数据库的,可能是chroma,也可能是Faiss,等等。

总之,它返回了一个KBService子类的实例。而这里KBService并非是一个可实例化的类,因为它是抽象类。

在server/knowledge_base/kb_service中,我们可以看到Class Definition。

  1. @abstractmethod
  2. def do_create_kb(self):
  3. """
  4. 创建知识库子类实自己逻辑
  5. """
  6. pass

在类定义中,出现了@abstractmethod注解,说明这是个抽象类。

那么其实现都在哪里呢?经过一番翻阅,在server/knowledge_base/kb_service下,包括了大量的基于不同数据库的实现类。

在翻阅代码时,我关注到了项目默认的向量数据库是faiss,因此我们可以来到faiss_kb_service中查看。

  1. class FaissKBService(KBService):
  2. vs_path: str
  3. kb_path: str
  4. vector_name: str = None

 类定义中,对于KBService的继承赫然在目。

再回到通过KBServiceFactory创建KBService处:

  1. kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
  2. try:
  3. kb.create_kb()

 我们溯源create_kb,可以发现:

  1. def create_kb(self):
  2. """
  3. 创建知识库
  4. """
  5. if not os.path.exists(self.doc_path):
  6. os.makedirs(self.doc_path)
  7. self.do_create_kb()
  8. status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
  9. return status

可以看到,create_kb调用了self(实例自身)的do_create_kb()。而这就是刚才提到的抽象方法,也就是它会根据不同类对其的覆写,执行不同的逻辑。

  1. def do_create_kb(self):
  2. if not os.path.exists(self.vs_path):
  3. os.makedirs(self.vs_path)
  4. self.load_vector_store()
  5. def load_vector_store(self) -> ThreadSafeFaiss:
  6. return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
  7. vector_name=self.vector_name,
  8. embed_model=self.embed_model)

例如faiss就有自己独特的创建数据库的方式。

因此这个设计架构就明确了,是一个四层的Web-静态工厂-抽象类-实体类的架构。如下图所示:

Mapping from Abstract Working Pipeline to Code 

现在我们知道了如何获取一个向量数据库的服务。但在哪里使用它,如何使用它呢?正如先前RAG的Working Pipeline中所说,用户在请求大模型进行任务时,先通过检索向量数据库获取相似知识优化Prompt,再进行提问。那么这样一套流程,是如何映射到代码中的,我们是如何使用向量数据库提供的检索功能的?

找到RAG流程的入口

为了找到这个接口的入口,我还是先翻看了server/api.py文件,其中包括了:

  1. app.post("/chat/chat",
  2. tags=["Chat"],
  3. summary="与llm模型对话(通过LLMChain)",
  4. )(chat)
  5. app.post("/chat/search_engine_chat",
  6. tags=["Chat"],
  7. summary="与搜索引擎对话",
  8. )(search_engine_chat)
  9. app.post("/chat/feedback",
  10. tags=["Chat"],
  11. summary="返回llm模型对话评分",
  12. )(chat_feedback)
  13. app.post("/chat/knowledge_base_chat",
  14. tags=["Chat"],
  15. summary="与知识库对话")(knowledge_base_chat)
  16. app.post("/chat/file_chat",
  17. tags=["Knowledge Base Management"],
  18. summary="文件对话"
  19. )(file_chat)
  20. app.post("/chat/agent_chat",
  21. tags=["Chat"],
  22. summary="与agent对话")(agent_chat)

 一开始我以为/chat/chat这个接口是包括了RAG流程的接口,但后来我翻了翻代码,发觉并没有检索向量数据库。

随后经过一些翻阅,我找到了/chat/knowledge_base_chat这个一接口:

  1. async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
  2. knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
  3. top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
  4. score_threshold: float = Body(
  5. SCORE_THRESHOLD,
  6. description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
  7. ge=0,
  8. le=2
  9. ),
  10. history: List[History] = Body(
  11. [],
  12. description="历史对话",
  13. examples=[[
  14. {"role": "user",
  15. "content": "我们来玩成语接龙,我先来,生龙活虎"},
  16. {"role": "assistant",
  17. "content": "虎头虎脑"}]]
  18. ),
  19. stream: bool = Body(False, description="流式输出"),
  20. model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
  21. temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
  22. max_tokens: Optional[int] = Body(
  23. None,
  24. description="限制LLM生成Token数量,默认None代表模型最大值"
  25. ),
  26. prompt_name: str = Body(
  27. "default",
  28. description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
  29. ),
  30. request: Request = None,
  31. ):
  32. kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
  33. if kb is None:
  34. return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
  35. history = [History.from_data(h) for h in history]
  36. async def knowledge_base_chat_iterator(
  37. query: str,
  38. top_k: int,
  39. history: Optional[List[History]],
  40. model_name: str = model_name,
  41. prompt_name: str = prompt_name,
  42. ) -> AsyncIterable[str]:
  43. nonlocal max_tokens
  44. callback = AsyncIteratorCallbackHandler()
  45. if isinstance(max_tokens, int) and max_tokens <= 0:
  46. max_tokens = None
  47. model = get_ChatOpenAI(
  48. model_name=model_name,
  49. temperature=temperature,
  50. max_tokens=max_tokens,
  51. callbacks=[callback],
  52. )
  53. docs = await run_in_threadpool(search_docs,
  54. query=query,
  55. knowledge_base_name=knowledge_base_name,
  56. top_k=top_k,
  57. score_threshold=score_threshold)
  58. # 加入reranker
  59. if USE_RERANKER:
  60. reranker_model_path = get_model_path(RERANKER_MODEL)
  61. reranker_model = LangchainReranker(top_n=top_k,
  62. device=embedding_device(),
  63. max_length=RERANKER_MAX_LENGTH,
  64. model_name_or_path=reranker_model_path
  65. )
  66. print("-------------before rerank-----------------")
  67. print(docs)
  68. docs = reranker_model.compress_documents(documents=docs,
  69. query=query)
  70. print("------------after rerank------------------")
  71. print(docs)
  72. context = "\n".join([doc.page_content for doc in docs])
  73. if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
  74. prompt_template = get_prompt_template("knowledge_base_chat", "empty")
  75. else:
  76. prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
  77. input_msg = History(role="user", content=prompt_template).to_msg_template(False)
  78. chat_prompt = ChatPromptTemplate.from_messages(
  79. [i.to_msg_template() for i in history] + [input_msg])
  80. chain = LLMChain(prompt=chat_prompt, llm=model)
  81. # Begin a task that runs in the background.
  82. task = asyncio.create_task(wrap_done(
  83. chain.acall({"context": context, "question": query}),
  84. callback.done),
  85. )
  86. source_documents = []
  87. for inum, doc in enumerate(docs):
  88. filename = doc.metadata.get("source")
  89. parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
  90. base_url = request.base_url
  91. url = f"{base_url}knowledge_base/download_doc?" + parameters
  92. text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
  93. source_documents.append(text)
  94. if len(source_documents) == 0: # 没有找到相关文档
  95. source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")
  96. if stream:
  97. async for token in callback.aiter():
  98. # Use server-sent-events to stream the response
  99. yield json.dumps({"answer": token}, ensure_ascii=False)
  100. yield json.dumps({"docs": source_documents}, ensure_ascii=False)
  101. else:
  102. answer = ""
  103. async for token in callback.aiter():
  104. answer += token
  105. yield json.dumps({"answer": answer,
  106. "docs": source_documents},
  107. ensure_ascii=False)
  108. await task
  109. return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))

他这个函数签名非常长,一堆参数,但实际有用的其实主要还是集中在query,也即用户查询上,其他的都是要调用langchain的库或者与向量数据库交互的必要参数。top k个相关向量是RAG技术的一部分,也是必要的参数。

源码解读

首先,先获取了数据库服务。(当然也可能数据库不存在)

  1. kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
  2. if kb is None:
  3. return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")

随后选择LLM模型实例:

  1. model = get_ChatOpenAI(
  2. model_name=model_name,
  3. temperature=temperature,
  4. max_tokens=max_tokens,
  5. callbacks=[callback],
  6. )

再在对应的向量数据库中检索相关文档(top k个)

  1. docs = await run_in_threadpool(search_docs,
  2. query=query,
  3. knowledge_base_name=knowledge_base_name,
  4. top_k=top_k,
  5. score_threshold=score_threshold)

这个异步调用中的search_docs暴露自server/knowledge_basekb_doc_api.py,如下:

  1. def search_docs(
  2. query: str = Body("", description="用户输入", examples=["你好"]),
  3. knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
  4. top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
  5. score_threshold: float = Body(SCORE_THRESHOLD,
  6. description="知识库匹配相关度阈值,取值范围在0-1之间,"
  7. "SCORE越小,相关度越高,"
  8. "取到1相当于不筛选,建议设置在0.5左右",
  9. ge=0, le=1),
  10. file_name: str = Body("", description="文件名称,支持 sql 通配符"),
  11. metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
  12. ) -> List[DocumentWithVSId]:
  13. kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
  14. data = []
  15. if kb is not None:
  16. if query:
  17. docs = kb.search_docs(query, top_k, score_threshold)
  18. data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
  19. elif file_name or metadata:
  20. data = kb.list_docs(file_name=file_name, metadata=metadata)
  21. for d in data:
  22. if "vector" in d.metadata:
  23. del d.metadata["vector"]
  24. return data

首先还是获取数据库服务,随后调用服务类暴露的search_docs函数(这个很显然,对于不同向量数据库来说,肯定是具体实现不一样), 随后返回相似度在阈值内的top_k个结果。

  1. if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
  2. prompt_template = get_prompt_template("knowledge_base_chat", "empty")
  3. else:
  4. prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
  5. input_msg = History(role="user", content=prompt_template).to_msg_template(False)
  6. chat_prompt = ChatPromptTemplate.from_messages(
  7. [i.to_msg_template() for i in history] + [input_msg])
  8. chain = LLMChain(prompt=chat_prompt, llm=model)

 随后,建立prompt模板。然后根据历史会话信息建立当前对话的prompt。

之后通过LangChain提供的LLMChain,获取能够进行用户任务的中间件。

  1. # Begin a task that runs in the background.
  2. task = asyncio.create_task(wrap_done(
  3. chain.acall({"context": context, "question": query}),
  4. callback.done),
  5. )

随后启动一个后台的异步任务,将向量数据库中检索到的文档作为知识背景,用户的输入作为问题。

  1. source_documents = []
  2. for inum, doc in enumerate(docs):
  3. filename = doc.metadata.get("source")
  4. parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
  5. base_url = request.base_url
  6. url = f"{base_url}knowledge_base/download_doc?" + parameters
  7. text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
  8. source_documents.append(text)
  9. if len(source_documents) == 0: # 没有找到相关文档
  10. source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")

一般LLM回答问题,会把自己参考的文献放出来(比如说Kimi),这一部分做的就是拼接参考文献字符串。

return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))

 最后返回大模型的回答。

这个过程就是RAG的Working Pipeline在代码部分中的映射。

将知识嵌入到知识库

这一部分相对而言比较直接。在server/api.py中,有这么一段:

  1. app.post("/knowledge_base/upload_docs",
  2. tags=["Knowledge Base Management"],
  3. response_model=BaseResponse,
  4. summary="上传文件到知识库,并/或进行向量化"
  5. )(upload_docs)

 找到对应的upload_docs,在server/knowledge_basekb_doc_api.py中。

  1. def upload_docs(
  2. files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
  3. knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
  4. override: bool = Form(False, description="覆盖已有文件"),
  5. to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
  6. chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
  7. chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
  8. zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
  9. docs: Json = Form({}, description="自定义的docs,需要转为json字符串",
  10. examples=[{"test.txt": [Document(page_content="custom doc")]}]),
  11. not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
  12. ) -> BaseResponse:
  13. """
  14. API接口:上传文件,并/或向量化
  15. """
  16. if not validate_kb_name(knowledge_base_name):
  17. return BaseResponse(code=403, msg="Don't attack me")
  18. kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
  19. if kb is None:
  20. return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
  21. failed_files = {}
  22. file_names = list(docs.keys())
  23. # 先将上传的文件保存到磁盘
  24. for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
  25. filename = result["data"]["file_name"]
  26. if result["code"] != 200:
  27. failed_files[filename] = result["msg"]
  28. if filename not in file_names:
  29. file_names.append(filename)
  30. # 对保存的文件进行向量化
  31. if to_vector_store:
  32. result = update_docs(
  33. knowledge_base_name=knowledge_base_name,
  34. file_names=file_names,
  35. override_custom_docs=True,
  36. chunk_size=chunk_size,
  37. chunk_overlap=chunk_overlap,
  38. zh_title_enhance=zh_title_enhance,
  39. docs=docs,
  40. not_refresh_vs_cache=True,
  41. )
  42. failed_files.update(result.data["failed_files"])
  43. if not not_refresh_vs_cache:
  44. kb.save_vector_store()
  45. return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})

这一部分最重要的还是save_vector_store函数,不过这一部分属于每种数据库自己的实现了。

我们可以看一个faiss的

  1. def load_vector_store(self) -> ThreadSafeFaiss:
  2. return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
  3. vector_name=self.vector_name,
  4. embed_model=self.embed_model)
  5. def load_vector_store(
  6. self,
  7. kb_name: str,
  8. vector_name: str = None,
  9. create: bool = True,
  10. embed_model: str = EMBEDDING_MODEL,
  11. embed_device: str = embedding_device(),
  12. ) -> ThreadSafeFaiss:
  13. self.atomic.acquire()
  14. vector_name = vector_name or embed_model
  15. cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
  16. if cache is None:
  17. item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
  18. self.set((kb_name, vector_name), item)
  19. with item.acquire(msg="初始化"):
  20. self.atomic.release()
  21. logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
  22. vs_path = get_vs_path(kb_name, vector_name)
  23. if os.path.isfile(os.path.join(vs_path, "index.faiss")):
  24. embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
  25. vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
  26. elif create:
  27. # create an empty vector store
  28. if not os.path.exists(vs_path):
  29. os.makedirs(vs_path)
  30. vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
  31. vector_store.save_local(vs_path)
  32. else:
  33. raise RuntimeError(f"knowledge base {kb_name} not exist.")
  34. item.obj = vector_store
  35. item.finish_loading()
  36. else:
  37. self.atomic.release()
  38. return self.get((kb_name, vector_name))

其实这个模块是个缓存机制,也就是说每次检索都会查看是否已经有这个向量数据库的操作柄了。如果有直接返回,如果没有则加载一遍,这个加载的过程集中在:

  1. def get(self, key: str) -> ThreadSafeObject:
  2. if cache := self._cache.get(key):
  3. cache.wait_for_loading()
  4. return cache

那么他返回的是什么呢?是一个对应数据库的操作柄,定义如下:

  1. class ThreadSafeFaiss(ThreadSafeObject):
  2. def __repr__(self) -> str:
  3. cls = type(self).__name__
  4. return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
  5. def docs_count(self) -> int:
  6. return len(self._obj.docstore._dict)
  7. def save(self, path: str, create_path: bool = True):
  8. with self.acquire():
  9. if not os.path.isdir(path) and create_path:
  10. os.makedirs(path)
  11. ret = self._obj.save_local(path)
  12. logger.info(f"已将向量库 {self.key} 保存到磁盘")
  13. return ret
  14. def clear(self):
  15. ret = []
  16. with self.acquire():
  17. ids = list(self._obj.docstore._dict.keys())
  18. if ids:
  19. ret = self._obj.delete(ids)
  20. assert len(self._obj.docstore._dict) == 0
  21. logger.info(f"已将向量库 {self.key} 清空")
  22. return ret

本质上是存储向量化文档的一个对象。

4. 体验这个应用

虽然README中说了怎么用,但这里想补充下。

首先大模型你可以不下载(如果不用这个服务),但向量嵌入模型必须下载。如果你hugging-face用git clone拉不下来,上去手动下也行。

其次如果你的电脑配不了cuda环境,那么你就没办法加载运行大模型。不过你可以选择放弃大模型服务,因为还有向量知识库的服务可以用。

只需要在启动脚本里把加载运行大模型部分的代码注释掉就行(以下是完整的启动脚本):

  1. import asyncio
  2. import multiprocessing as mp
  3. import os
  4. import subprocess
  5. import sys
  6. from multiprocessing import Process
  7. from datetime import datetime
  8. from pprint import pprint
  9. from langchain_core._api import deprecated
  10. try:
  11. import numexpr
  12. n_cores = numexpr.utils.detect_number_of_cores()
  13. os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
  14. except:
  15. pass
  16. sys.path.append(os.path.dirname(os.path.dirname(__file__)))
  17. from configs import (
  18. LOG_PATH,
  19. log_verbose,
  20. logger,
  21. LLM_MODELS,
  22. EMBEDDING_MODEL,
  23. TEXT_SPLITTER_NAME,
  24. FSCHAT_CONTROLLER,
  25. FSCHAT_OPENAI_API,
  26. FSCHAT_MODEL_WORKERS,
  27. API_SERVER,
  28. WEBUI_SERVER,
  29. HTTPX_DEFAULT_TIMEOUT,
  30. )
  31. from server.utils import (fschat_controller_address, fschat_model_worker_address,
  32. fschat_openai_api_address, get_httpx_client, get_model_worker_config,
  33. MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
  34. from server.knowledge_base.migrate import create_tables
  35. import argparse
  36. from typing import List, Dict
  37. from configs import VERSION
  38. @deprecated(
  39. since="0.3.0",
  40. message="模型启动功能将于 Langchain-Chatchat 0.3.x重写,支持更多模式和加速启动,0.2.x中相关功能将废弃",
  41. removal="0.3.0")
  42. def create_controller_app(
  43. dispatch_method: str,
  44. log_level: str = "INFO",
  45. ) -> FastAPI:
  46. import fastchat.constants
  47. fastchat.constants.LOGDIR = LOG_PATH
  48. from fastchat.serve.controller import app, Controller, logger
  49. logger.setLevel(log_level)
  50. controller = Controller(dispatch_method)
  51. sys.modules["fastchat.serve.controller"].controller = controller
  52. MakeFastAPIOffline(app)
  53. app.title = "FastChat Controller"
  54. app._controller = controller
  55. return app
  56. def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
  57. """
  58. kwargs包含的字段如下:
  59. host:
  60. port:
  61. model_names:[`model_name`]
  62. controller_address:
  63. worker_address:
  64. 对于Langchain支持的模型:
  65. langchain_model:True
  66. 不会使用fschat
  67. 对于online_api:
  68. online_api:True
  69. worker_class: `provider`
  70. 对于离线模型:
  71. model_path: `model_name_or_path`,huggingface的repo-id或本地路径
  72. device:`LLM_DEVICE`
  73. """
  74. import fastchat.constants
  75. fastchat.constants.LOGDIR = LOG_PATH
  76. import argparse
  77. parser = argparse.ArgumentParser()
  78. args = parser.parse_args([])
  79. for k, v in kwargs.items():
  80. setattr(args, k, v)
  81. if worker_class := kwargs.get("langchain_model"): # Langchian支持的模型不用做操作
  82. from fastchat.serve.base_model_worker import app
  83. worker = ""
  84. # 在线模型API
  85. elif worker_class := kwargs.get("worker_class"):
  86. from fastchat.serve.base_model_worker import app
  87. worker = worker_class(model_names=args.model_names,
  88. controller_addr=args.controller_address,
  89. worker_addr=args.worker_address)
  90. # sys.modules["fastchat.serve.base_model_worker"].worker = worker
  91. sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
  92. # 本地模型
  93. else:
  94. from configs.model_config import VLLM_MODEL_DICT
  95. if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
  96. import fastchat.serve.vllm_worker
  97. from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
  98. from vllm import AsyncLLMEngine
  99. from vllm.engine.arg_utils import AsyncEngineArgs
  100. args.tokenizer = args.model_path
  101. args.tokenizer_mode = 'auto'
  102. args.trust_remote_code = True
  103. args.download_dir = None
  104. args.load_format = 'auto'
  105. args.dtype = 'auto'
  106. args.seed = 0
  107. args.worker_use_ray = False
  108. args.pipeline_parallel_size = 1
  109. args.tensor_parallel_size = 1
  110. args.block_size = 16
  111. args.swap_space = 4 # GiB
  112. args.gpu_memory_utilization = 0.90
  113. args.max_num_batched_tokens = None # 一个批次中的最大令牌(tokens)数量,这个取决于你的显卡和大模型设置,设置太大显存会不够
  114. args.max_num_seqs = 256
  115. args.disable_log_stats = False
  116. args.conv_template = None
  117. args.limit_worker_concurrency = 5
  118. args.no_register = False
  119. args.num_gpus = 1 # vllm worker的切分是tensor并行,这里填写显卡的数量
  120. args.engine_use_ray = False
  121. args.disable_log_requests = False
  122. # 0.2.1 vllm后要加的参数, 但是这里不需要
  123. args.max_model_len = None
  124. args.revision = None
  125. args.quantization = None
  126. args.max_log_len = None
  127. args.tokenizer_revision = None
  128. # 0.2.2 vllm需要新加的参数
  129. args.max_paddings = 256
  130. if args.model_path:
  131. args.model = args.model_path
  132. if args.num_gpus > 1:
  133. args.tensor_parallel_size = args.num_gpus
  134. for k, v in kwargs.items():
  135. setattr(args, k, v)
  136. engine_args = AsyncEngineArgs.from_cli_args(args)
  137. engine = AsyncLLMEngine.from_engine_args(engine_args)
  138. worker = VLLMWorker(
  139. controller_addr=args.controller_address,
  140. worker_addr=args.worker_address,
  141. worker_id=worker_id,
  142. model_path=args.model_path,
  143. model_names=args.model_names,
  144. limit_worker_concurrency=args.limit_worker_concurrency,
  145. no_register=args.no_register,
  146. llm_engine=engine,
  147. conv_template=args.conv_template,
  148. )
  149. sys.modules["fastchat.serve.vllm_worker"].engine = engine
  150. sys.modules["fastchat.serve.vllm_worker"].worker = worker
  151. sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
  152. else:
  153. from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
  154. args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
  155. args.max_gpu_memory = "22GiB"
  156. args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
  157. args.load_8bit = False
  158. args.cpu_offloading = None
  159. args.gptq_ckpt = None
  160. args.gptq_wbits = 16
  161. args.gptq_groupsize = -1
  162. args.gptq_act_order = False
  163. args.awq_ckpt = None
  164. args.awq_wbits = 16
  165. args.awq_groupsize = -1
  166. args.model_names = [""]
  167. args.conv_template = None
  168. args.limit_worker_concurrency = 5
  169. args.stream_interval = 2
  170. args.no_register = False
  171. args.embed_in_truncate = False
  172. for k, v in kwargs.items():
  173. setattr(args, k, v)
  174. if args.gpus:
  175. if args.num_gpus is None:
  176. args.num_gpus = len(args.gpus.split(','))
  177. if len(args.gpus.split(",")) < args.num_gpus:
  178. raise ValueError(
  179. f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
  180. )
  181. os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
  182. gptq_config = GptqConfig(
  183. ckpt=args.gptq_ckpt or args.model_path,
  184. wbits=args.gptq_wbits,
  185. groupsize=args.gptq_groupsize,
  186. act_order=args.gptq_act_order,
  187. )
  188. awq_config = AWQConfig(
  189. ckpt=args.awq_ckpt or args.model_path,
  190. wbits=args.awq_wbits,
  191. groupsize=args.awq_groupsize,
  192. )
  193. worker = ModelWorker(
  194. controller_addr=args.controller_address,
  195. worker_addr=args.worker_address,
  196. worker_id=worker_id,
  197. model_path=args.model_path,
  198. model_names=args.model_names,
  199. limit_worker_concurrency=args.limit_worker_concurrency,
  200. no_register=args.no_register,
  201. device=args.device,
  202. num_gpus=args.num_gpus,
  203. max_gpu_memory=args.max_gpu_memory,
  204. load_8bit=args.load_8bit,
  205. cpu_offloading=args.cpu_offloading,
  206. gptq_config=gptq_config,
  207. awq_config=awq_config,
  208. stream_interval=args.stream_interval,
  209. conv_template=args.conv_template,
  210. embed_in_truncate=args.embed_in_truncate,
  211. )
  212. sys.modules["fastchat.serve.model_worker"].args = args
  213. sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
  214. # sys.modules["fastchat.serve.model_worker"].worker = worker
  215. sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
  216. MakeFastAPIOffline(app)
  217. app.title = f"FastChat LLM Server ({args.model_names[0]})"
  218. app._worker = worker
  219. return app
  220. def create_openai_api_app(
  221. controller_address: str,
  222. api_keys: List = [],
  223. log_level: str = "INFO",
  224. ) -> FastAPI:
  225. import fastchat.constants
  226. fastchat.constants.LOGDIR = LOG_PATH
  227. from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
  228. from fastchat.utils import build_logger
  229. logger = build_logger("openai_api", "openai_api.log")
  230. logger.setLevel(log_level)
  231. app.add_middleware(
  232. CORSMiddleware,
  233. allow_credentials=True,
  234. allow_origins=["*"],
  235. allow_methods=["*"],
  236. allow_headers=["*"],
  237. )
  238. sys.modules["fastchat.serve.openai_api_server"].logger = logger
  239. app_settings.controller_address = controller_address
  240. app_settings.api_keys = api_keys
  241. MakeFastAPIOffline(app)
  242. app.title = "FastChat OpeanAI API Server"
  243. return app
  244. def _set_app_event(app: FastAPI, started_event: mp.Event = None):
  245. @app.on_event("startup")
  246. async def on_startup():
  247. if started_event is not None:
  248. started_event.set()
  249. def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
  250. import uvicorn
  251. import httpx
  252. from fastapi import Body
  253. import time
  254. import sys
  255. from server.utils import set_httpx_config
  256. set_httpx_config()
  257. app = create_controller_app(
  258. dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
  259. log_level=log_level,
  260. )
  261. _set_app_event(app, started_event)
  262. # add interface to release and load model worker
  263. @app.post("/release_worker")
  264. def release_worker(
  265. model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
  266. # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[FSCHAT_CONTROLLER_address()]),
  267. new_model_name: str = Body(None, description="释放后加载该模型"),
  268. keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
  269. ) -> Dict:
  270. available_models = app._controller.list_models()
  271. if new_model_name in available_models:
  272. msg = f"要切换的LLM模型 {new_model_name} 已经存在"
  273. logger.info(msg)
  274. return {"code": 500, "msg": msg}
  275. if new_model_name:
  276. logger.info(f"开始切换LLM模型:从 {model_name}{new_model_name}")
  277. else:
  278. logger.info(f"即将停止LLM模型: {model_name}")
  279. if model_name not in available_models:
  280. msg = f"the model {model_name} is not available"
  281. logger.error(msg)
  282. return {"code": 500, "msg": msg}
  283. worker_address = app._controller.get_worker_address(model_name)
  284. if not worker_address:
  285. msg = f"can not find model_worker address for {model_name}"
  286. logger.error(msg)
  287. return {"code": 500, "msg": msg}
  288. with get_httpx_client() as client:
  289. r = client.post(worker_address + "/release",
  290. json={"new_model_name": new_model_name, "keep_origin": keep_origin})
  291. if r.status_code != 200:
  292. msg = f"failed to release model: {model_name}"
  293. logger.error(msg)
  294. return {"code": 500, "msg": msg}
  295. if new_model_name:
  296. timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register
  297. while timer > 0:
  298. models = app._controller.list_models()
  299. if new_model_name in models:
  300. break
  301. time.sleep(1)
  302. timer -= 1
  303. if timer > 0:
  304. msg = f"sucess change model from {model_name} to {new_model_name}"
  305. logger.info(msg)
  306. return {"code": 200, "msg": msg}
  307. else:
  308. msg = f"failed change model from {model_name} to {new_model_name}"
  309. logger.error(msg)
  310. return {"code": 500, "msg": msg}
  311. else:
  312. msg = f"sucess to release model: {model_name}"
  313. logger.info(msg)
  314. return {"code": 200, "msg": msg}
  315. host = FSCHAT_CONTROLLER["host"]
  316. port = FSCHAT_CONTROLLER["port"]
  317. if log_level == "ERROR":
  318. sys.stdout = sys.__stdout__
  319. sys.stderr = sys.__stderr__
  320. uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
  321. def run_model_worker(
  322. model_name: str = LLM_MODELS[0],
  323. controller_address: str = "",
  324. log_level: str = "INFO",
  325. q: mp.Queue = None,
  326. started_event: mp.Event = None,
  327. ):
  328. import uvicorn
  329. from fastapi import Body
  330. import sys
  331. from server.utils import set_httpx_config
  332. set_httpx_config()
  333. kwargs = get_model_worker_config(model_name)
  334. host = kwargs.pop("host")
  335. port = kwargs.pop("port")
  336. kwargs["model_names"] = [model_name]
  337. kwargs["controller_address"] = controller_address or fschat_controller_address()
  338. kwargs["worker_address"] = fschat_model_worker_address(model_name)
  339. model_path = kwargs.get("model_path", "")
  340. kwargs["model_path"] = model_path
  341. app = create_model_worker_app(log_level=log_level, **kwargs)
  342. _set_app_event(app, started_event)
  343. if log_level == "ERROR":
  344. sys.stdout = sys.__stdout__
  345. sys.stderr = sys.__stderr__
  346. # add interface to release and load model
  347. @app.post("/release")
  348. def release_model(
  349. new_model_name: str = Body(None, description="释放后加载该模型"),
  350. keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
  351. ) -> Dict:
  352. if keep_origin:
  353. if new_model_name:
  354. q.put([model_name, "start", new_model_name])
  355. else:
  356. if new_model_name:
  357. q.put([model_name, "replace", new_model_name])
  358. else:
  359. q.put([model_name, "stop", None])
  360. return {"code": 200, "msg": "done"}
  361. uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
  362. def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
  363. import uvicorn
  364. import sys
  365. from server.utils import set_httpx_config
  366. set_httpx_config()
  367. controller_addr = fschat_controller_address()
  368. app = create_openai_api_app(controller_addr, log_level=log_level)
  369. _set_app_event(app, started_event)
  370. host = FSCHAT_OPENAI_API["host"]
  371. port = FSCHAT_OPENAI_API["port"]
  372. if log_level == "ERROR":
  373. sys.stdout = sys.__stdout__
  374. sys.stderr = sys.__stderr__
  375. uvicorn.run(app, host=host, port=port)
  376. def run_api_server(started_event: mp.Event = None, run_mode: str = None):
  377. from server.api import create_app
  378. import uvicorn
  379. from server.utils import set_httpx_config
  380. set_httpx_config()
  381. app = create_app(run_mode=run_mode)
  382. _set_app_event(app, started_event)
  383. host = API_SERVER["host"]
  384. port = API_SERVER["port"]
  385. uvicorn.run(app, host=host, port=port)
  386. def run_webui(started_event: mp.Event = None, run_mode: str = None):
  387. from server.utils import set_httpx_config
  388. set_httpx_config()
  389. host = WEBUI_SERVER["host"]
  390. port = WEBUI_SERVER["port"]
  391. cmd = ["streamlit", "run", "webui.py",
  392. "--server.address", host,
  393. "--server.port", str(port),
  394. "--theme.base", "light",
  395. "--theme.primaryColor", "#165dff",
  396. "--theme.secondaryBackgroundColor", "#f5f5f5",
  397. "--theme.textColor", "#000000",
  398. ]
  399. if run_mode == "lite":
  400. cmd += [
  401. "--",
  402. "lite",
  403. ]
  404. p = subprocess.Popen(cmd)
  405. started_event.set()
  406. p.wait()
  407. def parse_args() -> argparse.ArgumentParser:
  408. parser = argparse.ArgumentParser()
  409. parser.add_argument(
  410. "-a",
  411. "--all-webui",
  412. action="store_true",
  413. help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py",
  414. dest="all_webui",
  415. )
  416. parser.add_argument(
  417. "--all-api",
  418. action="store_true",
  419. help="run fastchat's controller/openai_api/model_worker servers, run api.py",
  420. dest="all_api",
  421. )
  422. parser.add_argument(
  423. "--llm-api",
  424. action="store_true",
  425. help="run fastchat's controller/openai_api/model_worker servers",
  426. dest="llm_api",
  427. )
  428. parser.add_argument(
  429. "-o",
  430. "--openai-api",
  431. action="store_true",
  432. help="run fastchat's controller/openai_api servers",
  433. dest="openai_api",
  434. )
  435. parser.add_argument(
  436. "-m",
  437. "--model-worker",
  438. action="store_true",
  439. help="run fastchat's model_worker server with specified model name. "
  440. "specify --model-name if not using default LLM_MODELS",
  441. dest="model_worker",
  442. )
  443. parser.add_argument(
  444. "-n",
  445. "--model-name",
  446. type=str,
  447. nargs="+",
  448. default=LLM_MODELS,
  449. help="specify model name for model worker. "
  450. "add addition names with space seperated to start multiple model workers.",
  451. dest="model_name",
  452. )
  453. parser.add_argument(
  454. "-c",
  455. "--controller",
  456. type=str,
  457. help="specify controller address the worker is registered to. default is FSCHAT_CONTROLLER",
  458. dest="controller_address",
  459. )
  460. parser.add_argument(
  461. "--api",
  462. action="store_true",
  463. help="run api.py server",
  464. dest="api",
  465. )
  466. parser.add_argument(
  467. "-p",
  468. "--api-worker",
  469. action="store_true",
  470. help="run online model api such as zhipuai",
  471. dest="api_worker",
  472. )
  473. parser.add_argument(
  474. "-w",
  475. "--webui",
  476. action="store_true",
  477. help="run webui.py server",
  478. dest="webui",
  479. )
  480. parser.add_argument(
  481. "-q",
  482. "--quiet",
  483. action="store_true",
  484. help="减少fastchat服务log信息",
  485. dest="quiet",
  486. )
  487. parser.add_argument(
  488. "-i",
  489. "--lite",
  490. action="store_true",
  491. help="以Lite模式运行:仅支持在线API的LLM对话、搜索引擎对话",
  492. dest="lite",
  493. )
  494. args = parser.parse_args()
  495. return args, parser
  496. def dump_server_info(after_start=False, args=None):
  497. import platform
  498. import langchain
  499. import fastchat
  500. from server.utils import api_address, webui_address
  501. print("\n")
  502. print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
  503. print(f"操作系统:{platform.platform()}.")
  504. print(f"python版本:{sys.version}")
  505. print(f"项目版本:{VERSION}")
  506. print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
  507. print("\n")
  508. models = LLM_MODELS
  509. if args and args.model_name:
  510. models = args.model_name
  511. print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
  512. print(f"当前启动的LLM模型:{models} @ {llm_device()}")
  513. for model in models:
  514. pprint(get_model_worker_config(model))
  515. print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
  516. if after_start:
  517. print("\n")
  518. print(f"服务端运行信息:")
  519. if args.openai_api:
  520. print(f" OpenAI API Server: {fschat_openai_api_address()}")
  521. if args.api:
  522. print(f" Chatchat API Server: {api_address()}")
  523. if args.webui:
  524. print(f" Chatchat WEBUI Server: {webui_address()}")
  525. print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
  526. print("\n")
  527. async def start_main_server():
  528. import time
  529. import signal
  530. def handler(signalname):
  531. """
  532. Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
  533. Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
  534. """
  535. def f(signal_received, frame):
  536. raise KeyboardInterrupt(f"{signalname} received")
  537. return f
  538. # This will be inherited by the child process if it is forked (not spawned)
  539. signal.signal(signal.SIGINT, handler("SIGINT"))
  540. signal.signal(signal.SIGTERM, handler("SIGTERM"))
  541. mp.set_start_method("spawn")
  542. manager = mp.Manager()
  543. run_mode = None
  544. queue = manager.Queue()
  545. args, parser = parse_args()
  546. if args.all_webui:
  547. args.openai_api = True
  548. args.model_worker = True
  549. args.api = True
  550. args.api_worker = True
  551. args.webui = True
  552. elif args.all_api:
  553. args.openai_api = True
  554. args.model_worker = True
  555. args.api = True
  556. args.api_worker = True
  557. args.webui = False
  558. elif args.llm_api:
  559. args.openai_api = True
  560. args.model_worker = True
  561. args.api_worker = True
  562. args.api = False
  563. args.webui = False
  564. if args.lite:
  565. args.model_worker = False
  566. run_mode = "lite"
  567. dump_server_info(args=args)
  568. if len(sys.argv) > 1:
  569. logger.info(f"正在启动服务:")
  570. logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
  571. processes = {"online_api": {}, "model_worker": {}}
  572. def process_count():
  573. return len(processes) + len(processes["online_api"]) + len(processes["model_worker"]) - 2
  574. if args.quiet or not log_verbose:
  575. log_level = "ERROR"
  576. else:
  577. log_level = "INFO"
  578. controller_started = manager.Event()
  579. if args.openai_api:
  580. process = Process(
  581. target=run_controller,
  582. name=f"controller",
  583. kwargs=dict(log_level=log_level, started_event=controller_started),
  584. daemon=True,
  585. )
  586. processes["controller"] = process
  587. process = Process(
  588. target=run_openai_api,
  589. name=f"openai_api",
  590. daemon=True,
  591. )
  592. processes["openai_api"] = process
  593. # model_worker_started = []
  594. # if args.model_worker:
  595. # for model_name in args.model_name:
  596. # config = get_model_worker_config(model_name)
  597. # if not config.get("online_api"):
  598. # e = manager.Event()
  599. # model_worker_started.append(e)
  600. # process = Process(
  601. # target=run_model_worker,
  602. # name=f"model_worker - {model_name}",
  603. # kwargs=dict(model_name=model_name,
  604. # controller_address=args.controller_address,
  605. # log_level=log_level,
  606. # q=queue,
  607. # started_event=e),
  608. # daemon=True,
  609. # )
  610. # processes["model_worker"][model_name] = process
  611. #
  612. # if args.api_worker:
  613. # for model_name in args.model_name:
  614. # config = get_model_worker_config(model_name)
  615. # if (config.get("online_api")
  616. # and config.get("worker_class")
  617. # and model_name in FSCHAT_MODEL_WORKERS):
  618. # e = manager.Event()
  619. # model_worker_started.append(e)
  620. # process = Process(
  621. # target=run_model_worker,
  622. # name=f"api_worker - {model_name}",
  623. # kwargs=dict(model_name=model_name,
  624. # controller_address=args.controller_address,
  625. # log_level=log_level,
  626. # q=queue,
  627. # started_event=e),
  628. # daemon=True,
  629. # )
  630. # processes["online_api"][model_name] = process
  631. api_started = manager.Event()
  632. if args.api:
  633. process = Process(
  634. target=run_api_server,
  635. name=f"API Server",
  636. kwargs=dict(started_event=api_started, run_mode=run_mode),
  637. daemon=True,
  638. )
  639. processes["api"] = process
  640. webui_started = manager.Event()
  641. if args.webui:
  642. process = Process(
  643. target=run_webui,
  644. name=f"WEBUI Server",
  645. kwargs=dict(started_event=webui_started, run_mode=run_mode),
  646. daemon=True,
  647. )
  648. processes["webui"] = process
  649. if process_count() == 0:
  650. parser.print_help()
  651. else:
  652. try:
  653. # 保证任务收到SIGINT后,能够正常退出
  654. if p := processes.get("controller"):
  655. p.start()
  656. p.name = f"{p.name} ({p.pid})"
  657. controller_started.wait() # 等待controller启动完成
  658. if p := processes.get("openai_api"):
  659. p.start()
  660. p.name = f"{p.name} ({p.pid})"
  661. for n, p in processes.get("model_worker", {}).items():
  662. p.start()
  663. p.name = f"{p.name} ({p.pid})"
  664. for n, p in processes.get("online_api", []).items():
  665. p.start()
  666. p.name = f"{p.name} ({p.pid})"
  667. # for e in model_worker_started:
  668. # e.wait()
  669. if p := processes.get("api"):
  670. p.start()
  671. p.name = f"{p.name} ({p.pid})"
  672. api_started.wait()
  673. if p := processes.get("webui"):
  674. p.start()
  675. p.name = f"{p.name} ({p.pid})"
  676. webui_started.wait()
  677. dump_server_info(after_start=True, args=args)
  678. while True:
  679. cmd = queue.get()
  680. e = manager.Event()
  681. if isinstance(cmd, list):
  682. model_name, cmd, new_model_name = cmd
  683. if cmd == "start": # 运行新模型
  684. logger.info(f"准备启动新模型进程:{new_model_name}")
  685. process = Process(
  686. target=run_model_worker,
  687. name=f"model_worker - {new_model_name}",
  688. kwargs=dict(model_name=new_model_name,
  689. controller_address=args.controller_address,
  690. log_level=log_level,
  691. q=queue,
  692. started_event=e),
  693. daemon=True,
  694. )
  695. process.start()
  696. process.name = f"{process.name} ({process.pid})"
  697. processes["model_worker"][new_model_name] = process
  698. e.wait()
  699. logger.info(f"成功启动新模型进程:{new_model_name}")
  700. elif cmd == "stop":
  701. if process := processes["model_worker"].get(model_name):
  702. time.sleep(1)
  703. process.terminate()
  704. process.join()
  705. logger.info(f"停止模型进程:{model_name}")
  706. else:
  707. logger.error(f"未找到模型进程:{model_name}")
  708. elif cmd == "replace":
  709. if process := processes["model_worker"].pop(model_name, None):
  710. logger.info(f"停止模型进程:{model_name}")
  711. start_time = datetime.now()
  712. time.sleep(1)
  713. process.terminate()
  714. process.join()
  715. process = Process(
  716. target=run_model_worker,
  717. name=f"model_worker - {new_model_name}",
  718. kwargs=dict(model_name=new_model_name,
  719. controller_address=args.controller_address,
  720. log_level=log_level,
  721. q=queue,
  722. started_event=e),
  723. daemon=True,
  724. )
  725. process.start()
  726. process.name = f"{process.name} ({process.pid})"
  727. processes["model_worker"][new_model_name] = process
  728. e.wait()
  729. timing = datetime.now() - start_time
  730. logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}。")
  731. else:
  732. logger.error(f"未找到模型进程:{model_name}")
  733. # for process in processes.get("model_worker", {}).values():
  734. # process.join()
  735. # for process in processes.get("online_api", {}).values():
  736. # process.join()
  737. # for name, process in processes.items():
  738. # if name not in ["model_worker", "online_api"]:
  739. # if isinstance(p, dict):
  740. # for work_process in p.values():
  741. # work_process.join()
  742. # else:
  743. # process.join()
  744. except Exception as e:
  745. logger.error(e)
  746. logger.warning("Caught KeyboardInterrupt! Setting stop event...")
  747. finally:
  748. for p in processes.values():
  749. logger.warning("Sending SIGKILL to %s", p)
  750. # Queues and other inter-process communication primitives can break when
  751. # process is killed, but we don't care here
  752. if isinstance(p, dict):
  753. for process in p.values():
  754. process.kill()
  755. else:
  756. p.kill()
  757. for p in processes.values():
  758. logger.info("Process status: %s", p)
  759. if __name__ == "__main__":
  760. create_tables()
  761. if sys.version_info < (3, 10):
  762. loop = asyncio.get_event_loop()
  763. else:
  764. try:
  765. loop = asyncio.get_running_loop()
  766. except RuntimeError:
  767. loop = asyncio.new_event_loop()
  768. asyncio.set_event_loop(loop)
  769. loop.run_until_complete(start_main_server())
  770. # 服务启动后接口调用示例:
  771. # import openai
  772. # openai.api_key = "EMPTY" # Not support yet
  773. # openai.api_base = "http://localhost:8888/v1"
  774. # model = "chatglm3-6b"
  775. # # create a chat completion
  776. # completion = openai.ChatCompletion.create(
  777. # model=model,
  778. # messages=[{"role": "user", "content": "Hello! What is your name?"}]
  779. # )
  780. # # print the completion
  781. # print(completion.choices[0].message.content)

随后启动起来长这样:

当然大模型对话还是不能用的,因为根本没加载运行大模型。不过亲测向量知识库可以用。我就往知识库里传了个tmp.txt文件。

Web服务这边也是显示向量嵌入正常。

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号