当前位置:   article > 正文

RAG - QA + Qwen + dashscope

RAG - QA + Qwen + dashscope


转载改编自:qwen_doc_search_QA_based_on_dashscope.ipynb
https://github.com/modelscope/modelscope/blob/master/examples/pytorch/application/qwen_doc_search_QA_based_on_dashscope.ipynb


一、关于项目


二、准备


1、安装依赖包

# install required packages
!pip install dashvector dashscope
!pip install transformers_stream_generator python-dotenv
  • 1
  • 2
  • 3

2、准备数据集

这里使用的是 中文突发事件语料库,由 上海大学-语义智能实验室 提供
https://github.com/shijiebei2009/CEC-Corpus

# prepare news corpus as knowledge source
!git clone https://github.com/shijiebei2009/CEC-Corpus.git
  • 1
  • 2

数据集内容:

../datasets/CEC-Corpus$ tree
.
├── CEC
│   ├── 交通事故
│   │   ├── 101国道密云段现惨祸客车农用车相撞致6人亡.xml
│   │   ├── 104国道浙江温岭段发生翻车事故致2死2伤.xml 
│   │   ├── ... 
│   │   └── 黑龙江五常发生特大交通事故6人死亡.xml
│   ├── 地震
│   │   ├── 上海:高层建筑普遍有震感但不会造成危害.xml 
│   │   ├── ...
│   │   ├── 重庆市区有明显震感电线杆在摇晃.xml
│   │   └── 青海发生6.3级地震震区人口密度低尚无人员伤亡.xml
│   ├── 恐怖袭击
│   │   ├── 4月7日凌晨5时,近300名穿着“警察”制服.xml
│   │   ├── 世界杯险遇恐怖袭击警方发现犯罪组织欲炸桥.xml
│   │   ├── ... 
│   │   └── 阿尔及利亚汽车炸弹爆炸11人死31人伤.xml
│   ├── 火灾
│   │   ├── 上海永嘉路老式洋房突发火灾好心市民合力救出被困老太.xml
│   │   ├── 云南丽江束河古镇昨凌晨失火.xml
│   │   ├── ... 
│   │   └── 马尼拉华人区住宅发生火灾一华人老妇被烧伤.xml
│   └── 食物中毒
│       ├── 上海一家公司70多名员工食物中毒.xml 
│       ├── ... 
│       └── 龙岗一小食店发生一起疑似中毒事件.xml
├── raw corpus (332 文件)
│   └── allSourceText
│       ├── 101国道密云段现惨祸客车农用车相撞致6人亡.txt  
│       ├── ...
│       ├── 黑龙江鸡西市20多名小学生疑似食物中毒.txt
│       └── 龙岗一小食店发生一起疑似中毒事件.txt
└── README.md

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

3、准备 API-Key

如果没有,可以点击申请;
阿里云需要实名认证后才能刚申请,人脸识别也很快。


三、代码实现

import dashscope
import os
from dotenv import load_dotenv
from dashscope import TextEmbedding
from dashvector import Client, Doc

# get env variable from .env
# please make sure DASHSCOPE_KEY is defined in .env
load_dotenv()
dashscope.api_key = os.getenv('DASHSCOPE_KEY')


# initialize DashVector for embedding's indexing and searching
dashvector_client = Client(api_key='{your-dashvector-api-key}')

# define collection name
collection_name = 'news_embeddings'

# delete if already exist
dashvector_client.delete(collection_name)

# create a collection with embedding size of 1536
rsp = dashvector_client.create(collection_name, 1536)
collection = dashvector_client.get(collection_name)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

准备数据

def prepare_data_from_dir(path, size):
    # prepare the data from a file folder in order to upsert to DashVector with a reasonable doc's size.
    batch_docs = []
    for file in os.listdir(path):
        with open(path + '/' + file, 'r', encoding='utf-8') as f:
            batch_docs.append(f.read())
            if len(batch_docs) == size:
                yield batch_docs[:]
                batch_docs.clear()

    if batch_docs:
        yield batch_docs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

def prepare_data_from_file(path, size):
    # prepare the data from file in order to upsert to DashVector with a reasonable doc's size.
    batch_docs = []
    chunk_size = 12
    with open(path, 'r', encoding='utf-8') as f:
        doc = ''
        count = 0
        for line in f:
            if count < chunk_size and line.strip() != '':
                doc += line
                count += 1
            if count == chunk_size:
                batch_docs.append(doc)
                if len(batch_docs) == size:
                    yield batch_docs[:]
                    batch_docs.clear()
                doc = ''
                count = 0

    if batch_docs:
        yield batch_docs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

