当前位置:   article > 正文

NL2SQL_nl2sql学习(5):model1代码学习(详细注释)_一枚小白的日常的博客-csdn博客

nl2sql学习(5):model1代码学习(详细注释)_一枚小白的日常的博客-csdn博客

用途

1、NL2C++? NL2Python?
2、后端查询

什么是pointer network

  • 主要用于解决组合优化问题(TSP,Convex Hull),实际上是Seq2Seq中encoder和decoder的扩展,主要解决的问题是输出字典长度不固定问题(或说离散序列)

    • 传统Seq2Seq解决一些翻译问题,输出向量的长度一般就是字典的长度(一般是个超参数,提前订好了)

    • 就比如求二维凸包的问题(Convex Hull),属于n个seq,输出m个seq。但是在n不一样的时候就不work了

    • 而输出字典可变的问题就是下面的公式,就是把输出中attention的加权求和改成了条件概率的形式:

      • 普通的attention

        • img
      • 其实就是简单的修改了一下,修改后

      • u j i = v T t a n h ( W 1 e j + W 2 d i ) j ∈ ( 1 , . . . , n ) p ( C i ∣ C 1 , . . . , C i − 1 , P ) = s o f t m a x ( u i ) u_j^i=v^Ttanh(W_1e_j+W_2d_i)\qquad j\in(1,...,n)\\ p(C_i|C_1,...,C_{i-1},P)=softmax(u^i) uji=vTtanh(W1ej+W2di)j(1,...,n)p(CiC1,...,Ci1,P)=softmax(ui)

    • img

Seq2SQL: Generating structured queries from natural language using reinforcement learning

  • 这个方法结合Seq2Seq和RL,内容与Seq2Seq应用于Chatbot类似
    • 将输入语句encode后再decode成结构化的SQL语言输出
    • RL在Seq2SQL最后一个模块的应用
  • 这篇文章还推出了WikiSQL的数据集,后面相关算法大都会对这个数据集来评价

img

  • 大框架非常的简单

具体一点

首先提出了一个Augmented pointer network

Augmented pointer network

首先是输入的构造。

输入序列是由列 x j c x_j^c xjc(column,包含了那列的单词),SQL词典 x s x^s xs(SQL一些独特的单词),还有查询的问题 x q x^q xq组成的
x = [ < c o l > ; x 1 c ; x 2 c ; . . . ; x n c ; < s q l > ; x s ; < q u e s t i o n > ; x q ] x=[<col>;x_1^c;x_2^c;...;x_n^c;<sql>;x^s;<question>;x^q] x=[<col>;x1c;x2c;...;xnc;<sql>;xs;<question>;xq]
然后对于输入序列token-by-token的处理

其他的就是上面pointer network的构造

模型重要组件

  • img

  • 有三个部分:聚合分类器(Aggregation classifier),select列的pointer,where条件解码器

聚合分类器

将用户输入的语句分类成是select count/max/min(这些是SQL中的聚合类型) 等统计相关的约束条件

这一部分就用上面那个Augmented pointer network来完成分类任务

  • 把用户的输入序列得到attention score对hidden state生成一个表征向量,然后再通过一个BP网络来多分类

select column

  • 这一部分就是判断用户要用哪个column

1、先把每个列 x j c x_j^c xjc放进一个LSTM里面得到一个encode

2、然后和上面的"聚类分类器"的方法类似,也得到一个类似表征向量的东西,再通过一个BP网络来多分类

where clause

其实也可以像Augmented pointer network用一个pointer decoder来类似的训练,但是由于这一块的等价的表述方法很多,最终生成的SQL可能和目标不同

  • 所以就考虑用RL
    • img

ACC

执行准确率从35.9%提高到59.4%,逻辑形式准确率从23.4%提高到48.3%

SQLNet: Generating Structured Queries from Natural Language without Reinforcement Learning

  • 这两篇竟然是同一年的(2017)论文,并在WikiSQL数据集上提升了9-13个点
    • 提出了一种方法处理"序列到集合"的生成问题
    • 提出了一种全新的注意力结构,称为columns attention

采用草图的方式

不用设计对于未知的输出语法的语义解析模型,这里采用草图的方式(和SQL语法高度一致)

  • 因此我们只用在草图上填空,而不是预测输出语法和内容
  • 草图要被设计的足够用而不影响泛化
  • 草图已经捕获了要预测的东西的依赖关系,所以最终预测的值之和草图上要填的值有关,这样避免了上面那篇论文在Seq2Seq中的“秩序问题”( “order-matters”)

草图的规范形式与依赖关系参考图

