赞
踩
在本文中,我们将探讨如何使用 SQLCoder-7B(我们将在 Amazon SageMaker 上部署的大型语言模型 (LLM))和 LangChain 来执行自然语言查询 (NLQ)。
我们将了解如何使用 LangChain 创建一个管道,提示 LLM 生成 SQL 查询,从 PostgreSQL 数据库检索数据,并将结果作为上下文传递给 LLM 以获得最终响应。
SQLCoder 是大型语言模型 (LLM) 的集合,用于从自然语言高效生成 SQL 查询。
我们将使用 SQLCoder-7B,它基于 Mistral-7B 并针对 SQL 查询生成进行了微调。
根据其创建者的说法"SQLCoder-7B 在自然语言到 SQL 任务中的表现优于 GPT-3.5 Turbo 和其他流行的开源模型。此外,在对特定数据库模式进行微调时,它甚至超过了 GPT-4"。
首先,让我们使用 Amazon RDS 配置 PostgreSQL 数据库。
我们将使用下面的 Terraform 代码片段来完成这项任务:
- # ------------------------------------------------------------------------------
- # RDS Security group
- # ------------------------------------------------------------------------------
- resource "aws_security_group" "db_sg" {
- name_prefix = local.db_security_group_name_prefix
- vpc_id = local.vpc_id
-
- ingress {
- from_port = local.db_port
- to_port = local.db_port
- protocol = "tcp"
- cidr_blocks = [local.my_ip_address]
- }
-
- egress {
- from_port = 0
- to_port = 0
- protocol = "-1"
- cidr_blocks = ["0.0.0.0/0"]
- }
- }
-
- # ------------------------------------------------------------------------------
- # RDS
- # ------------------------------------------------------------------------------
- module "db" {
- source = "terraform-aws-modules/rds/aws"
- identifier = local.db_identifier
-
- engine = "postgres"
- engine_version = "15.4"
- family = "postgres15"
-
- instance_class = local.db_instance_class
- allocated_storage = local.db_allocated_storage
-
- db_name = local.db_name
- username = local.db_username
- port = local.db_port
-
- create_db_subnet_group = true
- vpc_security_group_ids = [aws_security_group.db_sg.id]
- subnet_ids = local.db_subnet_ids
-
- }

