这一步骤中的示例数据可以包括数据库模式或几行数据。将其全部转换为单个字符串非常重要,因为它将成为将发送到GPT 3.5的更大的文本字符串查询的一部分。
下面的代码将数据文件导入R语言,使用sqldf查看数据框架是SQL数据库表时的SQL模式,使用dplyr的filter()函数提取三个示例行,并将模式和示例行都转换为字符串。免责声明:ChatGPT编写了将数据转换为单个字符串的基本R apply()部分代码(通常使用purrr完成这些任务)。
- library(rio)
- library(dplyr)
- library(sqldf)
- library(glue)
- states <- rio::import("https://raw.githubusercontent.com/smach/SampleData/main/states.csv") |>
- filter(!is.na(Region))
- states_schema <- sqldf("PRAGMA table_info(states)")
- states_schema_string <- paste(apply(states_schema, 1, paste, collapse = "\t"), collapse = "\n")
- states_sample <- dplyr::sample_n(states, 3)
- states_sample_string <- paste(apply(states_sample, 1, paste, collapse = "\t"), collapse = "\n")
- create_prompt <- function(schema, rows_sample, query, table_name) {
- glue::glue("Act as if you're a data scientist. You have a SQLite table named {table_name} with the following schema:
- ```
- {schema}
- ```
- The first rows look like this:
- ```{rows_sample}```
- Based on this data, write a SQL query to answer the following question: {query}. Return the SQL query ONLY. Do not include any additional explanation.")
- }
用户可以先将数据剪切并粘贴到OpenAI的Web界面中,然后在ChatGPT或OpenAI API中查看结果。ChatGPT不收取使用费用,但用户不能调整其结果。可以让用户设置温度之类的参数,这意味着其反应应该有多“随机”或多有创意,以及服务商想使用哪种模型。对于SQL代码,将温度设置为0。
接下来,将一个自然语言问题保存到变量my_query中,使用create_prompt()函数创建一个提示符,然后观察当将该提示符粘贴到API playground中时会发生什么:
- > my_query <- "What were the highest and lowest Population changes in 2020 by Division?"
- > my_prompt <- get_query(states_schema_string, states_sample_string, my_query, "states")
- > cat(my_prompt)
- Act as if you're a data scientist. You have a SQLite table named states with the following schema:
- ```
- 0 State TEXT 0 NA 0
- 1 Pop_2000 INTEGER 0 NA 0
- 2 Pop_2010 INTEGER 0 NA 0
- 3 Pop_2020 INTEGER 0 NA 0
- 4 PctChange_2000 REAL 0 NA 0
- 5 PctChange_2010 REAL 0 NA 0
- 6 PctChange_2020 REAL 0 NA 0
- 7 State Code TEXT 0 NA 0
- 8 Region TEXT 0 NA 0
- 9 Division TEXT 0 NA 0
- ```
- The first rows look like this:
- ```Delaware 783600 897934 989948 17.6 14.6 10.2 DE South South Atlantic
- Montana 902195 989415 1084225 12.9 9.7 9.6 MT West Mountain
- Arizona 5130632 6392017 7151502 40.0 24.6 11.9 AZ West Mountain```
- Based on this data, write a SQL query to answer the following question: What were the highest and lowest Population changes in 2020 by Division?. Return the SQL query ONLY. Do not include any additional explanation.
提示输入OpenAI API playground和生成的SQL代码
- sqldf("SELECT Division, MAX(PctChange_2020) AS Highest_PctChange_2020, MIN(PctChange_2020) AS Lowest_PctChange_2020 FROM states GROUP BY Division;")
- Division Highest_PctChange_2020 Lowest_PctChange_2020
- 1 East North Central 4.7 -0.1
- 2 East South Central 8.9 -0.2
- 3 Middle Atlantic 5.7 2.4
- 4 Mountain 18.4 2.3
- 5 New England 7.4 0.9
- 6 Pacific 14.6 3.3
- 7 South Atlantic 14.6 -3.2
- 8 West North Central 15.8 2.8
- 9 West South Central 15.9 2.7
以编程方式向OpenAI发送和返回数据,而不是将其剪切和粘贴到Web界面中,这将会方便得多。有几个R包可以使用OpenAI API。下面的代码块使用OpenAI包向API发送一个提示,存储API响应,提取响应中包含带有请求的SQL代码的文本的部分,打印该代码,并在数据上运行SQL。
- library(openai)
- my_results <- openai::create_chat_completion(model = "gpt-3.5-turbo", temperature = 0, messages = list(
- list(role = "user", content = my_prompt)
- ))
- the_answer <- my_results$choices$message.content
- cat(the_answer)
- SELECT Division, MAX(PctChange_2020) AS Highest_Population_Change, MIN(PctChange_2020) AS Lowest_Population_Change
- FROM states
- GROUP BY Division;
- sqldf(the_answer)
- Division Highest_Population_Change Lowest_Population_Change
- 1 East North Central 4.7 -0.1
- 2 East South Central 8.9 -0.2
- 3 Middle Atlantic 5.7 2.4
- 4 Mountain 18.4 2.3
- 5 New England 7.4 0.9
- 6 Pacific 14.6 3.3
- 7 South Atlantic 14.6 -3.2
- 8 West North Central 15.8 2.8
- 9 West South Central 15.9
如果用户想使用OpenAI API,需要一个OpenAI API密钥。对于这个包,密钥应该存储在一个系统环境变量中,例如OPENAI_API_KEY。需要注意的是,这个API不是免费使用的,但在把它变成编辑器之前,一天运行了这个项目十几次,而其总账户使用的费用是1美分。
- library(shiny)
- library(openai)
- library(dplyr)
- library(sqldf)
- # Load hard-coded dataset
- states <- read.csv("states.csv") |>
- dplyr::filter(!is.na(Region) & Region != "")
- states_schema <- sqldf::sqldf("PRAGMA table_info(states)")
- states_schema_string <- paste(apply(states_schema, 1, paste, collapse = "\t"), collapse = "\n")
- states_sample <- dplyr::sample_n(states, 3)
- states_sample_string <- paste(apply(states_sample, 1, paste, collapse = "\t"), collapse = "\n")
- # Function to process user input
- get_prompt <- function(query, schema = states_schema_string, rows_sample = states_sample_string, table_name = "states") {
- my_prompt <- glue::glue("Act as if you're a data scientist. You have a SQLite table named {table_name} with the following schema:
- ```
- {schema}
- ```
- The first rows look like this:
- ```{rows_sample}```
- Based on this data, write a SQL query to answer the following question: {query} Return the SQL query ONLY. Do not include any additional explanation.")
- print(my_prompt)
- return(my_prompt)
- }
- ui <- fluidPage(
- titlePanel("Query state database"),
- sidebarLayout(
- sidebarPanel(
- textInput("query", "Enter your query", placeholder = "e.g., What is the total 2020 population by Region?"),
- actionButton("submit_btn", "Submit")
- ),
- mainPanel(
- uiOutput("the_sql"),
- br(),
- br(),
- verbatimTextOutput("results")
- )
- )
- )
- server <- function(input, output) {
- # Create the prompt from the user query to send to GPT
- the_prompt <- eventReactive(input$submit_btn, {
- req(input$query, states_schema_string, states_sample_string)
- my_prompt <- get_prompt(query = input$query)
- })
- # send prompt to GPT, get SQL, run SQL, print results
- observeEvent(input$submit_btn, {
- req(the_prompt()) # text to send to GPT
- # Send results to GPT and get response
- # withProgress adds a Shiny progress bar. Commas now needed after each statement
- withProgress(message = 'Getting results from GPT', value = 0, { # Add Shiny progress message
- my_results <- openai::create_chat_completion(model = "gpt-3.5-turbo", temperature = 0, messages = list(
- list(role = "user", content = the_prompt())
- ))
- the_gpt_sql <- my_results$choices$message.content
- # print the SQL
- sql_html <- gsub("\n", "<br />", the_gpt_sql)
- sql_html <- paste0("<p>", sql_html, "</p>")
- # Run SQL on data to get results
- gpt_answer <- sqldf(the_gpt_sql)
- setProgress(value = 1, message = 'GPT results received') # Send msg to user that
- })
- # Print SQL and results
- output$the_sql <- renderUI(HTML(sql_html))
- if (is.vector(gpt_answer) ) {
- output$results <- renderPrint(gpt_answer)
- } else {
- output$results <- renderPrint({ print(gpt_answer) })
- }
- })
- }
- shinyApp(ui = ui, server = server)
