当前位置:   article > 正文

基于langchainsql和chatglm实现自然语言查询mysql数据库_chatglm sql

chatglm sql

首先发布一个chatglm服务,具体如下:

import os
import json

from flask import Flask
from flask import request
from transformers import AutoTokenizer, AutoModel

# system params
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

tokenizer = AutoTokenizer.from_pretrained(r".\chatglm2-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained(r".\chatglm2-6b-int4", trust_remote_code=True).half().cuda()
model.eval()

app = Flask(__name__)

@app.route("/chat", methods=["POST"])
def chat():
    """chat
    """
    data_seq = request.get_data()
    data_dict = json.loads(data_seq)
    human_input = data_dict["human_input"]
    response, _ = model.chat(tokenizer, human_input, history=[])

    result_dict = {
        "response": response
    }
    result_seq = json.dumps(result_dict, ensure_ascii=False)
    return result_seq

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=8595, debug=False)

然后就可以基于langchain进行查询具体如下:

openai_api_key = "xxxx"
import os
import openai

# !pip install langchain langchain-experimental openai -q

from langchain import OpenAI, SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
import time
import logging
import requests
from typing import Optional, List, Dict, Mapping, Any

import langchain
from langchain.llms.base import LLM
from langchain.cache import InMemoryCache

logging.basicConfig(level=logging.INFO)
# 启动llm的缓存
langchain.llm_cache = InMemoryCache()


class ChatGLM(LLM):
    # 模型服务url
    url = "http://127.0.0.1:8595/chat"

    @property
    def _llm_type(self) -> str:
        return "chatglm"

    def _construct_query(self, prompt: str) -> Dict:
        """构造请求体
        """
        query = {
            "human_input": prompt
        }
        return query

    @classmethod
    def _post(cls, url: str,
              query: Dict) -> Any:
        """POST请求
        """
        _headers = {"Content_Type": "application/json"}
        with requests.session() as sess:
            resp = sess.post(url,
                             json=query,
                             headers=_headers,
                             timeout=60)
        return resp

    def _call(self, prompt: str,
              stop: Optional[List[str]] = None) -> str:
        """_call
        """
        # construct query
        query = self._construct_query(prompt=prompt)

        # post
        resp = self._post(url=self.url,
                          query=query)

        if resp.status_code == 200:
            resp_json = resp.json()
            predictions = resp_json["response"]
            return predictions
        else:
            return "请求模型"

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters.
        """
        _param_dict = {
            "url": self.url
        }
        return _param_dict


# llm = OpenAI(temperature=0, openai_api_key="")
if __name__ == "__main__":
    llm = ChatGLM()
    # sqlite_db_path ='./chinook.db'
    db = SQLDatabase.from_uri(f"mysql://用户名:密码@ip:端口号/数据库名?charset=数据库编码")
    db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)
    db_chain.run(用户问题)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/837396
推荐阅读
相关标签
  

闽ICP备14008679号