当前位置:   article > 正文

天池NL2SQL Top15方案_nl2sql技术方案

nl2sql技术方案

目录

代码地址

Part0: 参赛成绩

Part1: 代码环境

环境配置步骤如下:

Part2: 预处理

一. 数值类型转化

二. 训练集数据清洗与分类

Part3:模型介绍

目录

代码地址

Part0: 参赛成绩

Part1: 代码环境

环境配置步骤如下:

Part2: 预处理

一. 数值类型转化

二. 训练集数据清洗与分类

Part3:模型介绍

Part4: 后处理

Part5: 模型效果评估

Part6: TODO

数字通用前后缀挖掘

同义词解决方案

BUGFIX: (FIXED)

部分冗余逻辑重写

附录:代码树

回复 "nl2sql" 获取数据集, 有问题请留言


 

Part4: 后处理

Part5: 模型效果评估

Part6: TODO

数字通用前后缀挖掘

同义词解决方案

BUGFIX: (FIXED)

部分冗余逻辑重写

附录:代码树


代码地址

https://github.com/yscoder-github/nl2sql-tianchi

Part0: 参赛成绩

  • 平台昵称: yscoder
  • 参赛形式: 个人
  • 复赛排名: 15

Part1: 代码环境


环境配置步骤如下:


1. 深度学习相关环境

配置详情

  • 显卡: 1080ti
  • OS: Ubuntu
  • Driver Version: 418.56
  • CUDA Version: 10.1
  • cudnn version

2. Python相关环境

  1. conda create --name nl2sql-yscoder1 python=3.6
  2. source activate
  3. conda activate nl2sql-yscoder1
  4. pip install --upgrade pip
  5. pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements

root@481c7f8087a2

执行训练任务(训练入口):

  1. cd code
  2. sh start_train.sh

执行推理任务(推理入口):

  1. cd code
  2. sh start_test.sh

3. 模型与训练文件信息

  • Bert中文预训练模型(哈工大版) ,该模型存储位置: ./submit/chinese-bert_chinese_wwm_L-12_H-768_A-12 下载地址

  • Bert-finetune模型,该模型为经过finetune之后的,适配于当前nl2sql任务的训练模型, 该模型存储路径 ./data/weights/nl2sql_finetune.weights 训练好的模型链接: https://pan.baidu.com/s/11tQVSZl9e6VBLPp85c9JRQ 提取码: 4ceg

  • 数据集位置

Part2: 预处理


一. 数值类型转化

 官方提供的数据集中考虑到用户输入的多样性,所以question包含了各种展现形式的数值类型.模型处理时首先对question进行了格式统一,具体如下:

  1. 百分数转换,例如百分之10转化为10%,百分之十转换为10%, 百分之一点五转化为1.5%

  2. 大写数字转阿拉伯数字,例如十二转换为12二十万转化为200000; 一点二转化为1.2

  3. 年份转换,如将12年转换为2012年

数值类型转换过程当中用到了大量正则匹配,主要匹配出数字可能出现的位置,以利于后续对训练集做标记.
具体代码位置: question_prepro.py
功能汇总函数: trans_question_acc


二. 训练集数据清洗与分类

具体代码文件: new_mark_acc_ensure.py check_input_feature.py


Part3:模型介绍

Todo (这两天把图做好 )

部分训练日志

注意,如果执行此代码报错,需要修改一下Keras的backend/tensorflow_backend.py,将sparse_categorical_crossentropy函数中原本是

   logits = tf.reshape(output, [-1, int(output_shape[-1])]) 

的那一行改为

logits = tf.reshape(output, [-1, tf.shape(output)[-1]])

模型解决的特殊问题:

  • AB型问题 例如:初一欢乐寒假这本书多少钱。 这个句子当中的初一与欢乐寒假是相邻的两个不同列的候选值
  • 一个候选值对应多列, 例如 eps2011与eps2012均大于10的股票有哪些? 这里的10对应了表中的两个列

Part4: 后处理

后处理主要包括如下几块:

  1. 数字中的部分数位缺失,例如200000 模型之预测到200,根据数字的性质,可以对缺失的数位做补齐处理
  2. question中数值单位和表中单位进行统一. 例如question当中票房的单位是"亿",而在相关表中该列的单位为"百万". 本数据集中数值相关的单位存储在表的列名称或者表的title中.

代码文件: post_treat.py


Part5: 模型效果评估

模型效果评估部分,主要采用官方baseline中的方法,并进行了一定封装.主要用于对预测各个部件的准确性进行评估,并存储预测的错误结果以用于后续分析.
功能所在文件为calc_acc.py
主要函数为check_part_acc


Part6: TODO

数字通用前后缀挖掘

