赞
踩
vanna可实现自然语言转SQL,尝试本地部署vanna对接数据库,将自然语言转成标准的SQL对数据库进行查询。本文先对vanna源码进行分析,该部分内容为接入各类ai方法。
这段代码定义了一个 GoogleGeminiChat
类,继承自 VannaBase
。
import os from ..base import VannaBase # 从父模块导入 VannaBase 类 class GoogleGeminiChat(VannaBase): def __init__(self, config=None): VannaBase.__init__(self, config=config) # 调用父类的构造函数 # 默认的温度值,可以通过 config 覆盖 self.temperature = 0.7 # 如果配置中包含 temperature,则覆盖默认值 if "temperature" in config: self.temperature = config["temperature"] # 设置模型名称,如果配置中没有指定,则使用默认值 "gemini-1.0-pro" if "model_name" in config: model_name = config["model_name"] else: model_name = "gemini-1.0-pro" self.google_api_key = None # 初始化 API 密钥变量 # 如果配置中提供了 api_key 或环境变量中有 GOOGLE_API_KEY,则使用它 if "api_key" in config or os.getenv("GOOGLE_API_KEY"): """ 如果 Google api_key 通过配置提供 或设置为环境变量,则分配它。 """ import google.generativeai as genai # 导入 google.generativeai 库 genai.configure(api_key=config["api_key"]) # 使用提供的 API 密钥进行配置 self.chat_model = genai.GenerativeModel(model_name) # 初始化生成模型 else: # 使用 VertexAI 进行身份验证 from vertexai.preview.generative_models import GenerativeModel # 导入 VertexAI 的生成模型类 self.chat_model = GenerativeModel("gemini-pro") # 初始化生成模型 def system_message(self, message: str) -> any: return message # 返回系统消息 def user_message(self, message: str) -> any: return message # 返回用户消息 def assistant_message(self, message: str) -> any: return message # 返回助手消息 def submit_prompt(self, prompt, **kwargs) -> str: # 使用生成模型生成内容 response = self.chat_model.generate_content( prompt, generation_config={ "temperature": self.temperature, # 使用配置的温度值 }, ) return response.text # 返回生成的文本
类的初始化:
VannaBase
的构造函数。self.temperature
,可以通过配置覆盖。model_name
,可以通过配置覆盖,默认使用 "gemini-1.0-pro"
。google.generativeai
库进行配置并初始化生成模型。如果没有提供 API 密钥,则使用 VertexAI 进行身份验证并初始化生成模型。消息处理方法:
system_message
、user_message
和 assistant_message
方法都是简单地返回传入的消息。这些方法可以在实际应用中进行扩展,以处理不同类型的消息。提交提示:
submit_prompt
方法使用生成模型生成内容。它接受一个提示(prompt)并生成相应的文本。生成的配置包括温度值。这段代码定义了一个 Hf
类,继承自 VannaBase
。
import re # 导入正则表达式模块 from transformers import AutoTokenizer, AutoModelForCausalLM # 从 transformers 库中导入自动分词器和因果语言模型 from ..base import VannaBase # 从父模块导入 VannaBase 类 class Hf(VannaBase): def __init__(self, config=None): # 从配置中获取模型名称,例如 "meta-llama/Meta-Llama-3-8B-Instruct" model_name = self.config.get("model_name", None) # 从预训练模型中加载分词器 self.tokenizer = AutoTokenizer.from_pretrained(model_name) # 从预训练模型中加载因果语言模型,设置数据类型和设备映射为自动 self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto", ) def system_message(self, message: str) -> any: # 返回包含角色和内容的字典,表示系统消息 return {"role": "system", "content": message} def user_message(self, message: str) -> any: # 返回包含角色和内容的字典,表示用户消息 return {"role": "user", "content": message} def assistant_message(self, message: str) -> any: # 返回包含角色和内容的字典,表示助手消息 return {"role": "assistant", "content": message} def extract_sql_query(self, text): """ 提取第一个 SQL 语句,从 'select' 开始,不区分大小写, 匹配到第一个分号、三个反引号或字符串结尾, 并移除提取字符串中的三个反引号(如果存在)。 参数: - text (str): 要搜索 SQL 语句的字符串。 返回: - str: 找到的第一个 SQL 语句,移除三个反引号后的结果,如果没有匹配则返回空字符串。 """ # 正则表达式模式,用于查找 'select'(忽略大小写)并捕获到分号、三个反引号或字符串结尾 pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL) # 搜索匹配项 match = pattern.search(text) if match: # 如果存在匹配项,移除匹配字符串中的三个反引号 return match.group(0).replace("```", "") else: # 如果没有匹配项,返回输入字符串 return text def generate_sql(self, question: str, **kwargs) -> str: # 使用父类的 generate_sql 方法生成 SQL 语句 sql = super().generate_sql(question, **kwargs) # 替换字符串中的 "\_" 为 "_" sql = sql.replace("\\_", "_") # 替换字符串中的 "\" 为 "" sql = sql.replace("\\", "") # 提取并返回 SQL 查询 return self.extract_sql_query(sql) def submit_prompt(self, prompt, **kwargs) -> str: # 使用分词器对提示进行处理,添加生成提示并返回张量格式 input_ids = self.tokenizer.apply_chat_template( prompt, add_generation_prompt=True, return_tensors="pt" ).to(self.model.device) # 使用模型生成输出,设置最大新标记数、结束标记 ID、是否进行采样、温度和 top-p 参数 outputs = self.model.generate( input_ids, max_new_tokens=512, eos_token_id=self.tokenizer.eos_token_id, do_sample=True, temperature=1, top_p=0.9, ) # 提取生成的响应,跳过输入部分 response = outputs[0][input_ids.shape[-1] :] # 解码生成的响应,跳过特殊标记 response = self.tokenizer.decode(response, skip_special_tokens=True) # 记录响应 self.log(response) # 返回响应 return response
类的初始化:
torch_dtype
和 device_map
为 auto
,以便自动调整数据类型和设备。消息处理方法:
system_message
、user_message
和 assistant_message
方法返回包含角色和内容的字典,用于表示不同类型的消息。提取 SQL 查询:
extract_sql_query
方法使用正则表达式从输入文本中提取第一个 SQL 语句,匹配到分号、三个反引号或字符串结尾,并移除提取字符串中的三个反引号。生成 SQL 查询:
generate_sql
方法首先调用父类的 generate_sql
方法生成 SQL 语句,然后替换字符串中的特定字符,并使用 extract_sql_query
方法提取 SQL 查询。提交提示:
submit_prompt
方法使用分词器对提示进行处理,并使用模型生成响应。生成的响应经过解码和记录后返回。import uuid # 导入用于生成唯一标识符的模块 import marqo # 导入 Marqo 库,用于处理向量存储 import pandas as pd # 导入 pandas 库,用于数据处理 from ..base import VannaBase # 从父模块导入 VannaBase 类 class Marqo_VectorStore(VannaBase): def __init__(self, config=None): # 调用父类的构造函数进行初始化 VannaBase.__init__(self, config=config) # 检查配置中是否包含 marqo_url,否则使用默认值 if config is not None and "marqo_url" in config: marqo_url = config["marqo_url"] else: marqo_url = "http://localhost:8882" # 检查配置中是否包含 marqo_model,否则使用默认模型名称 if config is not None and "marqo_model" in config: marqo_model = config["marqo_model"] else: marqo_model = "hf/all_datasets_v4_MiniLM-L6" # 创建 Marqo 客户端 self.mq = marqo.Client(url=marqo_url) # 创建三个索引:vanna-sql、vanna-ddl、vanna-doc for index in ["vanna-sql", "vanna-ddl", "vanna-doc"]: try: # 尝试创建索引 self.mq.create_index(index, model=marqo_model) except Exception as e: # 如果索引已经存在,则捕获异常并打印错误信息 print(e) print(f"Marqo index {index} already exists") pass def generate_embedding(self, data: str, **kwargs) -> list[float]: # Marqo 不需要生成嵌入 pass def add_question_sql(self, question: str, sql: str, **kwargs) -> str: # 生成唯一标识符并添加 "-sql" 后缀 id = str(uuid.uuid4()) + "-sql" # 创建包含问题和 SQL 的字典 question_sql_dict = { "question": question, "sql": sql, "_id": id, } # 将文档添加到 "vanna-sql" 索引中 self.mq.index("vanna-sql").add_documents( [question_sql_dict], tensor_fields=["question", "sql"], ) return id def add_ddl(self, ddl: str, **kwargs) -> str: # 生成唯一标识符并添加 "-ddl" 后缀 id = str(uuid.uuid4()) + "-ddl" # 创建包含 DDL 的字典 ddl_dict = { "ddl": ddl, "_id": id, } # 将文档添加到 "vanna-ddl" 索引中 self.mq.index("vanna-ddl").add_documents( [ddl_dict], tensor_fields=["ddl"], ) return id def add_documentation(self, documentation: str, **kwargs) -> str: # 生成唯一标识符并添加 "-doc" 后缀 id = str(uuid.uuid4()) + "-doc" # 创建包含文档的字典 doc_dict = { "doc": documentation, "_id": id, } # 将文档添加到 "vanna-doc" 索引中 self.mq.index("vanna-doc").add_documents( [doc_dict], tensor_fields=["doc"], ) return id def get_training_data(self, **kwargs) -> pd.DataFrame: data = [] # 初始化一个空列表用于存储数据 # 从 "vanna-doc" 索引中检索文档 for hit in self.mq.index("vanna-doc").search("", limit=1000)["hits"]: data.append( { "id": hit["_id"], "training_data_type": "documentation", "question": "", "content": hit["doc"], } ) # 从 "vanna-ddl" 索引中检索文档 for hit in self.mq.index("vanna-ddl").search("", limit=1000)["hits"]: data.append( { "id": hit["_id"], "training_data_type": "ddl", "question": "", "content": hit["ddl"], } ) # 从 "vanna-sql" 索引中检索文档 for hit in self.mq.index("vanna-sql").search("", limit=1000)["hits"]: data.append( { "id": hit["_id"], "training_data_type": "sql", "question": hit["question"], "content": hit["sql"], } ) # 将数据转换为 DataFrame 并返回 df = pd.DataFrame(data) return df def remove_training_data(self, id: str, **kwargs) -> bool: # 根据 ID 后缀确定要删除的索引中的文档 if id.endswith("-sql"): self.mq.index("vanna-sql").delete_documents(ids=[id]) return True elif id.endswith("-ddl"): self.mq.index("vanna-ddl").delete_documents(ids=[id]) return True elif id.endswith("-doc"): self.mq.index("vanna-doc").delete_documents(ids=[id]) return True else: return False @staticmethod def _extract_documents(data) -> list: # 检查数据中是否包含 'hits' 键且其值是否为列表 if "hits" in data and isinstance(data["hits"], list): # 如果 'hits' 列表为空,则返回空列表 if len(data["hits"]) == 0: return [] # 如果 'hits' 中包含 "doc" 键,则返回其值 if "doc" in data["hits"][0]: return [hit["doc"] for hit in data["hits"]] # 如果 'hits' 中包含 "ddl" 键,则返回其值 if "ddl" in data["hits"][0]: return [hit["ddl"] for hit in data["hits"]] # 否则,返回所有命中的项目 return [ {key: value for key, value in hit.items() if not key.startswith("_")} for hit in data["hits"] ] else: # 如果 'hits' 不存在或不是列表,则返回空列表 return [] def get_similar_question_sql(self, question: str, **kwargs) -> list: # 从 "vanna-sql" 索引中搜索相似的问题 SQL,并提取文档 return Marqo_VectorStore._extract_documents( self.mq.index("vanna-sql").search(question) ) def get_related_ddl(self, question: str, **kwargs) -> list: # 从 "vanna-ddl" 索引中搜索相关的 DDL,并提取文档 return Marqo_VectorStore._extract_documents( self.mq.index("vanna-ddl").search(question) ) def get_related_documentation(self, question: str, **kwargs) -> list: # 从 "vanna-doc" 索引中搜索相关的文档,并提取文档 return Marqo_VectorStore._extract_documents( self.mq.index("vanna-doc").search(question) )
类的初始化:
marqo_url
和 marqo_model
,如果未提供则使用默认值。vanna-sql
、vanna-ddl
、vanna-doc
。生成嵌入:
generate_embedding
方法目前没有实现,因为 Marqo 不需要生成嵌入。添加问题和 SQL:
add_question_sql
方法生成一个唯一标识符,并将问题和 SQL 添加到 vanna-sql
索引中。添加 DDL:
add_ddl
方法生成一个唯一标识符,并将 DDL 添加到 vanna-ddl
索引中。添加文档:
add_documentation
方法生成一个唯一标识符,并将文档添加到 vanna-doc
索引中。获取训练数据:
get_training_data
方法从三个索引中检索文档并转换为 pandas DataFrame。删除训练数据:
remove_training_data
方法根据文档 ID 后缀确定要删除的索引中的文档。静态方法提取文档:
_extract_documents
静态方法从搜索结果中提取文档。获取相似问题的 SQL:
get_similar_question_sql
方法从 vanna-sql
索引中搜索相似的问题 SQL,并提取文档。获取相关的 DDL:
get_related_ddl
方法从 vanna-ddl
索引中搜索相关的 DDL,并提取文档。get_related_documentation
方法从 vanna-doc
索引中搜索相关的文档,并提取文档。from mistralai.client import MistralClient # 从 Mistral 库导入 MistralClient from mistralai.models.chat_completion import ChatMessage # 从 Mistral 库导入 ChatMessage 模型 from ..base import VannaBase # 从父模块导入 VannaBase 类 class Mistral(VannaBase): def __init__(self, config=None): # 如果没有提供配置,抛出 ValueError 异常 if config is None: raise ValueError( "For Mistral, config must be provided with an api_key and model" ) # 如果配置中不包含 api_key,抛出 ValueError 异常 if "api_key" not in config: raise ValueError("config must contain a Mistral api_key") # 如果配置中不包含 model,抛出 ValueError 异常 if "model" not in config: raise ValueError("config must contain a Mistral model") # 从配置中获取 api_key 和 model api_key = config["api_key"] model = config["model"] # 创建 Mistral 客户端实例 self.client = MistralClient(api_key=api_key) # 设置模型名称 self.model = model def system_message(self, message: str) -> any: # 返回一个系统消息对象 return ChatMessage(role="system", content=message) def user_message(self, message: str) -> any: # 返回一个用户消息对象 return ChatMessage(role="user", content=message) def assistant_message(self, message: str) -> any: # 返回一个助手消息对象 return ChatMessage(role="assistant", content=message) def generate_sql(self, question: str, **kwargs) -> str: # 使用父类的方法生成 SQL 查询 sql = super().generate_sql(question, **kwargs) # 将 "\_" 替换为 "_" sql = sql.replace("\\_", "_") return sql def submit_prompt(self, prompt, **kwargs) -> str: # 使用 Mistral 客户端发送聊天请求 chat_response = self.client.chat( model=self.model, messages=prompt, ) # 返回聊天响应中的消息内容 return chat_response.choices[0].message.content
类的初始化:
api_key
和 model
,如果缺少任意一个则抛出 ValueError
异常。系统消息:
system_message
方法创建并返回一个系统消息对象。用户消息:
user_message
方法创建并返回一个用户消息对象。助手消息:
assistant_message
方法创建并返回一个助手消息对象。生成 SQL 查询:
generate_sql
方法调用父类的方法生成 SQL 查询,然后替换其中的 “_” 为 “_” 并返回最终的 SQL 查询。提交提示:
submit_prompt
方法使用 Mistral 客户端发送聊天请求,并返回聊天响应中的消息内容。这段代码定义了一个名为 MockEmbedding
的类,继承自 VannaBase
。
MockEmbedding
类包含一个构造函数和一个 generate_embedding
方法,后者返回一个固定的浮点数列表。from typing import List # 从 typing 模块导入 List 类型,用于类型注解
from ..base import VannaBase # 从上一级目录的 base 模块导入 VannaBase 类
class MockEmbedding(VannaBase): # 定义一个名为 MockEmbedding 的类,继承自 VannaBase
def __init__(self, config=None): # 定义类的构造函数,接受一个可选的配置参数 config
pass # 目前构造函数不做任何操作
def generate_embedding(self, data: str, **kwargs) -> List[float]: # 定义一个名为 generate_embedding 的方法,接受一个字符串参数 data 和其他可选参数,返回一个浮点数列表
return [1.0, 2.0, 3.0, 4.0, 5.0] # 返回一个固定的浮点数列表
导入 List
类型:
typing
模块导入 List
类型,用于类型注解,指定 generate_embedding
方法返回值的类型。导入 VannaBase
类:
base
模块导入 VannaBase
类。VannaBase
类可能是所有具体实现的基类,提供一些基本的功能和接口。定义 MockEmbedding
类:
MockEmbedding
类继承自 VannaBase
,用于模拟嵌入生成的功能,通常在测试或开发阶段使用。构造函数 __init__
:
config
。目前构造函数不做任何实际操作,仅包含一个 pass
语句。generate_embedding
方法:
generate_embedding
的方法,接受一个字符串参数 data
和其他可选参数,返回一个浮点数列表。[1.0, 2.0, 3.0, 4.0, 5.0]
,模拟生成的嵌入向量。这段代码定义了一个名为 MockLLM
的类,继承自 VannaBase
。
MockLLM
类包含一个构造函数和多个消息创建方法,以及一个提交提示的方法,后者返回一个固定的字符串响应。from ..base import VannaBase # 从上一级目录的 base 模块导入 VannaBase 类 class MockLLM(VannaBase): # 定义一个名为 MockLLM 的类,继承自 VannaBase def __init__(self, config=None): # 定义类的构造函数,接受一个可选的配置参数 config pass # 目前构造函数不做任何操作 def system_message(self, message: str) -> any: # 定义一个名为 system_message 的方法,接受一个字符串参数 message,返回一个任意类型的值 return {"role": "system", "content": message} # 返回一个包含角色和内容的字典,角色为 "system" def user_message(self, message: str) -> any: # 定义一个名为 user_message 的方法,接受一个字符串参数 message,返回一个任意类型的值 return {"role": "user", "content": message} # 返回一个包含角色和内容的字典,角色为 "user" def assistant_message(self, message: str) -> any: # 定义一个名为 assistant_message 的方法,接受一个字符串参数 message,返回一个任意类型的值 return {"role": "assistant", "content": message} # 返回一个包含角色和内容的字典,角色为 "assistant" def submit_prompt(self, prompt, **kwargs) -> str: # 定义一个名为 submit_prompt 的方法,接受一个参数 prompt 和其他可选参数,返回一个字符串 return "Mock LLM response" # 返回一个固定的字符串 "Mock LLM response"
导入 VannaBase
类:
base
模块导入 VannaBase
类。VannaBase
类可能是所有具体实现的基类,提供一些基本的功能和接口。定义 MockLLM
类:
MockLLM
类继承自 VannaBase
,用于模拟大语言模型(LLM)的功能,通常在测试或开发阶段使用。构造函数 __init__
:
config
。目前构造函数不做任何实际操作,仅包含一个 pass
语句。system_message
方法:
system_message
的方法,接受一个字符串参数 message
,返回一个包含角色和内容的字典,角色为 “system”。user_message
方法:
user_message
的方法,接受一个字符串参数 message
,返回一个包含角色和内容的字典,角色为 “user”。assistant_message
方法:
assistant_message
的方法,接受一个字符串参数 message
,返回一个包含角色和内容的字典,角色为 “assistant”。submit_prompt
方法:
submit_prompt
的方法,接受一个参数 prompt
和其他可选参数,返回一个字符串。MockVectorDB
类模拟了一个简单的向量数据库,实现了一些基本的数据库操作和数据处理功能。该类主要用于测试和开发阶段,提供一个简单的实现,而无需依赖实际的数据库操作。
pandas
DataFrame 返回训练数据,便于数据的处理和操作。import pandas as pd # 导入 pandas 库,用于处理数据 from ..base import VannaBase # 从上一级目录的 base 模块导入 VannaBase 类 class MockVectorDB(VannaBase): # 定义一个名为 MockVectorDB 的类,继承自 VannaBase def __init__(self, config=None): # 定义类的构造函数,接受一个可选的配置参数 config pass # 目前构造函数不做任何操作 def _get_id(self, value: str, **kwargs) -> str: # 定义一个私有方法 _get_id,接受一个字符串参数 value 和其他可选参数 # 将值进行哈希处理并返回 ID return str(hash(value)) # 返回值的哈希值转换成字符串形式 def add_ddl(self, ddl: str, **kwargs) -> str: # 定义一个方法 add_ddl,接受一个字符串参数 ddl 和其他可选参数 return self._get_id(ddl) # 调用 _get_id 方法并返回其结果 def add_documentation(self, doc: str, **kwargs) -> str: # 定义一个方法 add_documentation,接受一个字符串参数 doc 和其他可选参数 return self._get_id(doc) # 调用 _get_id 方法并返回其结果 def add_question_sql(self, question: str, sql: str, **kwargs) -> str: # 定义一个方法 add_question_sql,接受一个字符串参数 question 和 sql 以及其他可选参数 return self._get_id(question) # 调用 _get_id 方法并返回其结果 def get_related_ddl(self, question: str, **kwargs) -> list: # 定义一个方法 get_related_ddl,接受一个字符串参数 question 和其他可选参数 return [] # 返回一个空列表 def get_related_documentation(self, question: str, **kwargs) -> list: # 定义一个方法 get_related_documentation,接受一个字符串参数 question 和其他可选参数 return [] # 返回一个空列表 def get_similar_question_sql(self, question: str, **kwargs) -> list: # 定义一个方法 get_similar_question_sql,接受一个字符串参数 question 和其他可选参数 return [] # 返回一个空列表 def get_training_data(self, **kwargs) -> pd.DataFrame: # 定义一个方法 get_training_data,接受其他可选参数,返回一个 pandas DataFrame # 返回一个包含训练数据的 DataFrame return pd.DataFrame({'id': {0: '19546-ddl', # 训练数据 ID 1: '91597-sql', 2: '133976-sql', 3: '59851-doc', 4: '73046-sql'}, 'training_data_type': {0: 'ddl', # 训练数据类型 1: 'sql', 2: 'sql', 3: 'documentation', 4: 'sql'}, 'question': {0: None, # 问题 1: 'What are the top selling genres?', 2: 'What are the low 7 artists by sales?', 3: None, 4: 'What is the total sales for each customer?'}, 'content': {0: 'CREATE TABLE [Invoice]\n(\n [InvoiceId] INTEGER NOT NULL,\n [CustomerId] INTEGER NOT NULL,\n [InvoiceDate] DATETIME NOT NULL,\n [BillingAddress] NVARCHAR(70),\n [BillingCity] NVARCHAR(40),\n [BillingState] NVARCHAR(40),\n [BillingCountry] NVARCHAR(40),\n [BillingPostalCode] NVARCHAR(10),\n [Total] NUMERIC(10,2) NOT NULL,\n CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]),\n FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)', # 内容 1: 'SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;', 2: 'SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;', 3: 'This is a SQLite database. For dates rememeber to use SQLite syntax.', 4: 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;'}}) def remove_training_data(id: str, **kwargs) -> bool: # 定义一个方法 remove_training_data,接受一个字符串参数 id 和其他可选参数 return True # 返回 True,表示删除成功
导入模块:
pandas
:用于数据处理和操作。VannaBase
:从上一级目录的 base
模块导入 VannaBase
类,作为基类。定义 MockVectorDB
类:
VannaBase
,模拟一个向量数据库的基本操作。构造函数 __init__
:
pass
语句。_get_id
方法:
_get_id
,接受一个字符串参数 value
,返回该字符串的哈希值作为 ID。add_ddl
方法:
_get_id
方法返回其哈希值作为 ID。add_documentation
方法:
_get_id
方法返回其哈希值作为 ID。add_question_sql
方法:
_get_id
方法返回问题的哈希值作为 ID。get_related_ddl
方法:
get_related_documentation
方法:
get_similar_question_sql
方法:
get_training_data
方法:
remove_training_data
方法:
True
,表示成功删除训练数据。OpenAI_Chat
类实现了与 OpenAI 模型的基本接口,提供了初始化、消息处理和提示提交等功能。该类用于与 OpenAI 模型进行通信,发送提示并接收响应,同时处理和返回响应中的文本内容。
OpenAI_Chat
类实现了与 OpenAI 模型的接口,通过 HTTP 请求与 OpenAI 模型进行通信,发送提示并接收响应。import os # 导入 os 模块,用于与操作系统交互 from openai import OpenAI # 从 openai 模块导入 OpenAI 类 from ..base import VannaBase # 从上一级目录的 base 模块导入 VannaBase 类 class OpenAI_Chat(VannaBase): # 定义一个名为 OpenAI_Chat 的类,继承自 VannaBase def __init__(self, client=None, config=None): # 定义类的构造函数,接受可选参数 client 和 config VannaBase.__init__(self, config=config) # 调用父类 VannaBase 的构造函数 # 设置默认参数 - 可以通过 config 覆盖 self.temperature = 0.7 self.max_tokens = 500 if "temperature" in config: self.temperature = config["temperature"] # 如果 config 中包含 "temperature",则覆盖默认温度值 if "max_tokens" in config: self.max_tokens = config["max_tokens"] # 如果 config 中包含 "max_tokens",则覆盖默认最大 token 数 if "api_type" in config: raise Exception( "Passing api_type is now deprecated. Please pass an OpenAI client instead." ) # 如果 config 中包含 "api_type",抛出异常 if "api_base" in config: raise Exception( "Passing api_base is now deprecated. Please pass an OpenAI client instead." ) # 如果 config 中包含 "api_base",抛出异常 if "api_version" in config: raise Exception( "Passing api_version is now deprecated. Please pass an OpenAI client instead." ) # 如果 config 中包含 "api_version",抛出异常 if client is not None: self.client = client # 如果提供了 client 参数,则将其设置为实例属性 return if config is None and client is None: self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # 如果没有提供 config 和 client,从环境变量中获取 API 密钥并创建 OpenAI 客户端 return if "api_key" in config: self.client = OpenAI(api_key=config["api_key"]) # 如果 config 中包含 "api_key",使用该密钥创建 OpenAI 客户端 def system_message(self, message: str) -> any: # 定义一个方法 system_message,接受一个字符串参数 message return {"role": "system", "content": message} # 返回一个包含角色和内容的字典 def user_message(self, message: str) -> any: # 定义一个方法 user_message,接受一个字符串参数 message return {"role": "user", "content": message} # 返回一个包含角色和内容的字典 def assistant_message(self, message: str) -> any: # 定义一个方法 assistant_message,接受一个字符串参数 message return {"role": "assistant", "content": message} # 返回一个包含角色和内容的字典 def submit_prompt(self, prompt, **kwargs) -> str: # 定义一个方法 submit_prompt,接受一个参数 prompt 和其他可选参数 if prompt is None: raise Exception("Prompt is None") # 如果 prompt 为 None,抛出异常 if len(prompt) == 0: raise Exception("Prompt is empty") # 如果 prompt 为空,抛出异常 # 计算消息日志中的 token 数量 # 使用 4 作为每个 token 的近似字符数 num_tokens = 0 for message in prompt: num_tokens += len(message["content"]) / 4 if kwargs.get("model", None) is not None: model = kwargs.get("model", None) # 如果 kwargs 中提供了 model 参数,则使用该模型 print( f"Using model {model} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( model=model, messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) elif kwargs.get("engine", None) is not None: engine = kwargs.get("engine", None) # 如果 kwargs 中提供了 engine 参数,则使用该引擎 print( f"Using model {engine} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( engine=engine, messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) elif self.config is not None and "engine" in self.config: print( f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( engine=self.config["engine"], messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) elif self.config is not None and "model" in self.config: print( f"Using model {self.config['model']} for {num_tokens} tokens (approx)" ) response = self.client.chat.completions.create( model=self.config["model"], messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) else: if num_tokens > 3500: model = "gpt-3.5-turbo-16k" # 如果 token 数量超过 3500,使用 gpt-3.5-turbo-16k 模型 else: model = "gpt-3.5-turbo" # 否则使用 gpt-3.5-turbo 模型 print(f"Using model {model} for {num_tokens} tokens (approx)") response = self.client.chat.completions.create( model=model, messages=prompt, max_tokens=self.max_tokens, stop=None, temperature=self.temperature, ) # 查找包含文本的第一个响应(有些响应可能没有文本) for choice in response.choices: if "text" in choice: return choice.text # 如果没有找到包含文本的响应,返回第一个响应的内容(可能为空) return response.choices[0].message.content
导入模块:
os
:用于与操作系统交互,特别是获取环境变量。OpenAI
:用于与 OpenAI API 进行交互。VannaBase
和 DependencyError
:用于继承基类和处理依赖错误。定义 OpenAI_Chat
类:
VannaBase
,实现了与 OpenAI 模型的接口。构造函数 __init__
:
消息方法:
system_message
、user_message
和 assistant_message
方法分别返回带有角色和内容的字典,用于与 OpenAI 模型的通信。submit_prompt
方法:
这个类的主要功能是初始化 OpenAI 客户端并根据提供的配置生成嵌入向量。
from openai import OpenAI # 从 openai 库导入 OpenAI 类,用于与 OpenAI API 交互 from ..base import VannaBase # 从本地模块导入 VannaBase 类,这是一个基础类 # 定义一个新的类 OpenAI_Embeddings,它继承自 VannaBase class OpenAI_Embeddings(VannaBase): def __init__(self, client=None, config=None): # 调用父类 VannaBase 的构造函数,并传递配置参数 VannaBase.__init__(self, config=config) # 如果提供了 client 参数,则使用这个参数初始化 self.client if client is not None: self.client = client return # 结束构造函数的执行 # 如果 self.client 已经存在(通过父类初始化),则结束构造函数 if self.client is not None: return # 如果没有提供 client 参数且 self.client 尚未初始化,则创建一个新的 OpenAI 客户端 self.client = OpenAI() # 如果没有提供配置参数,则结束构造函数 if config is None: return # 根据配置参数设置 OpenAI 客户端的不同属性 if "api_type" in config: self.client.api_type = config["api_type"] # 设置 API 类型 if "api_base" in config: self.client.api_base = config["api_base"] # 设置 API 基础 URL if "api_version" in config: self.client.api_version = config["api_version"] # 设置 API 版本 if "api_key" in config: self.client.api_key = config["api_key"] # 设置 API 密钥 # 定义生成嵌入的方法,接受输入数据 data 和其他可选参数 kwargs def generate_embedding(self, data: str, **kwargs) -> list[float]: # 如果配置参数中包含引擎配置,则使用该引擎生成嵌入 if self.config is not None and "engine" in self.config: embedding = self.client.embeddings.create( engine=self.config["engine"], # 使用指定的引擎 input=data, # 输入数据 ) else: # 否则使用默认的模型生成嵌入 embedding = self.client.embeddings.create( model="text-embedding-ada-002", # 使用默认模型 input=data, # 输入数据 ) # 返回生成的嵌入向量 return embedding.get("data")[0]["embedding"]
OpenSearch_VectorStore 类用于与 OpenSearch 进行交互,提供存储和检索文档、DDL(数据定义语言)和问题-SQL 对的方法。
import base64 # 导入 base64 模块,用于处理 base64 编码 import uuid # 导入 uuid 模块,用于生成唯一标识符 from typing import List # 导入 List 类型提示 import pandas as pd # 导入 pandas 模块,用于数据处理 from opensearchpy import OpenSearch # 导入 OpenSearch 模块,用于与 OpenSearch 服务交互 from ..base import VannaBase # 从本地模块导入 VannaBase 类,这是一个基础类 # 定义 OpenSearch_VectorStore 类,继承自 VannaBase class OpenSearch_VectorStore(VannaBase): def __init__(self, config=None): # 调用父类 VannaBase 的构造函数,并传递配置参数 VannaBase.__init__(self, config=config) # 初始化索引名称 document_index = "vanna_document_index" ddl_index = "vanna_ddl_index" question_sql_index = "vanna_questions_sql_index" # 从配置中获取自定义索引名称 if config is not None and "es_document_index" in config: document_index = config["es_document_index"] if config is not None and "es_ddl_index" in config: ddl_index = config["es_ddl_index"] if config is not None and "es_question_sql_index" in config: question_sql_index = config["es_question_sql_index"] # 将索引名称保存为类的属性 self.document_index = document_index self.ddl_index = ddl_index self.question_sql_index = question_sql_index print("OpenSearch_VectorStore initialized with document_index: ", document_index, " ddl_index: ", ddl_index, " question_sql_index: ", question_sql_index) # 定义默认索引设置 document_index_settings = { "settings": { "index": { "number_of_shards": 6, "number_of_replicas": 2 } }, "mappings": { "properties": { "question": {"type": "text"}, "doc": {"type": "text"} } } } ddl_index_settings = { "settings": { "index": { "number_of_shards": 6, "number_of_replicas": 2 } }, "mappings": { "properties": { "ddl": {"type": "text"}, "doc": {"type": "text"} } } } question_sql_index_settings = { "settings": { "index": { "number_of_shards": 6, "number_of_replicas": 2 } }, "mappings": { "properties": { "question": {"type": "text"}, "sql": {"type": "text"} } } } # 从配置中获取自定义索引设置 if config is not None and "es_document_index_settings" in config: document_index_settings = config["es_document_index_settings"] if config is not None and "es_ddl_index_settings" in config: ddl_index_settings = config["es_ddl_index_settings"] if config is not None and "es_question_sql_index_settings" in config: question_sql_index_settings = config["es_question_sql_index_settings"] # 将索引设置保存为类的属性 self.document_index_settings = document_index_settings self.ddl_index_settings = ddl_index_settings self.question_sql_index_settings = question_sql_index_settings # 初始化 OpenSearch 客户端 es_urls = None if config is not None and "es_urls" in config: es_urls = config["es_urls"] # 获取主机和端口配置 host = config["es_host"] if config and "es_host" in config else "localhost" port = config["es_port"] if config and "es_port" in config else 9200 ssl = config["es_ssl"] if config and "es_ssl" in config else False verify_certs = config["es_verify_certs"] if config and "es_verify_certs" in config else False # 获取认证配置 auth = (config["es_user"], config["es_password"]) if config and "es_user" in config else None # 基于 base64 的认证 headers = None if config and "es_encoded_base64" in config and "es_user" in config and "es_password" in config: if config["es_encoded_base64"]: encoded_credentials = base64.b64encode( (config["es_user"] + ":" + config["es_password"]).encode("utf-8") ).decode("utf-8") headers = {'Authorization': 'Basic ' + encoded_credentials} auth = None # 自定义 headers if config and "es_headers" in config: headers = config["es_headers"] # 获取超时和重试配置 timeout = config["es_timeout"] if config and "es_timeout" in config else 60 max_retries = config["es_max_retries"] if config and "es_max_retries" in config else 10 es_http_compress = config["es_http_compress"] if config and "es_http_compress" in config else False print("OpenSearch_VectorStore initialized with es_urls: ", es_urls, " host: ", host, " port: ", port, " ssl: ", ssl, " verify_certs: ", verify_certs, " timeout: ", timeout, " max_retries: ", max_retries) # 初始化 OpenSearch 客户端 if es_urls is not None: self.client = OpenSearch( hosts=[es_urls], http_compress=es_http_compress, use_ssl=ssl, verify_certs=verify_certs, timeout=timeout, max_retries=max_retries, retry_on_timeout=True, http_auth=auth, headers=headers ) else: self.client = OpenSearch( hosts=[{'host': host, 'port': port}], http_compress=es_http_compress, use_ssl=ssl, verify_certs=verify_certs, timeout=timeout, max_retries=max_retries, retry_on_timeout=True, http_auth=auth, headers=headers ) print("OpenSearch_VectorStore initialized with client over ") # 执行一个简单的查询来检查连接 try: print('Connected to OpenSearch cluster:') info = self.client.info() print('OpenSearch cluster info:', info) except Exception as e: print('Error connecting to OpenSearch cluster:', e) # 如果索引不存在,则创建索引 self.create_index_if_not_exists(self.document_index, self.document_index_settings) self.create_index_if_not_exists(self.ddl_index, self.ddl_index_settings) self.create_index_if_not_exists(self.question_sql_index, self.question_sql_index_settings) # 创建索引的方法 def create_index(self): for index in [self.document_index, self.ddl_index, self.question_sql_index]: try: self.client.indices.create(index) except Exception as e: print("Error creating index: ", e) print(f"opensearch index {index} already exists") pass # 如果索引不存在,则创建索引的方法 def create_index_if_not_exists(self, index_name: str, index_settings: dict) -> bool: try: if not self.client.indices.exists(index_name): print(f"Index {index_name} does not exist. Creating...") self.client.indices.create(index=index_name, body=index_settings) return True else: print(f"Index {index_name} already exists.") return False except Exception as e: print(f"Error creating index: {index_name} ", e) return False # 添加 DDL 文档的方法 def add_ddl(self, ddl: str, **kwargs) -> str: id = str(uuid.uuid4()) + "-ddl" # 生成唯一标识符 ddl_dict = {"ddl": ddl} response = self.client.index(index=self.ddl_index, body=ddl_dict, id=id, **kwargs) return response['_id'] # 添加文档的方法 def add_documentation(self, doc: str, **kwargs) -> str: id = str(uuid.uuid4()) + "-doc" # 生成唯一标识符 doc_dict = {"doc": doc} response = self.client.index(index=self.document_index, id=id, body=doc_dict, **kwargs) return response['_id'] # 添加问题和 SQL 的方法 def add_question_sql(self, question: str, sql: str, **kwargs) -> str: id = str(uuid.uuid4()) + "-sql" # 生成唯一标识符 question_sql_dict = {"question": question, "sql": sql} response = self.client.index(index=self.question_sql_index, body=question_sql_dict, id=id, **kwargs) return response['_id'] # 获取相关 DDL 文档的方法 def get_related_ddl(self, question: str, **kwargs) -> List[str]: query = {"query": {"match": {"ddl": question}}} print(query) response = self.client.search(index=self.ddl_index, body=query, **kwargs) return [hit['_source']['ddl'] for hit in response['hits']['hits']] # 获取相关文档的方法 def get_related_documentation(self, question: str, **kwargs) -> List[str]: query = {"query": {"match": {"doc": question}}} print(query) response = self.client.search(index=self.document_index, body=query, **kwargs) return [hit['_source']['doc'] for hit in response['hits']['hits']] # 获取相似问题和 SQL 的方法 def get_similar_question_sql(self, question: str, **kwargs) -> List[str]: query = {"query": {"match": {"question": question}}} print(query) response = self.client.search(index=self.question_sql_index, body=query, **kwargs) return [(hit['_source']['question'], hit['_source']['sql']) for hit in response['hits']['hits']] # 获取训练数据的方法 def get_training_data(self, **kwargs) -> pd.DataFrame: data = [] # 从文档索引中获取数据 response = self.client.search(index=self.document_index, body={"query": {"match_all": {}}}, size=1000) for hit in response['hits']['hits']: data.append({"id": hit["_id"], "training_data_type": "documentation", "question": "", "content": hit["_source"]['doc']}) # 从问题和 SQL 索引中获取数据 response = self.client.search(index=self.question_sql_index, body={"query": {"match_all": {}}}, size=1000) for hit in response['hits']['hits']: data.append({"id": hit["_id"], "training_data_type": "sql", "question": hit.get("_source", {}).get("question", ""), "content": hit.get("_source", {}).get("sql", "")}) # 从 DDL 索引中获取数据 response = self.client.search(index=self.ddl_index, body={"query": {"match_all": {}}}, size=1000) for hit in response['hits']['hits']: data.append({"id": hit["_id"], "training_data_type": "ddl", "question": "", "content": hit["_source"]['ddl']}) # 返回包含所有数据的 pandas DataFrame return pd.DataFrame(data) # 删除训练数据的方法 def remove_training_data(self, id: str, **kwargs) -> bool: try: if id.endswith("-sql"): self.client.delete(index=self.question_sql_index, id=id) return True elif id.endswith("-ddl"): self.client.delete(index=self.ddl_index, id=id, **kwargs) return True elif id.endswith("-doc"): self.client.delete(index=self.document_index, id=id, **kwargs) return True else: return False except Exception as e: print("Error deleting training data: ", e) return False # 生成嵌入的方法(空方法) def generate_embedding(self, data: str, **kwargs) -> list[float]: pass # OpenSearch 不需要生成嵌入 # 示例初始化调用 # OpenSearch_VectorStore.__init__(self, config={'es_urls': "https://opensearch-node.test.com:9200", 'es_encoded_base64': True, 'es_user': "admin", 'es_password': "admin", 'es_verify_certs': True}) # OpenSearch_VectorStore.__init__(self, config={'es_host': "https://opensearch-node.test.com", 'es_port': 9200, 'es_user': "admin", 'es_password': "admin", 'es_verify_certs': True})
1.初始化:
配置索引名称和设置。
初始化 OpenSearch 客户端,支持多种配置选项(如认证、超时、重试等)。
创建必要的索引。
2.索引管理:
create_index_if_not_exists:检查索引是否存在,不存在则创建。
create_index:创建指定的索引。
3.文档操作:
add_ddl、add_documentation、add_question_sql:添加 DDL 文档、普通文档和问题-SQL 对。
get_related_ddl、get_related_documentation、get_similar_question_sql:检索相关的 DDL 文档、普通文档和问题-SQL 对。
get_training_data:获取所有训练数据,返回 pandas DataFrame。
remove_training_data:删除指定的训练数据。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。