当前位置:   article > 正文

PRN(20220826):Learning to Prompt for Continual Learning (CVPR 2022)[理解不了篇]

PRN(20220826):Learning to Prompt for Continual Learning (CVPR 2022)[理解不了篇]
@inproceedings{wang2022learning,
  title={Learning to prompt for continual learning},
  author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={139--149},
  year={2022}
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

本文要介绍的论文是Google AI发表在CVPR 2022上的一篇关于深度可持续学习方法的论文。我是被论文摘要最后一句吸引的,论文摘要最后一句说:“令人惊讶地是,L2P(论文提出方法代称)在没有记忆回放单元的前提下,取得了与记忆回放方法相当的持续学习效果。在这里插入图片描述
看完论文,我的确有很强的惊讶感,作者最后也没有给大家解释这个惊讶。我不信世间真有妖法,几经思索得到一个解释。大家有兴趣降妖的么,一起呀!

一、背景:妖法施在哪个环节呀?

现有的深度可持续学习方法其实就是缓解深度学习灾难性遗忘方法,前者是从目的角度,后者是从途径角度。当前解决深度学习灾难性遗忘的三种主要途径:1)参数正则化;2)模型集成;3)记忆巩固。(详见PRN(20200908):Frosting Weights for Better Continual Training

这些途径要么通过直接存储重要的历史数据(类似于强化学习采用的replay memory)、要么间接地利用网络保存历史数据映射关系(像知识蒸馏、Self-refreshing Memory Approaches等方法),再要么就是获取网络权值与历史数据的重要性(如突触智能)。

我们发现,后面的网络映射、网络参数途径是通过生成反应历史映射关系的虚拟数据或限制对历史任务重要的参数的变化来间接地使用了历史信息。

我们要控制一个模型始何调整,可以通过以下两种方式:

  • 喂给它特定的训练数据
  • 直接限制其参数的变化(正则化)来控制

然而,google AI的这篇论文的方法像是构造了一只无形的手,以一种很不直观的方式对网络进行了控制,来达到了可持续学习的目的。它在训练新任务时,没有利用历史信息来唤醒网络对历史数据的记忆,也没有对参数的调整进行约束,那它是如何做到没发生灾难性遗忘的呢?

引用下图来说明L2P方法的大致流程:下图左边是基于历史数据回放的可持续学习方法(此处只提了存储历史数据的一种方法,它只是记忆巩固途径中的一种显示利用历史数据的方法,像知识蒸馏则一种隐式利用历史数据的记忆巩固方法),右图则是L2P的数据流图。根据数据流图所示,L2P方法实现可持续学习的核心是一个可训练的Prompt pool(提示池)模块,该模块是受自然语言处理领域的 prompting techniques (提示技术)的启发。就是增加这样一个可学习的提示池模块,既不对历史数据进行存储,也不对后续的分类网络进行正则化,如何能够有效呢?
在这里插入图片描述

prompting techniques: prompt有提示、提示符的意思,在NLP中,提示可以通过手动定义,例如给输入语句加上不完整的句子,如“I love this movie.”这句输入后,再加上手动定义的提示"The movie is ___“,构成新的输入"I love this movie. The movie is ___”;也可以通过自动生成,例如,输入一个语句为 x x x,利用函数 x p = f p ( x ) x_p=f_p(x) xp=fp(x)生面它的提示,并构造新的输入 [ x , x p ] [x, x_p] [x,xp]。这样的做法可以更加自然的利用预训练模型,因为可以通过选取合适的prompt,来控制下游要解决的不同任务。

如果你没有作者在摘要中的那种惊讶感,很有可能说明你是领域内待入门者(也可能是大神哈)。深度可持续学习领域的一般研究者都会惊讶——妖法施在哪个环节呀?作为一名刚刚入门的可持续学习的爱好者的我,经过仔细思索,得到一个自己的理解,不一定是正解,欢迎交流。

二、原理

网上已经有一些博客对该论文进行了简要的介绍,甚至还有一篇很肝的全文翻译。了解L2P方法的整个过程是没有什么问题的。本文就对L2P方法进行简单介绍。

引原文中的图来说明L2P的方法的流程。首先,根据当前的输入状态 x x x从提示池中通过查询状态与各提示的键值的匹配程度,选择前N个最匹配的提示(假设提示池总共有M个提示,满足 M > N M>N M>N)。然后,将状态以及这N个提示组成新的输入。其中,Embedding层与Transformer Encoder是提前训练好的,在整个可持续学习过程中保持不变,Prompt poolClassifier在训练过程中是可学习的。
在这里插入图片描述
例如,对于二维图像输入 x ∈ R H × W × C x\in \mathbb{R}^{H\times W \times C} xRH×W×C,一个预训练的vision transformer(ViT) f = f r ○ f e f=f_r○f_e f=frfe(包括两部分:Embedding f r f_r frTransformer Encoder f e f_e fe,不包括Classifier)。将图像重塑(reshape)成二维的小图块(也被称为token) x p ∈ R L × ( S 2 . C ) x_p\in \mathbb{R}^{L\times (S^2.C)} xpRL×(S2.C),其中 L L L是图块的个数, S S S是图块的大小, C C C是原始图像的通道数。将 x p x_p xp中第一个token定义为 [ c l a s s ] [class] [class],是序训练模型Embedding的一部分(这部分具体见下面一段介绍)。通过预训练过的embedding f e : R L × ( S 2 . C ) → R L × D f_e:\mathbb{R}^{L\times (S^2.C)}\rightarrow \mathbb{R}^{L\times D} fe:RL×(S2.C)RL×D,将多层的小图块映射为嵌入特征 x e = f e ( x ) ∈ R L × D x_e=f_e(x)\in \mathbb{R}^{L\times D} xe=fe(x)RL×D,其中 D D D为嵌入特征的维数。提示池中的提示 p ∈ R N × D p\in \mathbb{R}^{N \times D} pRN×D,此刻状态从提示池获取了 N N N个提示,构成提示集 P e ∈ R N × D P_e\in \mathbb{R}^{N\times D} PeRN×D,提示集与嵌入特征一起构成新的输入 x p = [ P e ; x e ] x_p=[P_e;x_e] xp=[Pe;xe]。然后将 x p x_p xp输入预训练过的Transformer Encoder网络,得到 f r ( x p ) f_r(x_p) fr(xp)输入到需要训练的Classifier层,得到最终的分类结果。

借用官方博客中的动态图,可以清晰的了解L2P方法在多任务可持续学习任务中的学习过程。
在这里插入图片描述
最后,说说第一个token:[class]。在另一篇论文中,作者把图片输入划分成一系统小图块,作为Embedding层的输入,此外,作者在这些小图块的前面,也就是第一个token位置放置了一个与小图块大小一样都为 R 1 × D \mathbb{R}^{1\times D} R1×D的向量,并将其命名为[class]。论文中说,[class]一般是由原始未切割的完整图像作为输入,经过另一个embedding层得到,并且该embedding层的参数是随整个学习任务一起学习的。在L2P方法,假设这个embedding层与上文提到的embedding层都包括在预训练模型中,并保持不变。

在这里插入图片描述

三、算法伪代码

L2P方法的伪代码如下图所示:
在这里插入图片描述
我们重点关注损失函数:
L x = L ( g ϕ ( f τ a v g ( x p ) ) , y ) + λ ∑ k s i ∈ K x γ ( q ( x ) , k s i ) (1) \mathcal{L}_x=\mathcal{L}(g_{\phi}(f_{\tau}^{avg}(x_p)), y)+\lambda \sum_{k_{s_i}\in K_x}\gamma(q(x), k_{s_i})\tag{1} Lx=L(gϕ(fτavg(xp)),y)+λksiKxγ(q(x),ksi)(1)

上式右边第一项是关于Classifier的参数学习,第二项是关于提示池中的相关提示键值(key)与提示值(value)的学习。我们发现,第二项中并不包含分类器的参数,对第一项并不构成正则化的功能。L2P方法中的键(key)-值(value)是一张映射表,训练过程损失函数里包含 P e P_e Pe,由于这部分的参数是固定的,可以得到损失函数关于 P e P_e Pe的梯度,也即 ∇ P \nabla_P P。但是,key没有在网络模型中, ∇ K \nabla_K K是什么形式呢?我想象不出来,只能大胆假设, ∇ K \nabla_K K就是key值与原始图像的嵌入特征的差值。

  • K K K的更新式为 K ← K − α ∇ K L B K\leftarrow K-\alpha \nabla_K \mathcal{L}_B KKαKLB,其中 L B \mathcal{L}_B LB为累积损失值,直观的理解就是,如果当前整体预测准确,加强这些选中的key,使之更加接近当前的输入的嵌入特征。

  • P P P的更新式为 P ← P − α ∇ P L B P\leftarrow P-\alpha \nabla_P \mathcal{L}_B PPαPLB,则是利用网络中固化的权值来调节 P P P(这部分就要靠丰富的想象力了)。

那么是什么神秘的力量让分类器参数在可持续学习任务中没发生灾难性遗忘呢?在第五节将给出本文的猜想,并给出一个简单的论证过程。

四、实验结果

实验设置:
1)类别增量设置,在推理过程中任务标签是未知的;
2)学习域增量设置,输入域随时间变化;
3)任务不可知设置,没有明确的任务边界。

