当前位置:   article > 正文

llama7_generationconfig

generationconfig

导入库

  1. import torch
  2. from langchain.text_splitter import RecursiveCharacterTextSplitter
  3. from langchain.vectorstores import Chroma
  4. from langchain.embeddings import HuggingFaceInstructEmbeddings
  5. from langchain import HuggingFacePipeline
  6. from langchain import PromptTemplate, LLMChain
  7. from langchain.chains import RetrievalQA
  8. from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
  9. from langchain.document_loaders import PyPDFLoader
  10. from transformers import pipeline
  11. import json
  12. import textwrap

选择课程文本并分段建立chorma数据库

  1. pdf_file_path = "/home/data/naiwen/naiwen_code/llama7b/chatbot_llama2-main/finance.pdf"
  2. pdf_loader = PyPDFLoader(pdf_file_path)
  3. text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=0)
  4. splitted_docs = text_splitter.split_documents(pdf_loader.load())
  5. print(splitted_docs[0])
  6. from langchain.embeddings.huggingface import HuggingFaceEmbeddings
  7. chinese_embedding_name = "/home/data/naiwen/naiwen_code/llama7b/chatbot_llama2-main/text2vec-base-chinese/"
  8. embeddings = HuggingFaceEmbeddings(
  9. model_name=chinese_embedding_name,
  10. model_kwargs={"device": "cuda"},
  11. )
  12. collection_name = 'llama2_demo'
  13. db = Chroma(
  14. collection_name=collection_name,
  15. embedding_function=embeddings,
  16. persist_directory='./'
  17. )
  18. db.add_documents(splitted_docs)

定义问题

test_query = input("\nEnter a query: ")

在数据库中检索与问题相关的词向量

  1. search_docs = db.similarity_search(test_query)
  2. print(search_docs)

导入训练好的模型

  1. model_path = "/home/data/naiwen/naiwen_code/llama7b/DocQA/models/Chinese-Llama-2-7b-4bit"
  2. tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
  3. model = AutoModelForCausalLM.from_pretrained(
  4. model_path,
  5. load_in_4bit=True,
  6. torch_dtype=torch.float16,
  7. device_map='auto'
  8. )
  9. generation_config = GenerationConfig.from_pretrained(model_path)
  10. pipe = pipeline(
  11. "text-generation",
  12. model=model,
  13. torch_dtype=torch.bfloat16,
  14. device_map='auto',
  15. max_length=2048,
  16. temperature=0,
  17. top_p=0.95,
  18. repetition_penalty=1.15,
  19. tokenizer=tokenizer,
  20. generation_config=generation_config,
  21. )

设置系统默认提示和问答格式

  1. B_INST, E_INST = "[INST]", "[/INST]"
  2. B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
  3. DEFAULT_SYSTEM_PROMPT = """\
  4. You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
  5. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
  6. def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
  7. SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
  8. prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
  9. return prompt_template
  10. instruction = "What is the temperature in Melbourne?"
  11. get_prompt(instruction)
  12. llm = HuggingFacePipeline(pipeline=pipe, model_kwargs={'temperature':0})
  13. def parse_text(text):
  14. wrapped_text = textwrap.fill(text, width=100)
  15. print(wrapped_text +'\n\n')
  16. from langchain.memory import ConversationBufferMemory
  17. from langchain.prompts import PromptTemplate
  18. template = """Use the following pieces of context to answer the question at the end. If you don't know the answer,\
  19. just say that you don't know, don't try to make up an answer. Must use Chinese to answer the question.
  20. {context}
  21. {history}
  22. Question: {question}
  23. Helpful Answer:"""
  24. prompt = PromptTemplate(input_variables=["history", "context", "question"], template=template)
  25. memory = ConversationBufferMemory(input_key='question', memory_key='history')

其他运行变量

  1. from langchain.retrievers.multi_query import MultiQueryRetriever
  2. retriever_from_llm = MultiQueryRetriever.from_llm(retriever=db.as_retriever(), llm=llm)
  3. from langchain.retrievers.document_compressors import LLMChainExtractor
  4. from langchain.retrievers import ContextualCompressionRetriever
  5. compressor = LLMChainExtractor.from_llm(llm)
  6. compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever_from_llm)
  7. import logging
  8. logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.DEBUG)
  9. retri_docs = compression_retriever.get_relevant_documents(test_query)
  10. print(retri_docs)
  11. qa = RetrievalQA.from_chain_type(
  12. llm = llm,
  13. chain_type = 'stuff',
  14. retriever = compression_retriever,
  15. return_source_documents = True,
  16. chain_type_kwargs = {"prompt": prompt, "memory": memory}
  17. )

给出回答

  1. res = qa(test_query)
  2. print(res["result"])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/103330
推荐阅读
相关标签
  

闽ICP备14008679号