当前位置:   article > 正文

用create_stuff_documents_chain构建一个完整的知识问答模型_stuffdocumentschain

stuffdocumentschain
  1. from langchain.chains.combine_documents import create_stuff_documents_chain
  2. from langchain_community.llms import QianfanLLMEndpoint
  3. from langchain_community.embeddings import QianfanEmbeddingsEndpoint
  4. from langchain_community.document_loaders import TextLoader
  5. from langchain_community.vectorstores import Chroma
  6. from langchain_core.prompts import ChatPromptTemplate
  7. from langchain_text_splitters import RecursiveCharacterTextSplitter
  8. from langchain.memory import ConversationSummaryMemory
  9. from langchain.chains import ConversationalRetrievalChain
  10. import os
  11. # 初始化大模型
  12. llm_qianfan = QianfanLLMEndpoint(temperature=0.1)
  13. # 导入文档
  14. loder_txt = TextLoader(r'D:\PycharmProjects\MyAgent\texts\text1.txt', encoding='utf8')
  15. docs_txt = loder_txt.load()
  16. # 分割文档
  17. text_splitter_txt = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=0,
  18. separators=["\n\n", "\n", " ", "", "。", ","])
  19. documents_txt = text_splitter_txt.split_documents(docs_txt)
  20. embeddings_qf = QianfanEmbeddingsEndpoint()
  21. # 文档导入向量数据库,如果之前已经生成就不重复导入,直接引用
  22. if not os.path.exists("../chroma.sqlite3"):
  23. vectordb = Chroma.from_documents(documents=documents_txt, embedding=embeddings_qf,
  24. persist_directory="D:\PycharmProjects\MyAgent")
  25. else:
  26. vectordb = Chroma(
  27. persist_directory="D:\PycharmProjects\MyAgent",
  28. embedding_function=embeddings_qf,
  29. )
  30. # 创建提示词
  31. prompt = ChatPromptTemplate.from_template("""使用下面的语料来回答本模板最末尾的问题。如果你不知道问题的答案,直接回答"抱歉,这个问题我还不清楚。",禁止随意编造答案。
  32. 为了保证答案尽可能简洁,你的回答必须不超过三句话,你的回答中不可以带有星号。
  33. 请注意!在每次回答结束之后,你都必须接上"感谢您的提问。"作为结束语
  34. 以下是一对问题和答案的样例:
  35. 请问:秦始皇的原名是什么?
  36. 秦始皇原名嬴政。感谢您的提问。
  37. 以下是语料:
  38. <context>
  39. {context}
  40. </context>
  41. Question:{input}""")
  42. # 创建检索链
  43. document_chain = create_stuff_documents_chain(llm_qianfan, prompt)
  44. retriever = vectordb.as_retriever()
  45. memory = ConversationSummaryMemory(
  46. llm=llm_qianfan,
  47. memory_key="chat_history",
  48. return_message=True
  49. )
  50. qa = ConversationalRetrievalChain.from_llm(llm=llm_qianfan, retriever=retriever, memory=memory)
  51. res = qa.invoke({"question": "根据勾股定理能推断出什么理论?"})
  52. print(res["answer"])

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号