当前位置:   article > 正文

0基础搞AI-NL2SQL数据集处理脚本(用于LLM-fine-tune)

nl2sql数据集

        消失了好久好久,这次换了一家公司,然后又在忙于秋招,因此很久没有更新,最近事情也告一段落,因此终于有空回来水博客,今天给大家带来最近的工作,NL2SQL数据集,我们的工作是利用代码生成大模型(类似CodeFuse系列,CodeLlama系列)进行fine-tune,通过用户query和query涉及的数据库表的Schema作为输入,使用fine-tune后的LLM进行推理来得到最后的生成SQL,当然为了工作的方便,所以我们试图将所有的开源数据集进行整合,因此在此处的NL2SQL数据集中,提供了经过模型翻译的Wiki_SQL数据集,Cspider数据集,Du_SQL数据集,如果有大佬有追一科技的数据集请告诉我,需要一些帮助,接下来首先给出NL2SQL数据集的处理脚本:

1、数据集生成(耗时13h,把8w条WikiSQL翻译了)

Data_deal_Script.py

  1. """
  2. codeer:Jinzhangli
  3. function:数据集处理和构建
  4. relation:2035877994@qq.com
  5. time:2023/11/21 15:23
  6. """
  7. import json,re
  8. class Cspider_Data_make:
  9. def Cspider_Schema_load_deal(self):
  10. Schema={}
  11. All_DB=self.Cspider_Data_load("Data/Cspider/tables.json")
  12. for i in range(len(All_DB)):
  13. DB={}
  14. column_names=All_DB[i]["column_names"]
  15. table_names=All_DB[i]['table_names']
  16. for j in range(len(table_names)):
  17. DB["_".join(re.split(" ",table_names[j]))]=[column_names[k][1] for k in range(len(column_names)) if column_names[k][0]==j]
  18. Schema[All_DB[i]["db_id"]]=DB
  19. return Schema
  20. def Cspider_Data_load(self,file_path:str):
  21. dict_data=json.loads(open(file_path,"r",encoding="utf-8").read())
  22. return dict_data
  23. def Cspider_Schema_pipe(self,db_name:str,Table_list:list):
  24. All_Schema=self.Cspider_Schema_load_deal()
  25. result=[]
  26. Table_list=[i for i in Table_list if i not in ["("]]
  27. for i in range(len(Table_list)):
  28. result.append(All_Schema[db_name][Table_list[i]])
  29. return result
  30. def Table_get(self,SQL_token:list)->list:
  31. Table_list=[SQL_token[i] for i in range(len(SQL_token)) if SQL_token[i-1] in ["from","join"]]
  32. return Table_list
  33. def Dict_deal(self,one_dict:dict)->dict:
  34. query=one_dict["question"]
  35. SQL=one_dict["query"]
  36. db_name=one_dict["db_id"]
  37. return {"query":query,"SQL":SQL,"table_name":"","column_name":"","db_name":db_name}
  38. def Cspider_Datas_Get(self,Cspider_data):
  39. Result=[]
  40. for i in range(len(Cspider_data)):
  41. if i not in [3097,3153]:
  42. print("=========正在处理第"+str(i)+",总共有"+str(len(Cspider_data))+"个=========")
  43. one_dict = self.Dict_deal(Cspider_data[i])
  44. Table_list = list(set(self.Table_get(Cspider_data[i]["query_toks"])))
  45. result = self.Cspider_Schema_pipe(one_dict["db_name"], Table_list)
  46. one_dict["table_name"] = Table_list
  47. one_dict["column_name"] = result
  48. Result.append(one_dict)
  49. return Result
  50. def Csipder_main(self):
  51. Cspider_train_data = self.Cspider_Data_load("Data/Cspider/train.json")
  52. Cspider_dev_data=self.Cspider_Data_load("Data/Cspider/dev.json")
  53. Cspider_Result=self.Cspider_Datas_Get(Cspider_train_data)+self.Cspider_Datas_Get(Cspider_dev_data)
  54. return Cspider_Result
  55. class wikiSQL_Data_make:
  56. def wiki_load(self,file_path):
  57. file_str=open(file_path,"r",encoding="utf-8").readlines()
  58. Dict_Data=[eval(file_str[i]) for i in range(len(file_str))]
  59. return Dict_Data
  60. def wiki_deal(self,data_path,table_path):
  61. Dict_data=self.wiki_load(data_path)
  62. Table_data=self.wiki_load(table_path)
  63. Wiki_Result,Index=[],0
  64. Table_dict={Table_data[i]["id"]:[Table_data[i]["header"],Table_data[i]['caption']]
  65. for i in range(len(Table_data)) if "caption" in Table_data[i].keys()}
  66. for i in range(len(Dict_data)):
  67. table_id=Dict_data[i]["table_id"]
  68. all_table=Table_dict.keys()
  69. if table_id in all_table:
  70. #print("正在处理第" + str(Index) + ",总共有" + str(len(Dict_data)) + "个")
  71. Index+=1
  72. query=Dict_data[i]["question"]
  73. table_name="_".join(re.split(" ",Table_dict[Dict_data[i]["table_id"]][1]))
  74. SQL=Dict_data[i]["sql"]
  75. column_name=Table_dict[Dict_data[i]["table_id"]][0]
  76. for j in range(len(column_name)):
  77. column=[]
  78. if "/" in column_name[j] and "(" not in column_name[j]:
  79. column_name[j]=re.split("/",column_name[j])[0]
  80. elif "(" in column_name[j]:
  81. for k in column_name[j]:
  82. if k!="(":
  83. column.append(k)
  84. else:
  85. column_name[j]=re.split(" ","".join(column))
  86. if column_name[j][-1]=="":
  87. column_name[j]="_".join(column_name[j][0:-1])
  88. else:
  89. column_name[j] = "_".join(column_name[j])
  90. break
  91. elif " " in column_name[j]:
  92. column_name[j]="_".join(re.split(" ",column_name[j]))
  93. elif type(column_name[j])==list:
  94. column_name[j]=column_name[j][0]
  95. SQL=self.SQL_make(SQL,column_name,table_name)
  96. one_dict={"query": query, "SQL": SQL, "table_name": table_name, "column_name":column_name, "db_name": ""}
  97. Wiki_Result.append(one_dict)
  98. return Wiki_Result
  99. def SQL_make(self,SQL_token,column_name,table_name):
  100. agg_Action, conds_Acction= ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'],['=', '>', '<', 'OP']
  101. SQL="SELECT "+agg_Action[SQL_token["agg"]]+" ( "+column_name[SQL_token["sel"]]+" ) "+"FROM "+table_name
  102. if len(SQL_token["conds"])==1:
  103. if type(SQL_token["conds"][0][2])!=str:
  104. SQL_token["conds"][0][2]=str(SQL_token["conds"][0][2])
  105. SQL_token["conds"][0][1]=conds_Acction[SQL_token["conds"][0][1]]
  106. SQL_token["conds"][0][0]=column_name[SQL_token["conds"][0][0]]
  107. SQL+=" WHERE "+" ".join(SQL_token["conds"][0])
  108. else:
  109. conds_list=SQL_token["conds"]
  110. for i in range(len(conds_list)):
  111. if type(conds_list[i][2])!=str:
  112. conds_list[i][2]=str(conds_list[i][2])
  113. conds_list[i][0]=column_name[conds_list[i][0]]
  114. conds_list[i][1]=conds_Acction[conds_list[i][1]]
  115. for i in range(len(conds_list)):
  116. if i==len(conds_list)-1:
  117. SQL+="and "+" ".join(conds_list[i])
  118. elif i==0:
  119. SQL+="WHERE "+" ".join(conds_list[i])+" "
  120. else:
  121. SQL+="and "+" ".join(conds_list[i])+" "
  122. return SQL
  123. def wiki_main(self):
  124. Wiki_Result=self.wiki_deal("Data/WikiSQL/train.json","Data/WikiSQL/train_tables.json")
  125. return Wiki_Result
  126. class DuSQL_Data_make:
  127. def DuSQL_load(self,file_path):
  128. DuSQL_data=json.loads(open(file_path,"r",encoding="utf-8").read())
  129. return DuSQL_data
  130. def Schema_deal(self,DuSQL_schema:list[dict]):
  131. Schema_dict={}
  132. for i in range(len(DuSQL_schema)):
  133. table_names=DuSQL_schema[i]["table_names"]
  134. column_names=DuSQL_schema[i]["column_names"]
  135. Schema_dict[DuSQL_schema[i]["db_id"]]={table_names[j]:[column_names[k][1] for k in range(len(column_names)) if column_names[k][0]==j] for j in range(len(table_names))}
  136. return Schema_dict
  137. def TableGetFromSQL(self,SQL):
  138. SQL_List=re.split(" ",SQL)
  139. Table=list(set([SQL_List[i] for i in range(len(SQL_List)) if i!=0 and SQL_List[i-1] in ["from","join"]]))
  140. return Table
  141. def Query_SQL_Schema(self,DUSQL_data:list[dict],DuSQL_Schema):
  142. Result=[]
  143. for i in range(len(DUSQL_data)):
  144. print("=========正在处理第" + str(i) + ",总共有" + str(len(DUSQL_data)) + "个=========")
  145. SQL=DUSQL_data[i]["sql_query"]
  146. query=DUSQL_data[i]["question"]
  147. db_name=DUSQL_data[i]["db_id"]
  148. table=self.TableGetFromSQL(SQL)[0]
  149. column=DuSQL_Schema[db_name][table]
  150. Result.append({"query":query,"SQL":SQL,"table_name":table,"column_name":column,"db_name":db_name})
  151. return Result
  152. def DuSQL_main(self):
  153. DuSQL_data=self.DuSQL_load("Data/DuSQL/sample-data.json")
  154. DUSQL_Schema=self.DuSQL_load("Data/DuSQL/db-schema.json")
  155. DUSQL_Schema=self.Schema_deal(DUSQL_Schema)
  156. DuSQL_Result=self.Query_SQL_Schema(DuSQL_data,DUSQL_Schema)
  157. return DuSQL_Result

