当前位置:   article > 正文

ACL 2023 | DaMSTF: 面向领域适应的领域对抗性学习增强自我训练框架

领域自适应 对抗

62a0127cbc3a928da3772040f2303e72.gif

本文介绍一篇被 ACL 2023 录用的领域自适应工作,领域自适应问题在 NLP 的应用部署中受到了广泛的关注。在这些领域自适应的研究中,自训练方法是一类重要的方法。通过将模型的预测作为未标记数据的伪标签,自训练用目标域中的伪实例引导模型。然而,伪标签中的标签噪声(预测误差)会对自训练方法的有效性带来挑战。

为了减少伪标签中的标签噪声,以前的方法往往只使用可靠的伪实例,即具有高预测置信度的伪实例,来训练模型。尽管这些策略有效地减少了标签噪声,但它们也会丢失难样本 (hard examples)。

在本文中,我们提出了一种新的面向领域适应的自我训练框架,即领域对抗性学习增强自我训练框架(DaMSTF)。首先,DaMSTF 通过引入元学习来估计每个伪实例的重要性,可以同时减少标签噪声并保留困难示例。同时,为了保证元学习模块在调整样本权重方面的有效性,我们设计了一个元构造函数来构造元验证集,从而提高元验证集的质量。

此外,我们发现元学习过程中容易出现训练梯度消失的问题,从而倾向于收敛到次优解。为此,我们采用领域对抗学习作为启发式的神经网络初始化方法,这可以帮助元学习模块收敛到更好的最优值。总的来说,本文的贡献包含以下几个方面:

  • 本研究提出了一个新的自训练框架来实现域适应,它通过引入元学习同时减少标签噪声和保留难样本。

  • 本研究提出了一个元构造函数来构造元验证集,这保证了元学习模块的有效性。

  • 本研究从理论上分析了元学习模块中的训练指导消失问题,并提出用领域对抗 学习模块来解决这个问题。

  • 理论上,本研究分析了所提出的 DaMSTF 方法在解决领域适应方面的错误率上界。实验上,DaMSTF 在基准数据集上超越了的所有基线方法。

0a152b3e28bdc692f07e25366481fe42.png

论文标题:

DaMSTF: Domain Adversarial Learning Enhanced Meta Self-Training for Domain Adaptation

论文地址:

https://aclanthology.org/2023.acl-long.92.pdf

代码地址:

https://github.com/LuMelon/DaMSTF_ACL_2023.git

2e4059c93103fb3e2a3a7b22509825b2.png

自训练方法和领域自适应

d2358c4137778f7a7140156529b3eb8e.jpeg

▲ 图1. 自训练示意图

自训练方法的流程图如图 1 所示,它由“伪标签”阶段和“模型再训练”阶段的一系列循环组成。在伪标签阶段,我们将模型的预测作为未标记数据的伪标签目标域。基于这些伪标记实例,在模型再训练阶段对当前模型进行再训练。重复这两个阶段,便可将一个源领域上训练得到的模型迁移到目标领域,从而实现领域自适应。这个自训练过程可以形式化为优化以下目标函数:

829102b37792a72d0a08009b509d814e.png

上式中, 是源领域上的监督损失, 是目标领域上的基于伪标签的弱监督损失, 是平衡 和 的系数。

d48471f4dbddb3ce16a1b6e4747ad7ce.png

DaMSTF方法介绍

926eec8442f6fcb864b4c633a433f344.png

▲ 图2. DaMSTF 方法概览。图中红色箭头表示训练过程中模型的流动方向,蓝色和绿色箭头分 别目标领域和源领域的数据流动方向。

如图 2 所示,DaMSTF 继承了自训练的基本框架,即由 “伪标注阶段” 和 “模型重训练阶段” 的一系列迭代构成:

(i) 在伪标注阶段,DaMSTF 预测目标领域中的无标注数据,并将预测作为伪标签;

(ii) 将这些伪标注数据传送到元构造器,具有高预测置信度的样本被用于扩展元验证集,其它的则被用于构造元训练集;

(iii) 在模型再训练阶段,DaMSTF 首先在域对抗训练模块中训练模型,这一过程会对齐特征空间并初始化模型参数;