L2P方法与其它可持续学习方法的对比,作者用了预测准确率与遗忘性两个指标来评价(从论文给出的数据上看,效果还是不错的):
在这里插入图片描述

五、妖法施在哪儿[猜想失败]

截取论文图中的左半部分,你会发现这个过程与KNN过程一样:根据相似性,选择K个与输入状态最相似的样本,并根据这K个样本来得到输入状态的所属类别。如果将对键-值的学习过程也进行分析,我们会发现,L2P方法对提示池的学习与自组织映射网络的训练非常象。
在这里插入图片描述
如下图所示,是一个二维的自组织映射网络(SOM)示意图。在SOM训练过程中,一个输入会激活网络节点中与之最匹配的节点,然后利用该状态根据赫布学习规则对激活的节点及其邻近节点进行更新。SOM是一个无监督的聚类学习方法。
在这里插入图片描述

Hebbian learning(赫布学习) :赫布理论描述了突触可塑性的基本原理,即突触前神经元向突触后神经元的持续重复的刺激可以导致突触传递效能的增加。这一理论由唐纳德.赫布于1949年提出,又被称为赫布定律。赫布理论可以用于触释“联合学习”,在这种学习中,由对神经元的重复刺激,使得神经元之间的突触强度增加。这样的学习方法被称为赫布学习。赫布理论也成为非监督学习的生物学基础。

