当前位置:   article > 正文

NLP炼丹技巧合集_linear warmup

linear warmup

原创:郑佳伟

NLP任务中,会有很多为了提升模型效果而提出的优化,为了方便记忆,所以就把这些方法都整理出来,也有助于大家学习。为了理解,文章并没有引入公式推导,只是介绍这些方法是怎么回事,如何使用。

一、对抗训练

近几年,随着深度学习的发展,对抗样本得到了越来越多的关注。通常,我们通过对模型的对抗攻击和防御来增强模型的稳健性,比如自动驾驶系统中的红绿灯识别,要防止模型因为一些随机噪声就将红灯识别为绿灯。在NLP领域,类似的对抗训练也是存在的。

简单来说,“对抗样本” 是指对于人类来说“看起来”几乎一样、但对于模型来说预测结果却完全不一样的样本,比如图中的例子,一只熊猫的图片在加了一点扰动之后被识别成了长臂猿。

对抗攻击”,就是生成更多的对抗样本,而“对抗防御”,就是让模型能正确识别更多的对抗样本。对抗训练,最初由 Goodfellow 等人提出,是对抗防御的一种,其思路是将生成的对抗样本加入到原数据集中用来增强模型对对抗样本的鲁棒性,Goodfellow还总结了对抗训练的除了提高模型应对恶意对抗样本的鲁棒性之外,还可以作为一种正则化,减少过拟合,提高模型泛化能力。

在CV任务中,输入是连续的RGB的值,而NLP问题中,输入是离散的单词序列,一般以one-hot向量的形式呈现,如果直接在raw text上进行扰动,那么扰动的大小和方向可能都没什么意义。Goodfellow在17年的ICLR中 提出了可以在连续的Embedding上做扰动,但对比图像领域中直接在原始输入加扰动的做法,扰动后的Embedding向量不一定能匹配上原来的Embedding向量表,这样一来对Embedding层的扰动就无法对应上真实的文本输入,这就不是真正意义上的对抗样本了,因为对抗样本依然能对应一个合理的原始输入。那么,在Embedding层做对抗扰动还有没有意义呢?有!实验结果显示,在很多任务中,在Embedding层进行对抗扰动能有效提高模型的性能。之所以能提高性能,主要是因为对抗训练可以作为正则化,一定程度上等价于在loss里加入了梯度惩罚,提升了模型泛化能力。

接下来看一下NLP对抗训练中常用的两个方法和具体实现代码。

第一种是FGM, Goodfellow在17年的ICLR中对他自己在15年提出的FGSM方法进行了优化,主要是在计算扰动的部分做了一点简单的修改。其伪代码如下:

对于每个样本x:

1.计算x的前向loss、反向传播得到梯度

2.根据Embedding矩阵的梯度计算出扰动项r,并加到当前Embedding上,相当于x+r

3.计算x+r的前向loss,反向传播得到对抗的梯度,累加到(1)的梯度上

4.将embedding恢复为(1)时的值

5.根据(3)的梯度对参数进行更新

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

具体pytorch代码如下:

1import torch
2class FGM():
3def __init__(self, model):
4.	        self.model = model
5.	        self.backup = {
   }
67def attack(self, epsilon=1., emb_name='embedding'):
8# emb_name这个参数要换成你模型中embedding的参数名
9for name, param in self.model.named_parameters():
10if param.requires_grad and emb_name in name:
11.	                self.backup[name] = param.data.clone()
12.	                norm = torch.norm(param.grad)
13if norm != 0 and not torch.isnan(norm):
14.	                    r_at = epsilon * param.grad / norm
15.	                    param.data.add_(r_at)
1617def restore(self, emb_name='embedding'):
18# emb_name这个参数要换成你模型中embedding的参数名
19for name, param in self.model.named_parameters():
20if param.requires_grad and emb_name in name: 
21assert name in self.backup
22.	            param.data = self.backup[name]
23.	    self.backup = {
   }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

需要使用对抗训练的时候,只需要添加五行代码:

需要使用对抗训练的时候,只需要添加五行代码:
1# 初始化
2.	fgm = FGM(model)
3for batch_input, batch_label in data:
4# 正常训练
5.	    loss = model(batch_input, batch_label)
6.	    loss.backward() # 反向传播,得到正常的grad
7# 对抗训练
8.	    fgm.attack() # 在embedding上添加对抗扰动
9.	    loss_adv = model(batch_input, batch_label)
10.	    loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
11.	    fgm.restore() # 恢复embedding参数
12# 梯度下降,更新参数
13.	    optimizer.step()
14.	    model.zero_grad()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

第二种是PGD, FGM直接通过epsilon参数一下算出了对抗扰动,这样得到的对抗扰动可能不是最优的。因此PGD进行了改进,多迭代几次,“小步走,多走几步”,慢慢找到最优的扰动。伪代码如下

对于每个样本x:

1.计算x的前向loss、反向传播得到梯度并备份

对于每步t:

2.根据Embedding矩阵的梯度计算出扰动项r,并加到当前Embedding上,相当于x+r(超出范围则投影回epsilon内)

3.t不是最后一步: 将梯度归0,根据1的x+r计算前后向并得到梯度

4.t是最后一步: 恢复(1)的梯度,计算最后的x+r并将梯度累加到(1)上

5.将Embedding恢复为(1)时的值

6.根据(4)的梯度对参数进行更新
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

可以看到,在循环中r是逐渐累加的,要注意的是最后更新参数只使用最后一个x+r算出来的梯度。具体代码如下:

1import torch
2class PGD():
3def __init__(self, model):
4.	        self.model = model
5.	        self.emb_backup = {
   }
6.	        self.grad_backup = {
   }
78def attack(self, epsilon=1., alpha=0.3, emb_name='emb.', is_first_attack=False):
9# emb_name这个参数要换成你模型中embedding的参数名
10for name, param in self.model.named_parameters():
11if param.requires_grad and emb_name in name:
12if is_first_attack:
13.	                    self.emb_backup[name] = param.data.clone()
14.	                norm = torch.norm(param.grad)
15if norm != 0 and not torch.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/361838
推荐阅读
相关标签
  

闽ICP备14008679号