用于翻译的数据接口,这里用了通义千问14B

OutAPI.py

  1. """
  2. codeer:Jinzhangli
  3. function:接入外部API服务
  4. relation:2035877994@qq.com
  5. time:2023/11/30 15:49
  6. """
  7. import requests,json
  8. def Qwen14BChat(text,history):
  9. url="http://172.16.158.247:9899/Qwen14B"
  10. data=json.dumps({"prompt":text,"history":history})
  11. response=requests.post(url=url,data=data)
  12. response=eval(response.text)
  13. return response

接下来是主控脚本,Tune_main.py

  1. """
  2. codeer:Jinzhangli
  3. function:主控文件
  4. relation:2035877994@qq.com
  5. time:2023/11/30 18:05
  6. """
  7. import json
  8. from Data_Deal_Script import *
  9. from OutAPI import *
  10. def LearningDataJson_build():
  11. wikiSQL_Data = wikiSQL_Data_make()
  12. print("开始处理WIKI_SQL")
  13. WIKI_SQL = wikiSQL_Data.wiki_main()
  14. #英文数据集翻译
  15. for i in range(len(WIKI_SQL)):
  16. print("====翻译第"+str(i)+"个句子====")
  17. WIKI_SQL[i]["query"] = Qwen14BChat("请帮我将以下文本翻译为中文,只输出结果,不要任何解释\n"+WIKI_SQL[i]["query"],[])["response"]
  18. print(WIKI_SQL[i]["query"])
  19. Cspider_Data = Cspider_Data_make()
  20. Dusql_Data = DuSQL_Data_make()
  21. print("开始处理DU_SQL")
  22. DU_SQL = Dusql_Data.DuSQL_main()
  23. print("开始处理Cspider")
  24. Cspider = Cspider_Data.Csipder_main()
  25. Result=DU_SQL+Cspider+WIKI_SQL
  26. with open("result.json", "w", encoding="utf-8") as json_file:
  27. json.dump(Result,json_file,ensure_ascii=False)