抛开具体的网络结构,单从外在的形式上看,我们可以把L2P的提示池学习过程约等于一个聚类,按照这种思路分析,那么整个过程就是:根据当前输入状态,从提示池选择N个提示集,属于同一类别的输入状态或是相似的状态,激活的是同样的(或者是非常相似的)提示集。

进行更进一步的简化,我们可以粗暴的把这个提示集用one-hot编码的类标签来替代。

下面就用这样本个简单例子来说明,假定输入 x ∈ R 2 x\in \mathbb{R}^{2} xR2,输出为 y ∈ R 2 y\in \mathbb{R}^{2} yR2,说明输入为二维向量,总共有二个类别。

我们做进一步简化,就直接以One-hot标签作为输入,以One-hot标签作为输出。

在这里插入图片描述

  • 初始化网络,初始化权值:
import torch
from torch.autograd import Variable
x0 = Variable(torch.tensor([[1.0, 0.0]]), requires_grad=False)
x1 = Variable(torch.tensor([[0.0, 1.0]]), requires_grad=False)
w1 = Variable(torch.zeros(2, 4), requires_grad=True)
w2 = Variable(torch.zeros(4, 2), requires_grad=True)
torch.nn.init.normal_(w1, mean=0, std=1)
torch.nn.init.normal_(w2, mean=0, std=1)
alpha = 0.001
print(x0, x1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 用第一个类别的数据( x = [ 1.0 , 0.0 ] , y = [ 1.0 , 0.0 ] x=[1.0, 0.0], y=[1.0, 0.0] x=[1.0,0.0],y=[1.0,0.0])训练网络:
for i in range(100):
    l1 = x0.mm(w1)
    y_pre = l1.mm(w2)
    loss = (y_pre-x0).pow(2).sum()
    print(loss)
    print(y_pre)    
    loss.backward()
#     print(w1.grad, w2.grad)
    w1.data -= alpha*w1.grad.data
    w2.data -= alpha*w2.grad.data
    w1.grad.data.zero_()
    w2.grad.data.zero_()
#     print(w1, w2)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

最后网络的预测输出如下:
在这里插入图片描述

  • 在已训练模型基础上,继续用第二个类别的数据( x = [ 0.0 , 1.0 ] , y = [ 0.0 , 1.0 ] x=[0.0, 1.0], y=[0.0, 1.0] x=[0.0,1.0],y=[0.0,1.0])训练网络:
for i in range(100):
    l1 = x1.mm(w1)
    y_pre = l1.mm(w2)
    loss = (y_pre-x1).pow(2).sum()
    print(loss)
    print(y_pre)    
    loss.backward()
#     print(w1.grad, w2.grad)
    w1.data -= alpha*w1.grad.data
    w2.data -= alpha*w2.grad.data
    w1.grad.data.zero_()
    w2.grad.data.zero_()
#     print(w1, w2)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

最后网络预测输出如下:
在这里插入图片描述

  • 用最后的模型来预测 x = [ 1.0 , 0.0 ] x=[1.0, 0.0] x=[1.0,0.0]的输出:
l1 = x0.mm(w1)
y_pre = l1.mm(w2)
print(y_pre)
  • 1
  • 2
  • 3

预测结果如下:
在这里插入图片描述

上面一系统过程,简单说就是,最开始用 x = [ 1.0 , 0.0 ] , y = [ 1.0 , 0.0 ] x=[1.0,0.0], y=[1.0, 0.0] x=[1.0,0.0],y=[1.0,0.0]训练后,模型对 x = [ 1.0 , 0.0 ] x=[1.0,0.0] x=[1.0,0.0]的预测输出为 y ^ = [ 0.9964 , − 0.0841 ] \hat{y}=[ 0.9964, -0.0841] y^=[0.9964,0.0841],与直值非常接近,误差为0.0071,在经过对 x = [ 0.0 , 1.0 ] , y = [ 0.0 , 1.0 ] x=[0.0,1.0], y=[0.0, 1.0] x=[0.0,1.0],y=[0.0,1.0]的训练后(模拟类别增量式学习任务),对 x = [ 1.0 , 0.0 ] x=[1.0,0.0] x=[1.0,0.0]的预测输出为 [ − 2.3276 , 0.0446 ] [-2.3276, 0.0446] [2.3276,0.0446],与直值相差很远。

还是产生了灾难性遗忘,即使用很强的提示——数据的类别标签替代数据的状态作为输入,也还是产生了灾难性遗忘。

后来回顾了BP的推导过程,灾难性遗忘的产生并不会因为你用什么形式的输入替代原始数据,只要是前后两批数据分布不同,你又没有利用任何历史信息来回忆的话,这个新数据集就会把那些能让你在之前数据集上输出接近真值的参数修改的面目全非。为什么,因为当前的任务是对新数据集的更好拟合,如果没有一个历史的代理人监督,当前的任务会自私地以拟合好当前数据集为第一要务。于是灾难性遗忘就产生了。

六、总结

Google AI提出来的这个新的可持续学习方法,说是通过增加一个可学习的提示池就能提升模型的可持续学习能力。本人表示很神奇,因为这种方法没有显示地利用历史信息时刻唤醒模型对历史知识的记忆,也没隐式的利用历史信息限制那些对历史知识重要的模型权重的更改。从这个提示池中,我隐约看到了聚类的影子,以为窥探到了方法的机理。最后,通过本用来证明猜想的论证过程,却再次证明自己的猜想根本不对,即:用包含提示数据类别的信息与数据状态一起作为输入,能够提升可持续学习的能力。这是一个错误的猜想!!!

那妖法到底施在哪儿了呢?

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号