赞
踩
传统自训练(self-training)方法采用固定的探索式算法,在不同数据集上表现不一。
本文采用强化学习框架学习数据选择策略,提供更可靠的数据。
处理零样本文本分类通常有两个主要的方法,目前工作主要在第1点,忽略了第2点:
直接用传统的自训练方式可能会遇到一些问题:
本文主要贡献:
自训练有两个缺陷:
模型框架如下:
首先在训练集上训练基础的文本匹配模型,然后在测试集上预测。策略网络在预测的结果中进行样本的挑选,策略网络的奖励来源于匹配模型在验证集上的效果。若当前策略网络采取了正确的策略,挑选出了高质量的样本,那么模型期望会在验证集上获得较好的 performance,则会获得正向的奖励;相反若策略网络采取错误的策略,则模型获得较差的结果和负向的奖励。
对于基础的文本匹配模型,本文采用了预训练模型 BERT,BERT 的输入为句子和类别文本的拼接输出为该句子和类别的匹配分数,如图所示。
state:当前状态包括两部分:[CLS]对应的向量表示
c
x
,
y
∗
c_{x,y^*}
cx,y∗,以及预测的confidence分数
p
x
,
y
∗
p_{x,y*}
px,y∗
action:agent 需要判断是否选择当前实例
(
x
,
y
∗
)
(x, y^*)
(x,y∗)
reward:根据验证集的匹配效果计算 reward,计算公式如下:
r
k
=
(
F
k
s
−
μ
s
)
σ
s
+
λ
⋅
(
F
k
u
−
μ
u
)
σ
u
r_{k}=\frac{\left(F_{k}^{s}-\mu^{s}\right)}{\sigma^{s}}+\lambda \cdot \frac{\left(F_{k}^{u}-\mu^{u}\right)}{\sigma^{u}}
rk=σs(Fks−μs)+λ⋅σu(Fku−μu)
其中:
policy Network:使用多层感知机作为挑选策略网络,输入为state,输出为是否挑选当前实例的概率(action 的概率),计算公式如下,
z
t
=
ReLU
(
W
1
T
c
x
,
y
∗
+
W
2
T
p
x
,
y
∗
+
b
1
)
P
(
a
∣
s
t
)
=
softmax
(
W
3
T
z
t
+
b
2
)
其中:
整个模型的伪代码如图:
数据集采用 EMNLP19年的工作:Yin W, Hay J, Roth D. Benchmarking zero-shot text classification: Datasets, evaluation and entailment approach[J]. In EMNLP 2019.
包括3个数据集:话题、情感、情景,另外再加电商数据集。去除掉多标签数据,只考虑单标签数据。
文本匹配baseline方法:
本文方法:
Generalized 方法:类标签来自看不见的类和看到的类
结果:平均提升了15.4%
non-generalized 方法:类标签来自看不到的类
结果:平均提升了5.4%
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。