img

  • 左边的定义了一些草图的查询规范

    • $AGG表示聚合类型(比如:SUM、MAX等)
    • $COLUMN表示表的列名
    • $VALUE必须来自输入自然语言的子串
    • $OP表示比较符号(比如:=、<、>)
    • 结尾的*就是可重复
  • 右边的图就是一个依赖关系参考图了,可以综合更复杂的查询问题,把这个扩大

column attention(Seq to Set)

以where语句为例

  • 不是生成出一串列名(就是上面的 x j c x^c_j xjc),我们只是简单的预测哪些列出现在这个集合中

    • P w h e r e c o l ( c o l ∣ Q ) c o l 是 列 名 , Q 是 问 题 传 统 的 解 决 方 法 是 P w h e r e c o l ( c o l ∣ Q ) = σ ( u c T E c o l + u q T E Q ) ( σ 是 s i g m o i d , E c o l 是 列 名 的 e m b e d d i n g , E Q 是 问 题 , u c 和 u q 是 两 个 可 训 练 的 向 量 ; 他 们 的 维 度 都 是 h i d d e n    s t a t e 的 维 度 ) P_{wherecol}(col|Q)\qquad col是列名,Q是问题\\ 传统的解决方法是P_{wherecol}(col|Q)=\sigma(u_c^TE_{col}+u_q^TE_Q)\\ (\sigma是sigmoid,E_{col}是列名的embedding,E_Q是问题,\\u_c和u_q是两个可训练的向量;他们的维度都是hidden~~state的维度) Pwherecol(colQ)colQPwherecol(colQ)=σ(ucTEcol+uqTEQ)(σsigmoidEcolembeddingEQucuqhidden  state)

    • E c o l 、 E Q E_{col}、E_Q EcolEQ,他们是由两个BiLSTM(一个在col上,一个在Q上)的最后一个hidden state,并没有参数共享

在这里插入图片描述

  • 然而传统的方法有点问题:
    • 对于 E Q E_Q EQ,不能预测对于特定列的特定信息
    • 比如上面的:在语句select的时候number和列No.较相关,在语句where的时候playerplayer较相关
img