由于时间关系,原有方案在将文本中的"中文数字"转换为阿拉伯数字时, 用了一种1字前缀+1字后缀的方式来匹配"中文数字" ,例如:

'概天','到月', '于元', '在年', '于平', '足平', '过股', '过套', '招位', '前', '前中', '前名', '前个', '于的'

上面的这几种词对也是通过对数据集进行处理而得出的。表面上看,虽然上面的几种词对可以将一些数字匹配出来,但是这种1字前缀+1字后缀的方式会将一些专有名词中的中文数字进行误换. 所以后期需要做数字通用前后缀挖掘,即挖掘出更好的n字前缀+n字后缀

有兴趣的可以看看是否可以将所有的数值相关的训练材料从阿拉伯数字转换为中文数字,最后评估下模型效果

同义词解决方案

例如如何将问题中的鹅场和腾讯关联起来

BUGFIX: (FIXED)

  1. File "nl2sql_main.py", line 816, in <module>
  2. callbacks=[evaluator]
  3. File "/home/yinshuai/anaconda3/envs/nl2sql-yscoder/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
  4. return func(*args, **kwargs)
  5. File "/home/yinshuai/anaconda3/envs/nl2sql-yscoder/lib/python3.6/site-packages/keras/engine/training.py", line 1418, in fit_generator
  6. initial_epoch=initial_epoch)
  7. File "/home/yinshuai/anaconda3/envs/nl2sql-yscoder/lib/python3.6/site-packages/keras/engine/training_generator.py", line 251, in fit_generator
  8. callbacks.on_epoch_end(epoch, epoch_logs)
  9. File "/home/yinshuai/anaconda3/envs/nl2sql-yscoder/lib/python3.6/site-packages/keras/callbacks.py", line 79, in on_epoch_end
  10. callback.on_epoch_end(epoch, logs)
  11. File "nl2sql_main.py", line 798, in on_epoch_end
  12. acc = self.evaluate()
  13. File "nl2sql_main.py", line 806, in evaluate
  14. return evaluate(valid_data, valid_tables)
  15. File "nl2sql_main.py", line 721, in evaluate
  16. R = nl2sql(question, table) #
  17. File "nl2sql_main.py", line 555, in nl2sql
  18. if entity_start_pos_list[0] != v_start:
  19. IndexError: list index out of range
  20. 2209it [00:25, 86.28it/s]

部分冗余逻辑重写

附录:代码树

  1. ├── code
  2. │ ├── calc_acc.py
  3. │ ├── check_input_feature.py
  4. │ ├── config.py
  5. │ ├── dbengine.py
  6. │ ├── hand_set.py
  7. │ ├── new_mark_acc_ensure.py
  8. │ ├── nl2sql_main.py
  9. │ ├── post_treat.py
  10. │ ├── question_prepro.py
  11. │ ├── start_test.sh
  12. │ ├── start_train.sh
  13. │ └── utils.py
  14. ├── data
  15. │ ├── chinese-bert_chinese_wwm_L-12_H-768_A-12
  16. │ │ ├── bert_config.json
  17. │ │ ├── bert_model.ckpt.data-00000-of-00001
  18. │ │ ├── bert_model.ckpt.index
  19. │ │ ├── bert_model.ckpt.meta
  20. │ │ └── vocab.txt
  21. │ ├── logs
  22. │ │ ├── evaluate_pred.json
  23. │ │ ├── where_cnt_error.log
  24. │ │ ├── where_col_error.log
  25. │ │ ├── where_oper_error.log
  26. │ │ └── where_val_error.log
  27. │ ├── prepare_data
  28. │ │ ├── col_in_text
  29. │ │ ├── new_q_correct
  30. │ │ ├── new_q_exactly_match
  31. │ │ ├── new_q_exactly_more_strict_match
  32. │ │ ├── new_q_need_col_sim
  33. │ │ ├── new_q_no_num_similar
  34. │ │ ├── new_q_one_vs_more_col
  35. │ │ └── new_q_text_contain_similar
  36. │ ├── train
  37. │ │ ├── train.db
  38. │ │ ├── train.json
  39. │ │ └── train.tables.json
  40. │ ├── val.db
  41. │ ├── valid
  42. │ │ ├── val.db
  43. │ │ ├── val.json
  44. │ │ └── val.tables.json
  45. │ └── weights
  46. │ └── nl2sql_finetune.weights
  47. ├── Dockerfile
  48. ├── requirements
  49. ├── img
  50. │ └── nl2sql_model_old.png
  51. ├── README.md
  52. └── run.sh

回复 "nl2sql" 获取数据集, 有问题请留言

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/433502
推荐阅读
相关标签
  

闽ICP备14008679号