现在,让我们用电子商务数据集中的数据填充新创建的数据库,您可以从 Kaggle 下载 CSV 文件。
下载数据集后,使用以下命令连接到 RDS 数据库:
psql -h <RDS_ENDPOINT> -p <RDS_PORT> -U <DATABASE_USERNAME> -d <DATABASE_NAME> -W
根据提示输入密码。连接后,使用以下 SQL 命令创建销售表:
- CREATE TABLE sales (
- invoiceno VARCHAR(255),
- stockcode VARCHAR(255),
- description VARCHAR(255),
- quantity INT,
- invoicedate TIMESTAMP,
- unitprice DECIMAL(10, 2),
- customerid INT,
- country VARCHAR(50)
- );
接下来,使用以下命令将 CSV 文件中的数据复制到销售表中:
\COPY sales(invoiceno, stockcode, description, quantity, invoicedate, unitprice, customerid, country) FROM '/path/to/data.csv' DELIMITER ',' CSV HEADER;
最后,运行一个简单的计数查询,验证数据是否已成功加载:
SELECT COUNT(*) FROM sales;
如果你已经按照我之前的文章部署了自己的私人 LLM 聊天机器人,只需更新代码,按如下方式部署模型 defog/sqlcoder-7b:
- locals {
- hugging_face_model_id = "defog/sqlcoder-7b"
- }
现在,您的数据库已经建立并填充了数据,模型也已部署,我们将创建一个简单的脚本,使用数据库连接和 Amazon SageMaker Endpoint 创建 SQLDatabaseChain。
如果您对检索增强生成(RAG)的概念不了解,请参阅我之前的文章《使用亚马逊 Bedrock 和 LangChain 创建上下文感知的 LLM 聊天机器人》。
使用下面的 Python 脚本可以对数据库中存储的数据执行自然语言查询:
- import boto3
-
- import json
- from langchain.sql_database import SQLDatabase
- from langchain_experimental.sql import SQLDatabaseChain
- from langchain.llms.sagemaker_endpoint import SagemakerEndpoint, LLMContentHandler
- from typing import Dict
- from sqlalchemy.exc import ProgrammingError
-
- # RDS configuration
- RDS_DB_NAME = "<RDS_DB_NAME>"
- RDS_ENDPOINT = "<RDS_ENDPOINT>"
- RDS_USERNAME = "<RDS_USERNAME>"
- RDS_PASSWORD = "<RDS_PASSWORD>"
- RDS_PORT = "<RDS_PORT>"
- RDS_URI = f"postgresql+psycopg2://{RDS_USERNAME}:{RDS_PASSWORD}@{RDS_ENDPOINT}:{RDS_PORT}/{RDS_DB_NAME}"
- db = SQLDatabase.from_uri(
- RDS_URI,
- include_tables=["sales"],
- sample_rows_in_table_info=2,
- )
-
- # Sagemaker configuration
- SAGEMAKER_ENDPOINT_NAME = "<SAGEMAKER_ENDPOINT_NAME>"
- MAX_TOKENS = 1024
-
- class ContentHandler(LLMContentHandler):
- content_type = "application/json"
- accepts = "application/json"
-
- def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
- input_str = json.dumps({"inputs": prompt.strip(), "parameters": model_kwargs})
- return input_str.encode("utf-8")
-
- def transform_output(self, output: bytes) -> str:
- response_json = json.loads(output.read().decode("utf-8"))
- response = response_json[0]["generated_text"].strip().split("\n")[0]
- return response
-
- content_handler = ContentHandler()
-
- sagemaker_client = boto3.client("runtime.sagemaker")
- llm = SagemakerEndpoint(
- client=sagemaker_client,
- endpoint_name=SAGEMAKER_ENDPOINT_NAME,
- model_kwargs={
- "max_new_tokens": MAX_TOKENS,
- "return_full_text": False,
- },
- content_handler=content_handler,
- )
-
- # Chain
- db_chain = SQLDatabaseChain.from_llm(
- llm,
- db,
- verbose=True,
- )
-
- while True:
- user_input = input("Enter a message (or 'exit' to quit): ")
-
- if user_input.lower() == "exit":
- break
-
- try:
- results = db_chain.run(user_input)
- print(results)
- except (ProgrammingError, ValueError) as exc:
- print(f"\n\n{exc}")

运行此脚本时,它会提示用户输入信息,然后这些信息将通过 SQLDatabaseChain 传递。
此脚本仅供演示之用,可进一步定制。
在与 LLM 交互时,LangChain 可灵活自定义提示,以获得更好的效果。
定制脚本的另一种方法是向 LLM 提供详细的表定义。
这样做可以提供有关被查询表结构的额外上下文,从而帮助 LLM 生成更准确、更相关的 SQL 查询。
让我们看看引擎盖下发生了什么。
我运行了脚本,并输入了问题 "最畅销的产品是什么?
在这里,我们可以看到 LangChain 生成了一个提示,其中包含表模式和表中的 2 行。
然后,该提示请求亚马逊 SageMaker 端点根据给定上下文创建 SQL 查询。
结果,模型返回了生成的 SQL 查询,如下图所示:
LangChain 在数据库中执行了查询,得到了以下结果:
有了这些结果,LangChain 再次请求 LLM 模型,在提示符中填写 SQLResult 并请求回答。
模型给出了最终答案,如下图所示:
从这一成功结果来看,脚本似乎能够处理自然语言查询。
要进一步测试脚本,可以考虑尝试其他复杂查询。
必须考虑使用只读用户连接数据库,因为 LLM 有可能生成插入、删除和更改等数据处理语言(DML)查询。
这种集成为各种应用打开了大门,从数据分析到能够回答复杂数据库相关查询的聊天机器人。如果您在集成过程中遇到任何难题,或有本指南未涵盖的特殊要求,请随时联系我。我将竭诚为您服务!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。