所以改变一下问题:我们计算 E Q ∣ c o l E_{Q|col} EQcol
w = s o f t m a x ( v ) v i = ( E c o l ) T W H Q i ( i ∈ { 1 , . . . , L } ) ( H Q 是 d × L 的 矩 阵 , 第 i 列 表 示 询 问 的 第 i 个 t o k e n 的 在 L S T M 输 出 的 h i d d e n   s t a t e 就 是 H Q i L 是 询 问 的 长 度 , w 是 注 意 力 权 重 , W 是 可 训 练 的 d × d 的 矩 阵 ) w=softmax(v)\\ v_i=(E_{col})^TWH_Q^i\qquad(i\in\left\{1,...,L\right\})\\ (H_Q是d×L的矩阵,第i列表示询问的第i个token的\\在LSTM输出的hidden~ state就是H_Q^i\\ L是询问的长度,w是注意力权重,W是可训练的d×d的矩阵) w=softmax(v)vi=(Ecol)TWHQi(i{1,...,L})(HQd×LiitokenLSTMhidden stateHQiLwWd×d)
所以有
E Q ∣ c o l = H Q w E_{Q|col}=H_Qw EQcol=HQw
然后我们把 E Q ∣ c o l E_{Q|col} EQcol替换掉上面的 E Q E_Q EQ就有
P w h e r e c o l ( c o l ∣ Q ) = σ ( u c T E c o l + u q T E Q ∣ c o l ) P_{wherecol}(col|Q)=\sigma(u_c^TE_{col}+u_q^TE_{Q|col}) Pwherecol(colQ)=σ(ucTEcol+uqTEQcol)
作者发现在 s i g m o i d sigmoid sigmoid之前多加一层网络可以提高1.5%的性能,所以最终形式为
P w h e r e c o l ( c o l ∣ Q ) = σ ( ( u a c o l ) t a n h ( U c c o l E c o l + U q c o l E Q ∣ c o l ) ) P_{wherecol}(col|Q)=\sigma((u_a^{col})tanh(U_c^{col}E_{col}+U_q^{col}E_{Q|col})) Pwherecol(colQ)=σ((uacol)tanh(UccolEcol+UqcolEQcol))

  • 然后的问题就是需要选择哪些列

    • 最简单的方法就是设定一个阈值 r ∈ ( 0 , 1 ) r\in(0,1) r(0,1),所有 P w h e r e c o l ( c o l ∣ Q ) ≥ r P_{wherecol}(col|Q)\geq r Pwherecol(colQ)r被选,但这样效果不好

    • 所以本文是用网络预测k个列会被选,然后取top-k。文章在这里简化一下,因为对于大多数问题这个k是有上限的,假设这个上限是 N N N(文章在估算的时候简单取了个4),那么最后就是一个N+1的分类问题(0~N)

    • KaTeX parse error: Expected '}', got '#' at position 11: P_{\text{#̲}col}(K|Q)=soft…

      • KaTeX parse error: Expected '}', got '#' at position 12: U_1^{\text{#̲}col},U_2^{\tex…是可训练的(N+1)×d和d×d的矩阵, s o f t m a x ( . . ) i softmax(..)_i softmax(..)i表示第i维的softmax输出
where中的OP slot
  • 预测三种操作(=,>,<)

P o p ( i ∣ Q , c o l ) = s o f t m a x ( U 1 o p t a n h ( U c o p E c o l + U q o p E Q ∣ c o l ) ) i P_{op}(i|Q,col)=softmax(U_1^{op}tanh(U_c^{op}E_{col}+U_q^{op}E_{Q|col}))_i Pop(iQ,col)=softmax(U1optanh(UcopEcol+UqopEQcol))i

  • 这里就是 U 1 o p U_1^{op} U1op是3×d的,其他都差不多
where中的 value slot

这里预测来自询问的字串(substring),这里用Seq2Seq来实现

  • 这里的decoder也是用pointer network,只不过attention的时候用的是column attention

P v a l ( i ∣ Q , c o l , h ) = s o f t m a x ( a ( h ) ) a ( h ) i = ( u v a l ) T t a n h ( U 1 v a l H Q i + U 2 v a l E c o l + U 3 v a l h ) i ∈ { 1 , . . . , L } P_{val}(i|Q,col,h)=softmax(a(h))\\ a(h)_i=(u^{val})^Ttanh(U_1^{val}H_Q^i+U_2^{val}E_{col}+U_3^{val}h)\qquad i\in\left\{1,...,L\right\} Pval(iQ,col,h)=softmax(a(h))a(h)i=(uval)Ttanh(U1valHQi+U2valEcol+U3valh)i{1,...,L}

  • 这个 a ( h ) i a(h)_i a(h)i的计算和上面的 E Q ∣ c o l E_{Q|col} EQcol类似

  • 然后用最可能的val值来做生成的第i个token

select语句

和where语句的很类似
P s e l c o l ( i ∣ Q ) = s o f t m a x ( s e l ) i s e l i = ( u a s e l ) T t a n h ( U c s e l E c o l i + U q s e l E Q ∣ c o l i ) i ∈ { 1 , . . . , C } P_{selcol}(i|Q)=softmax(sel)_i\\ sel_i=(u_a^{sel})^Ttanh(U_c^{sel}E_{col_i}+U_q^{sel}E_{Q|col_i})\qquad i\in\left\{1,...,C\right\} Pselcol(iQ)=softmax(sel)iseli=(uasel)Ttanh(UcselEcoli+UqselEQcoli)i{1,...,C}
对于聚合器,与where的OP有类似的操作
P a g g ( i ∣ Q , c o l ) = s o f t m a x ( U a g g t a n h ( U a E Q ∣ c o l ) ) i P_{agg}(i|Q,col)=softmax(U^{agg}tanh(U_aE_{Q|col}))_i Pagg(iQ,col)=softmax(Uaggtanh(UaEQcol))i

损失

每个slot的损失函数都类似如下
l o s s ( c o l , Q , Y ) = − ( ∑ j = 1 C ( α y j log ⁡ P w h e r e c o l ( c o l j ∣ Q ) + ( 1 − y j ) log ⁡ ( 1 − P w h e r e c o l ( c o l j ∣ Q ) ) ) loss(col,Q,Y)=-(\sum_{j=1}^C(\alpha y_j\log P_{wherecol}(col_j|Q)+\\(1-y_j)\log {(1-P_{wherecol}(col_j|Q))}) loss(col,Q,Y)=(j=1C(αyjlogPwherecol(coljQ)+(1yj)log(1Pwherecol(coljQ)))

  • y是表示那些列出现了的向量, α \alpha α是一个平衡正负数据的超参数(作者选的3)

一些细节

对于参数的共享

作者发现上面的多个LSTM(多个slot),用不同的网络权重,相同的embedding性能较好

embedding

  • Seq2SQL中:把embedding固定为一个glove效果较好
  • SeqNet:glove随着训练变化能提升两个点

在这里插入图片描述

缺点

  • 对中文不友好
  • where_val使用seq2seq出来的,与真实值差距大

在这里插入图片描述

  • where处的准确率较低(虽然where也是最复杂的地方)
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/350503?site
推荐阅读
相关标签
  

闽ICP备14008679号