赞
踩
导入库
- import torch
- from langchain.text_splitter import RecursiveCharacterTextSplitter
- from langchain.vectorstores import Chroma
- from langchain.embeddings import HuggingFaceInstructEmbeddings
-
- from langchain import HuggingFacePipeline
- from langchain import PromptTemplate, LLMChain
- from langchain.chains import RetrievalQA
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
- from langchain.document_loaders import PyPDFLoader
- from transformers import pipeline
- import json
- import textwrap
选择课程文本并分段建立chorma数据库
- pdf_file_path = "/home/data/naiwen/naiwen_code/llama7b/chatbot_llama2-main/finance.pdf"
- pdf_loader = PyPDFLoader(pdf_file_path)
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=0)
- splitted_docs = text_splitter.split_documents(pdf_loader.load())
- print(splitted_docs[0])
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
- chinese_embedding_name = "/home/data/naiwen/naiwen_code/llama7b/chatbot_llama2-main/text2vec-base-chinese/"
- embeddings = HuggingFaceEmbeddings(
- model_name=chinese_embedding_name,
- model_kwargs={"device": "cuda"},
- )
- collection_name = 'llama2_demo'
- db = Chroma(
- collection_name=collection_name,
- embedding_function=embeddings,
- persist_directory='./'
- )
- db.add_documents(splitted_docs)
定义问题
test_query = input("\nEnter a query: ")
在数据库中检索与问题相关的词向量
- search_docs = db.similarity_search(test_query)
- print(search_docs)
导入训练好的模型
- model_path = "/home/data/naiwen/naiwen_code/llama7b/DocQA/models/Chinese-Llama-2-7b-4bit"
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
- model = AutoModelForCausalLM.from_pretrained(
- model_path,
- load_in_4bit=True,
- torch_dtype=torch.float16,
- device_map='auto'
- )
-
- generation_config = GenerationConfig.from_pretrained(model_path)
- pipe = pipeline(
- "text-generation",
- model=model,
- torch_dtype=torch.bfloat16,
- device_map='auto',
- max_length=2048,
- temperature=0,
- top_p=0.95,
- repetition_penalty=1.15,
- tokenizer=tokenizer,
- generation_config=generation_config,
- )
设置系统默认提示和问答格式
- B_INST, E_INST = "[INST]", "[/INST]"
- B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
- DEFAULT_SYSTEM_PROMPT = """\
- You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
- If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
-
- def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
- SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
- prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
- return prompt_template
-
- instruction = "What is the temperature in Melbourne?"
- get_prompt(instruction)
-
- llm = HuggingFacePipeline(pipeline=pipe, model_kwargs={'temperature':0})
- def parse_text(text):
- wrapped_text = textwrap.fill(text, width=100)
- print(wrapped_text +'\n\n')
-
- from langchain.memory import ConversationBufferMemory
- from langchain.prompts import PromptTemplate
-
- template = """Use the following pieces of context to answer the question at the end. If you don't know the answer,\
- just say that you don't know, don't try to make up an answer. Must use Chinese to answer the question.
- {context}
- {history}
- Question: {question}
- Helpful Answer:"""
- prompt = PromptTemplate(input_variables=["history", "context", "question"], template=template)
- memory = ConversationBufferMemory(input_key='question', memory_key='history')
其他运行变量
- from langchain.retrievers.multi_query import MultiQueryRetriever
-
- retriever_from_llm = MultiQueryRetriever.from_llm(retriever=db.as_retriever(), llm=llm)
-
- from langchain.retrievers.document_compressors import LLMChainExtractor
- from langchain.retrievers import ContextualCompressionRetriever
-
- compressor = LLMChainExtractor.from_llm(llm)
- compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever_from_llm)
- import logging
- logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.DEBUG)
- retri_docs = compression_retriever.get_relevant_documents(test_query)
- print(retri_docs)
- qa = RetrievalQA.from_chain_type(
- llm = llm,
- chain_type = 'stuff',
- retriever = compression_retriever,
- return_source_documents = True,
- chain_type_kwargs = {"prompt": prompt, "memory": memory}
- )
给出回答
- res = qa(test_query)
- print(res["result"])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。