- import torch
- from langchain.text_splitter import RecursiveCharacterTextSplitter
- from langchain.vectorstores import Chroma
- from langchain.embeddings import HuggingFaceInstructEmbeddings
- from langchain import HuggingFacePipeline
- from langchain import PromptTemplate, LLMChain
- from langchain.chains import RetrievalQA
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
- from langchain.document_loaders import PyPDFLoader
- from transformers import pipeline
- import json
- import textwrap
- pdf_file_path = "/home/data/naiwen/naiwen_code/llama7b/chatbot_llama2-main/finance.pdf"
- pdf_loader = PyPDFLoader(pdf_file_path)
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=0)
- splitted_docs = text_splitter.split_documents(pdf_loader.load())
- print(splitted_docs[0])
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
- chinese_embedding_name = "/home/data/naiwen/naiwen_code/llama7b/chatbot_llama2-main/text2vec-base-chinese/"
- embeddings = HuggingFaceEmbeddings(
- model_name=chinese_embedding_name,
- model_kwargs={"device": "cuda"},
- )
- collection_name = 'llama2_demo'
- db = Chroma(
- collection_name=collection_name,
- embedding_function=embeddings,
- persist_directory='./'
- )
- db.add_documents(splitted_docs)
test_query = input("\nEnter a query: ")
- search_docs = db.similarity_search(test_query)
- print(search_docs)
- model_path = "/home/data/naiwen/naiwen_code/llama7b/DocQA/models/Chinese-Llama-2-7b-4bit"
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
- model = AutoModelForCausalLM.from_pretrained(
- model_path,
- load_in_4bit=True,
- torch_dtype=torch.float16,
- device_map='auto'
- )
- generation_config = GenerationConfig.from_pretrained(model_path)
- pipe = pipeline(
- "text-generation",
- model=model,
- torch_dtype=torch.bfloat16,
- device_map='auto',
- max_length=2048,
- temperature=0,
- top_p=0.95,
- repetition_penalty=1.15,
- tokenizer=tokenizer,
- generation_config=generation_config,
- )
- B_INST, E_INST = "[INST]", "[/INST]"
- B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
- You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
- If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
- def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
- SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
- prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
- return prompt_template
- instruction = "What is the temperature in Melbourne?"
- get_prompt(instruction)
- llm = HuggingFacePipeline(pipeline=pipe, model_kwargs={'temperature':0})
- def parse_text(text):
- wrapped_text = textwrap.fill(text, width=100)
- print(wrapped_text +'\n\n')
- from langchain.memory import ConversationBufferMemory
- from langchain.prompts import PromptTemplate
- template = """Use the following pieces of context to answer the question at the end. If you don't know the answer,\
- just say that you don't know, don't try to make up an answer. Must use Chinese to answer the question.
- {context}
- {history}
- Question: {question}
- Helpful Answer:"""
- prompt = PromptTemplate(input_variables=["history", "context", "question"], template=template)
- memory = ConversationBufferMemory(input_key='question', memory_key='history')
- from langchain.retrievers.multi_query import MultiQueryRetriever
- retriever_from_llm = MultiQueryRetriever.from_llm(retriever=db.as_retriever(), llm=llm)
- from langchain.retrievers.document_compressors import LLMChainExtractor
- from langchain.retrievers import ContextualCompressionRetriever
- compressor = LLMChainExtractor.from_llm(llm)
- compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever_from_llm)
- import logging
- logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.DEBUG)
- retri_docs = compression_retriever.get_relevant_documents(test_query)
- print(retri_docs)
- qa = RetrievalQA.from_chain_type(
- llm = llm,
- chain_type = 'stuff',
- retriever = compression_retriever,
- return_source_documents = True,
- chain_type_kwargs = {"prompt": prompt, "memory": memory}
- )
- res = qa(test_query)
- print(res["result"])