2、基于Swift框架的加载LoRA微调

接下来是LLM微调脚本(基于Swift框架)

首先安装阿里巴巴Swift框架

  1. git clone https://github.com/modelscope/swift.git
  2. cd swift
  3. pip install -e .

然后进入Clone下来的Swift文件夹

cd ../swift/examples/pytorch/llm

使用llm下自带的脚本,也可以自己写,我比较懒直接os.system()来修改

  1. import os
  2. command="""
  3. CUDA_VISIBLE_DEVICES=0 \
  4. python llm_sft.py \
  5. --model_type qwen-14b \
  6. --model_cache_dir /home/gpu-user1/JinzhangLi/Qwen-14B \
  7. --sft_type lora \
  8. --template_type default-generation \
  9. --dtype bf16 \
  10. --output_dir output \
  11. --dataset dureader-robust-zh \
  12. --train_dataset_sample -1 \
  13. --num_train_epochs 1 \
  14. --max_length 2048 \
  15. --quantization_bit 4 \
  16. --bnb_4bit_comp_dtype bf16 \
  17. --lora_rank 8 \
  18. --lora_alpha 32 \
  19. --lora_dropout_p 0. \
  20. --lora_target_modules ALL \
  21. --gradient_checkpointing true \
  22. --batch_size 1 \
  23. --weight_decay 0. \
  24. --learning_rate 1e-4 \
  25. --gradient_accumulation_steps 16 \
  26. --max_grad_norm 0.5 \
  27. --warmup_ratio 0.03 \
  28. --eval_steps 100 \
  29. --save_steps 100 \
  30. --save_total_limit 2 \
  31. --logging_steps 10 \
  32. --use_flash_attn false \
  33. --push_to_hub false \
  34. --hub_model_id qwen-14b-qlora \
  35. --hub_private_repo true \
  36. --hub_token 'your-sdk-token' """
  37. os.system(command)
3、数据集样式和链接(根据自己使用的框架微调,不出意外,后面数据集还会变大)

最后给出搞定后的NL2SQL数据集(当然数据集还得调整,只是将数据格式整理如下)

  1. {
  2. "query": "创刊时间不早于1989年10月10日的期刊,按出版刊数降序排列给出期刊的名称以及语言",
  3. "SQL": "select 名称 , 语言 from 期刊 where 创刊时间 >= '1989-10-10' order by 出版刊数 desc",
  4. "table_name": "期刊",
  5. "column_name": ["词条id", "名称", "语言", "类别", "主办单位", "创刊时间", "国家", "出版刊数"],
  6. "db_name": "期刊"
  7. }

如想获取数据,请访问我们在modelscope的开源地址

Text2SQL-英文-150K · 数据集 (modelscope.cn)

Text2SQL-中文-180K · 数据集 (modelscope.cn)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/433496
推荐阅读
相关标签
  

闽ICP备14008679号