(iv) 然后,模型在元学习模块中进行训练, 过程由 “元训练” 步骤和 “元验证” 步骤的一系列循环所组成,它会一边优化超参数(样本权重)一边优化模型参数。之后,DaMSTF 重新开始另一次自训练迭代。

在这里,元构造器是结合元学习和自训练的重要纽带。一方面,由于固有的标签噪声,传统的模型训练方法无法利用那些具有高预测熵的伪标注数据。在这种情况下,元构造器使用它们来构造元训练集,因为元学习模块可以容忍元训练集中的标签噪声。另一方面,在元学习中,那些具有低预测熵的伪标注样本不能为改进模型提供额外信息,但是它们包含较少的标签噪声。在这种情况下,元构造器使用它们来构造或扩展元验证集,用以验证模型,这可以提高元验证集的质量。

378096096ef0d331637e51881564d9c1.png

理论分析

3.1 训练指导消失问题的分析

05d935f9304baf3ea24e57b25ed7314c.png

根据定理 1,如果模型在元验证集上的梯度很小,即 很小时,对于每个样本 而言都会很小,即元学习过程无法再提供新的训练指导,这个问题在本章中被称之为训练指导消失的问题。在 DaMSTF 中,训练指导消失问题的带来的挑战体现在以下几个方面:

首先,元验证集比元训练集小得多,因此模型在元验证集上比在元训练集上收敛得更快。考虑到神经网络的优化是非凸的,如果模型在元验证集上收敛得太早,它可能会收敛到次优。在这种情况下,模型在元验证集上的梯度非常小,它会导致训练指导消失问题。

其次, 中的样本是预测熵较小的样本。由于对伪标注样本的监督正是模型的预测,较低的预测熵导致较低的风险损失。这时,从风险损失反向传播的梯度可以忽略不计,这也会导致训练指导消失问题。

3.2 DaMSTF 泛化能力的分析

本文从理论上分析了 DaMSTF 在实现领域适应方法的有效性。Theorem2 给出了 DaMSTF 方法在领域适应场景下的错误率上界。

08ba39661aa23be5133011d2630b79f4.png

3ebcdce5016c65f655350f6a2680a045.png

基于定理 2,本研究认为 DaMSTF 中算法设计的有效性体现在以下几个方面:

首先,扩展元验证集可以减少定理 2 中的第二项,即。根据定理 3,小于 ,因为 和 中的输入集合都是 的输入集合的子集。因此,扩展元验证集可以减少 的上限。

其次,由于 在每次自训练迭代中变化,DaMSTF 可以利用目标领域中无标注数据的多样性。因此,在整个训练过程中会接近于 。

此外,通过选择具有最低预测熵的示例, 上的错误率会低于 上的预期错误 率,即 。因此,元构造函数中的数据选择过程中减少了定理 2 中的 。

8381b9a59a9a605ca42a0c6b4fac178c.png

解决方案

4.1 领域自适应性能

为了验证元自训练的有效性,本文在两个基准数据集(即 TWITTER 上的 BiGCN 和亚马逊上的 BERT)上进行了无监督和半监督领域适应实验。实验结果列于表 1、表 2。

2c5bdabd9c5aa1d099705fd91558911b.png

如表 1、表 2 所示,DaMSTF 在所有基准数据集 上都优于所有基线方法。在谣言检测任务上,DaMSTF 平均超过最佳基线方法(CST 用于无监督域自适应, WIND 用于半监督域自适应)近 5%。在情感分类任务上,DaMSTF 也优于其他方法。在无监督域适应场景下,DaMSTF 平均超过最佳基线方法(亚马逊数据集上的 DANN)近 2%。在半监督域自适应场景下,DaMSTF 平均超过 Amazon 数据集上最好的基线方法 Wind 2.28%。

4.2 无标注数据规模的影响

