赞
踩
RAG(Retrieval-Augmented Generation)是一种结合了检索和生成模型的方法,主要用于解决序列到序列的任务,如问答、对话系统、文本摘要等。它的核心思想是通过从大量文档中检索相关信息,然后利用这些信息来增强生成模型的输出。
原理如下图:
使用langchain框架用python代码实现,代码如下:
- import os
- import faiss
- from langchain.retrievers import ContextualCompressionRetriever
- from langchain_community.vectorstores import FAISS
- from langchain_core.prompts import PromptTemplate
- from langchain_huggingface import HuggingFaceEmbeddings
- from langchain_community.llms.ollama import Ollama
- from langchain_core.output_parsers import StrOutputParser
- from langchain_core.runnables import RunnablePassthrough
- from langchain_text_splitters import RecursiveCharacterTextSplitter
- import config as cfg
- from log_util import LogUtil
- from auto_directory_loader import AutoDirectoryLoader
- from BCEmbedding.tools.langchain import BCERerank
-
-
-
- doc_path = cfg.load_doc_dir
-
- # 在线 embedding model
- embedding_model_name = 'maidalun1020/bce-embedding-base_v1'
-
- model1_path = r'F:\ai\ai_model\maidalun1020_bce_embedding_base_v1'
- model2_path = r'F:\ai\ai_model\maidalun1020_bce_reranker_base_v1'
-
- # 本地模型路径
- embedding_model_kwargs = {'device': 'cuda:0'}
- embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True}
-
-
- embeddings = HuggingFaceEmbeddings(
- model_name=model1_path,
- model_kwargs=embedding_model_kwargs,
- encode_kwargs=embedding_encode_kwargs
- )
-
- reranker_args = {'model': model2_path, 'top_n': 5, 'device': 'cuda:0'}
-
- reranker = BCERerank(**reranker_args)
-
- # 检查FAISS向量库是否存在
- if os.path.exists(cfg.faiss_index_path):
- # 如果存在,从本地加载
- LogUtil.info("FAISS index exists. Loading from local path...")
-
- vectorstore = FAISS.load_local(cfg.faiss_index_path, embeddings, allow_dangerous_deserialization=True)
- LogUtil.info("FAISS index exists. Loading from local path...")
-
- else:
- # 如果不存在,加载txt文件并创建FAISS向量库
- LogUtil.info("FAISS index does not exist. Loading txt file and creating index...")
-
- loader = AutoDirectoryLoader(doc_path, glob="**/*.txt")
- docs = loader.load()
-
- LogUtil.info(f"Loaded documents num:{len(docs)}")
-
- # 从文档创建向量库
- # 文本分割
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=cfg.chunk_size, chunk_overlap=cfg.chunk_overlap)
- documents = text_splitter.split_documents(docs)
- LogUtil.info(f"Text splits num :{len(documents)}", )
-
- # 创建向量存储
- vectorstore = FAISS.from_documents(documents, embeddings)
- LogUtil.info("create db ok.")
-
- # 保存向量库到本地
- vectorstore.save_local(cfg.faiss_index_path)
-
- LogUtil.info("Index saved to local ok.")
-
- # 将索引搬到 GPU 上
- res = faiss.StandardGpuResources()
- gpu_index = faiss.index_cpu_to_gpu(res, 0, vectorstore.index)
- vectorstore.index = gpu_index
-
- retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 10})
- test_ask="宴桃园豪杰三结义有谁参加了?"
- # 调试查看结果
- retrieved_docs = retriever.invoke(test_ask)
- for doc in retrieved_docs:
- print('++++++单纯向量库提取++++++++')
- print(doc.page_content)
-
- compression_retriever = ContextualCompressionRetriever(
- base_compressor=reranker, base_retriever=retriever
- )
-
- response = compression_retriever.get_relevant_documents(test_ask)
-
- print("============================================compression_retriever")
- print(response)
- print("---------------------end")
-
-
- # 定义Prompt模板
- prompt_template = """
- 问题:{question}
- 相关信息:
- {retrieved_documents}
- 请根据以上信息回答问题。
- """
-
- prompt = PromptTemplate(
- input_variables=["question", "retrieved_documents"],
- template=prompt_template,
- )
-
-
- # 创建LLM模型
- llm = Ollama(model="qwen2:7b")
-
-
- def format_docs(all_docs):
- txt = "\n\n".join(doc.page_content for doc in all_docs)
- print('+++++++++使用bce_embedding + bce-reranker 上下文内容++++++')
- print(txt)
- return txt
-
-
- rag_chain = (
- {"retrieved_documents": compression_retriever | format_docs, "question": RunnablePassthrough()}
- | prompt
- | llm
- | StrOutputParser()
- )
-
- r = rag_chain.invoke(test_ask)
- print("++++++加 LLM模型处理最终结果++++++++")
- print(r)
-
-
在上面代码中我准备了一些文档,上传到向量库,其中就有三国演义的,并提出了问题:宴桃园豪杰三结义有谁参加了?运行后回答也与文档一致,测试结果正确,并在不同的环节输出相应的结果,如下图:
第一步,直接向量库检索,相近最近的10条内容如下:
经过 bce-embedding与bce_reranker两在模型的处理,结果也是准确的
再提交给LLM处理后的效果
本地环境:win10系统,本地安装了ollama 并使用的是阿里最新的qwen2:7b,其实qwen:7b测试结果也是准确的。另外还使用了bce-embedding作为嵌入模型,之前测试使用过Lam2+nomic-embed-text做了很多测试发现中文无论怎么调试,都不是很理想,回答的问题总是在胡说八道的感觉。RAG应用个人感觉重点资料输入这块也很重要,像图片里的文字非得要ocr技术,这一点发现有道的qanything做得非常好,以后看来要花点时间查看qanything的源代码好好恶补一下自己这一块。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。