生成 embeddings

def generate_embeddings(docs):
    # create embeddings via DashScope's TextEmbedding model API
    rsp = TextEmbedding.call(model=TextEmbedding.Models.text_embedding_v1,
                             input=docs)
    embeddings = [record['embedding'] for record in rsp.output['embeddings']]
    return embeddings if isinstance(docs, list) else embeddings[0]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

id = 0
dir_path = 'xx/CEC-Corpus/raw corpus/allSourceText'

# indexing the raw docs with index to DashVector
collection = dashvector_client.get(collection_name)

# embedding api max batch size
batch_size = 4  

for news in list(prepare_data_from_dir(dir_path, batch_size)):
    ids = [id + i for i, _ in enumerate(news)]
    id += len(news)
    # generate embedding from raw docs
    vectors = generate_embeddings(news)
    # upsert and index
    ret = collection.upsert(
        [
            Doc(id=str(id), vector=vector, fields={"raw": doc})
            for id, doc, vector in zip(ids, news, vectors)
        ]
    )
    print(ret)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

# check the collection status
collection = dashvector_client.get(collection_name)
rsp = collection.stats()
print(rsp)
  • 1
  • 2
  • 3
  • 4

检索方法

def search_relevant_context(question, topk=1, client=dashvector_client):
    # query and recall the relevant information
    collection = client.get(collection_name)

    # recall the top k similarity results from DashVector
    rsp = collection.query(generate_embeddings(question), output_fields=['raw'],
                           topk=topk)
    return "".join([item.fields['raw'] for item in rsp.output])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

# query the top 1 results
question = '清华博士发生了什么?'
context = search_relevant_context(question, topk=1)
print(context)
  • 1
  • 2
  • 3
  • 4

2006-08-26 10:41:45
823日上午940分,京沪高速公路沧州服务区附近,一辆由北向南行驶的金杯面包车撞到高速公路护栏上,车上5名清华大学博士后研究人员及1名司机受伤,被紧急送往沧州二医院抢救。截至发稿时,仍有一名张姓博士后研究人员尚未脱离危险。
  • 1
  • 2

# initialize qwen 7B model
from modelscope import AutoModelForCausalLM, AutoTokenizer
from modelscope import GenerationConfig

tokenizer = AutoTokenizer.from_pretrained("qwen/Qwen-7B-Chat", revision = 'v1.0.5',trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("qwen/Qwen-7B-Chat", revision = 'v1.0.5',device_map="auto", trust_remote_code=True, fp16=True).eval()
model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat",revision = 'v1.0.5', trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

问答

# define a prompt template for the vectorDB-enhanced LLM generation
def answer_question(question, context):
    prompt = f'''请基于```内的内容回答问题。"
	\```
	{context}
	\```
	我的问题是:{question}。
	'''
	
	history = None
	print(prompt)
	response, history = model.chat(tokenizer, prompt, history=None) 
	return response 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

# test the case on plain LLM without vectorDB enhancement
question = '清华博士发生了什么?'
answer = answer_question(question, '')
print(f'question: {question}\n' f'answer: {answer}')
  • 1
  • 2
  • 3
  • 4
请基于```内的内容回答问题。"
\```

\```
我的问题是:清华博士发生了什么?。

question: 清华博士发生了什么?
answer: 清华博士是指清华大学的博士研究生。作为一名AI语言模型,我无法获取个人的身份信息或具体事件,因此无法回答清华博士发生了什么。如果您需要了解更多相关信息,建议您查询相关媒体或官方网站。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

# test the case with knowledge
context = search_relevant_context(question, topk=1)
answer = answer_question(question, context)
print(f'question: {question}\n' f'answer: {answer}')
  • 1
  • 2
  • 3
  • 4

请基于```内的内容回答问题。"
\```
	2006-08-26 10:41:45
8月23日上午9时40分,京沪高速公路沧州服务区附近,一辆由北向南行驶的金杯面包车撞到高速公路护栏上,车上5名清华大学博士后研究人员及1名司机受伤,被紧急送往沧州二医院抢救。截至发稿时,仍有一名张姓博士后研究人员尚未脱离危险。


\```
	我的问题是:清华博士发生了什么?。

question: 清华博士发生了什么?
answer: 8月23日上午9时40分,一辆由北向南行驶的金杯面包车撞到高速公路护栏上,车上5名清华大学博士后研究人员及1名司机受伤,被紧急送往沧州二医院抢救。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

2024-03-24(日)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/477306
推荐阅读
相关标签
  

闽ICP备14008679号