如前文 Theorem 2 所示,第二项 会在整个自训练过程中逼近。从这个角度来说,增加未标注数据集的大小可以减小 ,从而降低误差上界,提高性能。为了验证这一点,本实验分别提供了原始 中 0%、5%、. . .、100% 的未标注数据用于 DaMSTF 中的训练,这些未标注数据集分别表示为 Du (0%), Du (5%)、. . .、Du (100%)。本实验是在 TWITTER 数据集上的“Ott”事件上进行的,实验结果如图 3 所示。


a4311317e87fb10fcca7838da53556f7.png

▲ 图3. 无标注数据集 DTu 的大小的影响

从图 3 中可以观察到,在训练过程中使用的未标记数据较少时,模型表现较差。具体地,使用 来训练 DaMSTF 时仅获得 0.701 的 F1 分数,这比使用 来训练 DaMSTF 时的 0.843 少了 0.142。

此外,可以看到,当给出的无标注数据的比例从 0% 增加到 50% 时会持续提高 F1 分数,但是这一比例在超过 50% 之后就会达到饱和。这一现象可以用统计理论中的大数定律来解释 [26]。因为给出的 50% 的无标注数据的分布会趋近于所有无标注数据的整体分布,即 接近于 ,所以 也会近似于 ,这就导致了性能增长的饱和。

4.3 样本权重值的有效性

为了研究元学习模块的有效性,本研究将对元学习上可视化不同伪标注样本上的优化实例权重。图 4 可视化了实例权重、伪标签的正确性和伪标签的置信度之间的相关性。

88654a47df9937901048b0e09c3a7104.jpeg

▲ 图 4 权重在不同的伪标注样本上的分布。曲线的高度代表概率密度。不同的子图显示了具有不同预测置信度的伪标注样本,右边缘的标题是置信区间。在每个置信区间中,黄色曲线表示正确伪标注样本上的权重 值分布,而蓝色曲线表示错误伪标注样本的权重值分布。

图 4 是水平方向的小提琴图,其中每条曲线代表实例权重的分布。不同子图中的伪标注样本有着不同的预测置信度,其中右侧的图注是每个子图的预测置信度区 间。一般来说,子图 (v) 中的伪标注样本是困难的伪标注样本,因为模型对它们的预测置信度最低。需要注意的是,概率密度在整个集合中进行了归一化处理,即所有曲线下的面积之和等于 1.0。从图 4 中的结果可以得到以下结论:

首先,元学习模块可有效降低标签噪声。在不同的置信区间,特别是在 [0.5-0.6] 和 [0.6-0.7] 中,蓝色曲线的峰值小于 0.2,这意味着错误的伪标注样本主要分配 了较低的实例权重。因此,减少了错误伪标注样本的不利影响。

其次,较大的实例权重被分配给具有低置信度的正确伪标注样本。具体来说,大实例权重(即 >0.5)主要 出现在底部的两个子图中,因此大实例权重主要分配给预测置信度低于 0.7 的正确伪标注样本。因此,元学习模块在挖掘硬伪示例方面也很有效。

38062dc665edfa14010d4ed0f5dc7e00.png

总结

本文面向领域适应问题,提出使用领域对抗学习和元学习来改进传统的自训练框架,改进后的框架命名为 DaMSTF (Domain Adversarial Learning enhanced Meta-Self-Training Framework, 领域对抗学习增强的元自训练框架)。DaMSTF 的核心在于使用元学习来对模型进行重训练,它可以在学习的过程中自动识别每个样本的价值,达到同时降低标签噪声并保留伪标注难样本 的效果。

为了更好地结合元学习和自训练,本研究提出了一个元构造器来对伪标注数据进行整理。同时,本研究还提出了一个领域对抗模块来防止元学习中的训练指导消失问题,从而为元学习过程提供更好的初始模型参数。这个领域对抗学习模块的另一个好处是它能在自训练的过程中对齐特征空间,因此能在领域适应场景下增强模型的性能。在两个基准数据集上对两种流行模型 BiGCN 和 BERT 的实验验证表明,DaMSTF 比以往的跨领域迁移方法更有效。

更多阅读

1fc1b50f376c84b297fda1cfc82f8d7a.png

ef6997092356fe60bf77bf3731705f3b.png

bc89a61080cb8e390b695b17c3d15b20.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/197112
推荐阅读
相关标签