当前位置:   article > 正文

重读经典(CLIP下):《Learning Transferable Visual Models From Natural Language Supervision》_linear probe分类

linear probe分类

上文链接:重读经典(CLIP上):《Learning Transferable Visual Models From Natural Language Supervision》

5. 实验

现在我们已经知道 CLIP 是如何进行预训练的以及作者为什么选用对比学习来预训练 CLIP,接下来我们就来读一下这个长达十几页的实验部分。实验部分一开始作者先介绍了一下什么是 zero-shot transfer?因为这个才是 CLIP 这篇文章的核心和精华所在。作者上来还介绍了他们为什么要这么做?

总结下来就是一句话:之前那些自监督或者无监督的方法主要研究的是特征学习的能力,目标是去学一种泛化性比较好的特征,比如说我们讲过的 MOCO、SimCLR 、DINO,但即使学到了很好的特征如果想应用到下游任务的时候还是需要有标签的数据去做微调,所以这里面就还牵扯各种问题,比如说下游任务不好去收集数据、比如说有 distribution shift 的问题,如何能够训练一个模型接下来就不再训练或者不再微调了呢?这就是作者研究 zero-shot 迁移的研究动机,一旦借助文本训练好了一个又大又好的模型之后就可以用文本作为引导去很灵活的做 zero-shot 的迁移学习,至少在分类上效果都非常的好。

在这里插入图片描述


至于怎么用 CLIP 去做 zero-shot 的迁移,现在回到图1里去再复习一下。当 CLIP 预训练好之后其实就有两个编码器,一个是图像编码器,一个是文本编码器。他们都已经训练好了,这时候任意给一张照片,比如说有狗的照片,通过图片编码器就会得到一个图片的特征,文本的输入是感兴趣的标签有哪些。比如说飞机、汽车、狗和鸟,暂且就用这四个词当做所有的标签。首先这四个词通过 prompt engineering 会变成这一个句子,也就是说飞机就会变成这是一张飞机的照片,汽车就会变成这是一张汽车的照片,四个单词就变成了四个句子。当有了四个句子之后通过文本编码器就会得到四个文本的特征,这时候拿着四个文本的特征去跟图像的特征去算 Cosine similarity 最后得到的相似度还会再通过一层 softmax 得到一个概率分布,哪个概率最大就是说哪个相似度最高,所对应的那个句子大概率就是在描述这张照片,也就是说那个句子里所包含的物体其实就是这个图片里应该有的物体了。

那如果我们现在想一下要给 ImageNet 里所有的图片去做这种推理,去测试一下模型的效果如何,这里会生成多少个句子呢?应该是1000个句子,因为 ImageNet 有1000个类,所以说对应的就会生成1000个句子。简单的来说就是说每张图片相当于都用1000个句子,这样去问是不是飞机,是不是汽车,是不是狗,是不是人?问1000次,然后跟哪个文本最接近,那就是哪一类。当然在文本端算特征的时候并不是这种顺序进行的,是可以批次进行的,所以说 CLIP 的推理还是非常高效的。

在这里插入图片描述


作者在3.1.3里就去做了一下和之前最相似工作 Visual N-Grams 的对比,在表1里展示了一个非常惊人的结果。就是 Visual N-Grams 之前在 ImageNet 的效果只有 11.5 的准确率,而 CLIP 已经达到了 76.2,而 76.2 的性能直逼之前的原版的 Res-50,而且现在完全没有用任何一张张训练图片的情况下,直接 zero-shot 的迁移就能得到 76.2。表一 CLIPVisual N-Grams 的直接对比可以看到, CLIP 是大幅度超越 Visual N-Grams,当然 ImageNet 上提升是最明显的,但是在剩下两个数据集上也都非常厉害。但是作者这里也指出了这并不是一个公平的对比,因为很多东西都不一样,CLIP 的数据集比之前的方法大了10倍,而且用的视觉上的模型比之前要多100倍的计算,所以说相当于在训练上他们用了超过1000倍的资源去训练这个模型,而且在模型的架构上 CLIP 用的是 Transformer,在2017年 Visual N-Grams 这篇论文发表的时候其实 Transformer 这篇论文还没出现,所以作者这里其实想强调的就是这不是一个公平的对比,这也是对之前的工作 show respect

在这里插入图片描述


那接下来我们就来说一下大家比较感兴趣的 3.1.4 节,也就是 prompt engineeringensembling,基于 prompt 学习方法最近非常火,不论是在 nlp 还是在 cv 领域都非常的火,如果你还在找研究方向其实这是一个不错的选择。因为它主要是在做微调或者直接做推理的时候用的一种方法,而不是在预训练阶段,所以不需要那么多的计算资源,但是因为效果好所以说它的影响力还是非常大的。

prompt 是什么呢? prompt 翻译过来有很多的含义,但是在这里其实对应的中文含义是 提示,起到一个提示的作用,也就是之前一直说的文本的引导作用。那为什么要做 prompt engineeringprompt Ensembling 呢?作者这里举了两个比较常见的例子:

  • 第一个就是作者这里说的 Polysemy,多义性也就是一个单词可以同时拥有很多的含义。如果在做文本和图片的匹配的时候每次只用一个单词,也就是标签对应的单词去做文本的特征抽取,那就有可能会面临到这种问题。比如说在 ImageNet 数据集里同时包含有两个类,一个类是 construction crane,一个是 crane,在相应的语境下这两个 crane 所对应的意义是不一样的。在建筑工地环境下指的是起重机,而作为一个动物它指的是鹤,就丹顶鹤那种鹤类动物。那这个时候就会有歧义性出现了,算的相似度很有可能就是错的。当然别的数据集也有这种问题,比如说 oxford-IIIT Pet 里面有一类叫 boxer,那其实如果只从动物的角度去理解,boxer 其实就指的是狗的一个种类,但是对于一个什么都不知道的文本编码器来说如果只把 boxer 输给它,很有可能就把这个词当成是一种运动员了,就是打拳击的,抽取的特征肯定就不对了。总之就是说如果只用一个单词去做 prompt 的话会经常出现歧义性这种问题。
  • 那另外一个问题之前也说过很多次了,就是在做预训练的时候匹配的文本一般都是一个句子,很少是一个单词,可是如果在做推理的时候每次都进来的是一个单词,那这里面可能就会存在 distribution gap 的问题。抽出来的特征可能就不是很好,所以基于这两个问题作者就想了一个非常简单的一个方式去做 prompt template

有一个提示的模板,也就是说把标签放在这里,然后把它变成一个句子。所有的标签都变成了这是一个什么什么的图片?这样首先已经是一个句子了,就不太会出现 distribution gap 的问题,其次因为说的是这是一张什么的图片,所以说后面这个标签一般多指的是名词,上面歧义性的问题有时候就会解决了。所以用上了这个提示模板之后,作者发现准确度一下就提升了1.3%,而且这个模板非常的简单粗暴,所有的标签都是改成这样一种形式的句子。那当然 prompt Engineering,不光是提供这么一个提示的模板,他还可以做很多事情。作者就发现如果提前知道一些信息,对 zero-shot 的推理是非常有帮助的,比如说当你知道你现在做的就是 oxford-IIIT Pets 数据集的时候,它里面的类别肯定都是动物,所以你给出的提示就不光是说 a photo of,还可以在后面再加一句话,一种宠物的类别,那这样一下就把解的空间缩小的很小了,很容易得到一个正确的答案。那同样的道理,对于食物数据集就可以说这是一种食物,对于飞机数据集就可以说这是一种飞机,对于 zero-shot 的迁移都会非常的有用。作者这里还举例当你去做 OCR 这个任务,也就是说给定一个图片去里面找文字数字这种任务。如果你在你想找的那个文本或者数字上打上双引号,模型就更明白你的意思了。他可能就知道你就是想找双引号里面的这个内容,这个结果也是相当有意思的。

最后作者就讲了一下 prompt ensembling,就是说多用一些提示的模板做多次推理,然后最后把结果综合起来。那作者这里到底用了多少个提示样本去做这个 Ensemble 呢?在论文里呢说用了80个。

在这里插入图片描述


在 3.1.5 节作者就大范围的在27个数据集上去衡量了一下 CLIPzero-shot 的迁移的效果,27个数据集的对比结果放到了下图里。图比较的双方一个是做 zero-shotCLIP,另外一个就是在 Res50 的特征上去做这种 linear probe。这里 Res50 是在 ImageNet 上用有监督的方式训练好的一个模型,从中去抽特征然后在这些下游任务上去添加新的分类头,然后在新的分类头上去做 linear probe 的微调,作者这里是把这种方式当成了基线。如果 CLIP 模型比它表现的好就列在上面是绿色的,那全都是加号,意思就是说 CLIP 模型相比于这个基线提升了多少。那对于下面这些蓝色的就是说 CLIP 模型相对于基线要降低了多少的性能。

首先我们可以看到的是在16个数据集上也就是在大多数数据集上 Zero-ShotCLIP 模型都超越了之前这种用有监督预训练好的 Res50,这个结果是非常惊人的,真的证实了 zero-shot 的迁移也是有效的,是可以广泛进行应用的,而不是光在 ImageNet或者某些数据集上有作用。然后作者这里还做了一些总结就是对于普通的给物体进行分类的数据集来说 CLIP 一般都表现的比较好。像这种车、食物、动作、或者是像 CIFAR10CIFAR100ImageNet 这种普通的物体分类的数据集。CLIP 就能很好的去做 zero-shot 的迁移,这个也比较好理解,因为如果你这个图片中有一个可以描述出来的物体,那对应的文本里应该也有这个物体的描述,所以他俩就会匹配的非常好,CLIP 模型对这种物体也会比较敏感。但是对于更难的一些任务、更难的一些数据集
比如说 DTD,这种对纹理进行分类的数据集或者说 CLEVRCounts 这种去给图片中的物体计数的任务,对于 CLIP 来说就非常的难而且太过抽象了,所以 CLIP 模型在这些数据之上的表现就不好。因为对于这种难的任务完全不给他任何标签信息的话也有点强人所难了。所以作者最后也为 CLIP 打报了个不平,说对于这种特别难的任务如果只做这种 zero-shot 的迁移可能不是那么的合理,有可能去做 few-shot 的迁移会更合理,因为对于难的任务比如说去给肿瘤做分类需要特定领域知识的任务,即使是对于我们人来说如果我们没有先验知识也没法分类正确的。那为什么要去强求 CLIP去在 Zero-Shot 的时候能正确分类呢?

在这里插入图片描述


既然作者提到了 few-shot,作者觉得对于这些更难的数据集 few-shot 会比 zero-shot 的衡量更合理。下图里作者就做了一下 zero-shot CLIPfew-shot CLIP 以及和之前 few-shot 的那些方法的一些比较。

具体来说对于之前的方法就是用不同的方法去进行模型的预训练,然后当训练好这个模型以后就把这个模型的参数冻住,只从里面去抽特征做 linear probe,既然是 linear probe,就需要训练最后那个分类头了,所以就需要下游数据集里有标签的数据。那对于 CLIP 模型也一样,作者这里就是把 CLIP 里面图片的编码器拿出来冻住,然后去做 linear probe

下图的横坐标是说数据集里每一个类别里到底用了多少训练样本,0就是说的 zero-shot 什么都不用。别的这些方法因为没有和自然语言相结合,所以他们就没办法做 zero-shot,所以最低也得从 one-shot 开始,就是至少也得用一个训练样本,然后从 one-shottwo-shotfour-shot,最后到 sixteen-shot。纵坐标说的是 average score,平均的分类准确度。这里是做了20个数据集的平均准确度,意味着这里的每一条曲线其实是20个数据集结果的一个合并,而并不单单是一个数据集上的结果。那从这张图里可以观察到好几个有趣的结论:

  • 首先就是这条蓝色的曲线,对应的其实是 bit 模型,它主要就是为迁移学习而量身定做的。所以它算是迁移学习或者说 few-shot 的迁移学习里表现最好的工作之一。那在这里 bit 的模型是在 ImageNet21k 上去做预训练的,数据集也比较大。所以说这条蓝色的曲线非常具有代表性,是一个很强的 baseline,但是我们可以看到 zero-shot CLIP 不用任何训练样本直接就和最好的 bit 打成平手了,可见利用自然语言的威力。
  • 那第二个比较有意思点就是这条紫色的曲线,就是对 CLIP 里的图片编码器去做 few-shotlinear probe。我们可以发现在当训练样本只有1、2或者4的时候,这种用了训练样本的 few-shot 的方式还不如直接去做 zero-shotCLIP,也就再次证明了用文本去做引导的多模态学习是多么的强大。
  • 那最后一个观察就是随着训练样本的增多,few-shot 的学习的 CLIP 模型效果是最好。不仅超越了之前的这些方法,验证了 CLIP 模型的强大,同时还超越了这种 zero-shot CLIP,验证了作者刚才的说法。对于难的数据集来说,有一些训练样本还是非常有必要的。

在这里插入图片描述


说完了 zero-shot 又说完了 few-shot 那接下来很自然的就是要看一下如果下游任务用全部的数据 CLIP 的效果会如何。如果下游任务用全部的数据有很多种方式去衡量模型学到的特征好不好,最常见的两种方式一种就是 linear probe,另外一种就是 fine-tuningLinear Probe 就是把预训练好的模型冻住,然后在上面再训练一个分类头,那对于 fine-tune 来说就是把整个网络都放开,直接去做端到端的学习。微调一般是更灵活的,而且在下游数据集比较大的时候微调往往是要比 linear probe 的效果要好很多。但是在 CLIP 这篇文章里作者说就是要用 linear probe,不用 fine-tune,这里就列了几个原因:

  • 第一个原因就是 CLIP 就是为了用来研究跟数据集无关的预训练方式的,那如果下游数据集足够大然后我整个网络全都放开,在这个数据集上去做微调的话很有可能预训练的那个模型并不好,但是在微调的过程中经过不断的优化,最后的效果也很好。这样就无法分辨预训练的方法或者预训练的模型到底好不好了。而 linear probe 就是用线性的分类头的方式,就不太灵活。整个网络大部分都是冻住的,只有最后这一层 fc 是可以训练的,那这个可学习的空间就比较小,所以如果预训练的模型没有训练好的话,即使在下游任务上训练的再久也很难优化到一个特别好的结果。所以说能更准确的反映出预训练模型的好坏。
  • 另外一个作者选用 linear probe 的原因就是因为 linear probe 不太用调参。因为 CLIP 这篇论文做了大量的实验、涉及了大量的数据集,如果是做端到端的微调就有太多可以调的超参和设计方案了。比如如果这个下游数据集特别大而且数据标注也质量比较高的话,你就希望学习率可能会大一点。因为你想要预训练的模型尽可能的去拟合下游的数据集,但如果下游任务的数据集特别小你可能就得用特别特别小的学习率。因为很有可能稍一学习模型就过拟合了,总之你得为每个数据集都去量身定做,就你得去搜超参才能在各个数据集上表现好一点。但如果你是做 linear probe ,模型主体都已经冻住了,就是抽特征,唯一要学的就是最后那个分类层,所以可以调的参数非常少,而且不论对于什么数据集或者不论对于什么任务。只要是分类,整个测试的流程就是一个正规化的流程,什么都不用改,大大简化了方法之间的对比。

接下来作者就在下图里展示一下结果,作者这里对比了很多的方法,从他们自己的 CLIP 然后到有监督的 EfficientNet,还有用了伪标签的 EfficientNet,还有弱监督的在 instagram 上训练的模型,还有最近大火的对比学习自监督学到的模型以及还有一些经典的有监督学习的这些基线模型。左右两张图其实画的都是一个意思,横坐标都是说对于一张图片来说,做一遍前向过程需要用多少的计算量,然后纵坐标就是说在很多数据集上最后的平均准确率,准确率越高、用的时间越短,效果呢就越好。在这两张图里越靠左上角点的模型在精度和速度上的 trade-off 就做的更好。

首先我们先看右图,把27个数据上的效果平均了一下,然后这里我们可以看到 CLIP 里对应的红色五角星和红色的空心的五角星效果都是最好的,比剩下所有的这些模型的效果都要好。这就再次证明了 CLIP 模型的强大之处,这不光是 zero-shot 还有 few-shot、在用全部的数据去做训练的时候 CLIP 照样吊打别的所有模型。

左边这张图其实是为了跟之前的工作做一个公平对比,因为之前有个工作提出了12个数据集的一个集合。很多人就在这12个数据集上去比这个平均的效果,在这张图里我们可以看到用更大的模型 Vision TransformerCLIP 效果还是很好,但是用残差网络的 CLIP 就比别的方法要差了。但作者想说的意思是这12个数据集跟 ImageNet 的关系很大,就是说他们之间的关联性很高。如果你那些模型在 ImageNet 数据集上做过有监督的预训练,那肯定效果就会特别好,超过 CLIP 也就不足为奇了。

在这里插入图片描述


接下来我们就看第四章节,和人的表现对比。这里作者其实就找了5个人然后让他们去看 oxford IIT Pets 数据集里的测试图片,这里面其实就都是一些宠物了。主要就是一些细分,狗和猫的种类,一共是37类。那对于 zero-shot 的情况作者也做了强调,说这些参加实验的人是不能去网上进行搜索的,而且他们也没有给这些参赛者看过这些狗和猫对应的种类的示例图片。然后作者他们还做了 One-ShotTwo-Shot 的实验,One-Shot 就是给所有的这些参赛者每个种类看一张示例图片,告诉他这个种类的狗和猫到底长什么样,然后 Two-Shot 的就是看两张。

那我们来看看表2里对比的结果到底如何,我们可以看到 zero-shot CLIPzero-shot human 表现要好得多。那毕竟对于一个普通人来说,如果他对狗和猫的种类不是那么了解的话,那确实很难去做出这种 Zero-Shot 的分类。然后一旦给这个参赛者看过一张示例图片以后,准确度一下就从50多提升到70多了,说明人的学习速度还是非常快的。但如果给这些参赛者每个种类再多看一张照片,就是这里的 Two-Shot 的结果你会发现准确度并没有提升。这个倒是挺有意思的,这说明人之所以在给一张图片以后准确率能提升这么多,主要是他可以把常识、之前的这种先验知识和这个示例图片联合在一起去做判断。那如果他没有系统的学过这些知识,就是说他的先验知识并没有增加的情况下,那你再多给他看一到两张图片其实也是于事无补的,他并不会真的随着这个数据的增加准确率提升。

在这里插入图片描述
然后作者还做了一个有趣的实验,把这些宠物的种类都列了出来,看了一下针对每一个类模型的表现到底如何。蓝色的点对应的就是 zero-shot CLIP,橘黄色的点就对应的是 One-Shot 的人,绿色就对应的 Zero-Shot 的人。这里其实真实的数据点是这些点,然后这里的直线其实只是拟合过后的直线,只是为了方便读者看的。作者这里想说的意思就是说对于 CLIP 模型来说,难的类对于人来说也很难,对于 CLIP 来说简单的类对于人来说也简单。

因为从图里来看,这些分类准确率比较高的这些类 CLIP 和人的准确率都高。那对于分类准确度低的这些类来说 CLIP 的表现也都低,这个也是比较有意思,那很有可能还是这种现实生活中的这个数据分布有关系。那人类对于这种常见的宠物种类自然就知道,就在他的常识范围里面,那 CLIP 因为这些在数据集里肯定也出现过,所以说他的这个分类准确度也高。

在这里插入图片描述


6. 模型局限性

那接下来我们就来看一看作者写了哪些局限性和不足之处:

  • 第一点作者说 CLIP 在很多数据集上平均下来来看,是可以和一个比较简单的基线模型打成平手的,也就是我们反复说的那个 ImageNet 上训练的 Res50 的模型。但是在大多数数据集上 Res50 的模型其实根本就不是 state of the art,如果跟现在最好的模型比起来差很远。作者这里说文章中也做了实验,就是如果加大数据集而且加大模型,就是你去扩大这个规模,CLIP 的性能是还能继续提高的。但是如果你想把这十几个点的差距都弥补上,作者预估还要在现在训练 CLIP 计算量的基础上要乘以 1000 倍计算量。那这个代价就太大了,作者这里说即使对 OpenAI 来说用现有的硬件条件他们也是没办法训练。所以如果你想走扩大 CLIP 规模的这种方式去弥补十几个点的差距,CLIP 在所有数据集上都达到 SOTA 的效果的话那肯定需要有新的方法,然后在计算和数据的高效性上进行进一步的提高。
  • 那另外一个局限之处作者就是说 CLIP 在有些数据集上的 Zero Shot 的结果也并不好。比如在某些细分类的数据集上 CLIP 的效果也是低于有监督训练 Res50 的基线网络的,另外不光是细分类的任务,CLIP 还无法处理特别抽象的概念或者说更难的任务 (比如说去数一数图片里到底有多少个物体) 或者说在监控视频里区分当前这一帧是异常还是非异常。因为 CLIP 模型虽然很擅长去分类物体,但是他完全不了解什么叫异常什么叫安全。作者最后总结说他们坚信还有很多领域 CLIP Zero Shot 的性能其实是跟瞎猜一样的,也就是说在很多情况下 CLIP 都不行,并不是一个万能的方法。
  • 第三个局限性作者在这里说 CLIP 虽然说泛化做的很好,对于很多自然图像的分布偏移模型还是相对稳健。但是如果你在做推理的时候这个数据真的跟你训练的数据差的非常远,数据真的已经 out of distribution ,那 CLIP 的模型泛化照样也很差。作者这里举了一个例子,就是在 MNIST 的数据集上 CLIP 的准确率只有88%。作者就深入的研究了一下,就用各种去重的方法去看看他们搜集的4个亿图片的数据集里到底有没有跟 MNIST 的相似的图片。结果发现非常神奇的是即使他们的训练数据集有4亿个训练样本,但是就是没有跟 MNIST 的数据长的像的。这其实也就从侧面反映了 CLIP 模型也没什么大不了的,他跟普通的深度学习的模型一样都非常的脆弱。
  • 然后第四个局限性作者这里就说虽然 CLIP 可以去做 Zero Shot 的分类任务,但他还是从你给定的那些类别里去做的选择。那相比而言一种更灵活的方式就直接去生成图像的标题,这样的话一切都是这个模型在处理,所有都自动化,可以给你生成一个新的输出的,而不是像 CLIP 一样你得给他一个新的类别,然后他告诉你跟这个图片类似不类似。所以作者这里还是不忘 OpenAI 的老本行,还是想把一切都 gpt 化,都做成生成式的模型。那可惜受限于计算资源的问题,他们没办法去训练一个图像题目生成的基线网络,作者说以后可能会有这么一个简单的想法,就是说把对比学习的目标函数和生成式的目标函数合在一起。那这样的话就有可能把两个方法的优势结合在一起,就是既有了对比学习训练模型的高效性又有了生成式模型的灵活性。
  • 接下来作者又讨论了第五个局限性,就是说 CLIP 对数据的利用并不是很高效,他跟别的深度学习里的网络一样需要大量大量的数据去投喂。作者这里形象的描述了一下,他们这个数据集到底有多大,在他们训练的过程中他们一共训练了32个 Epoch,那每个 Epoch 可要过4亿个图片,所以说最后一共就相当于是跑了128亿张图片。那如果我们这个 DataLoader 的速度是每秒钟出一张图片,那这个模型要把所有的这些图片全看完就需要花405年的时间。所以作者感叹说这用的数据实在是太多了,如果能减少一下数据用量,那当然是极好的。那怎么减少这个数据用量呢?简单一点的方式当然就是做数据增强了,那另外最近还有两种比较常见的方式,一种就是用自监督的方式,另一种就是用伪标签的方式这两种方式都能比监督学习有更好的数据利用效率。
  • 作者接下来说的第六个局限性还跟数据有关,但是是跟下游任务的测试数据集有关。他的意思是说虽然我整篇文章都在说 zero shot,在我们整个研发 CLIP 的过程之中我们为了能跟别人去做公平的比较,也为了得到一些回馈。所以我们往往是在整个测试数据集上去不停的做测试,比如说 CLIPImageNet 上的分这么高并不是第一次训练出来分就这么高的。他肯定是测试了很多变体,做了很多超参的调整,最后才定下的这套网络结构和这套超参数。而在整个研发的过程中其实每次都用 ImageNet 的测试集去做了指导,所以这里面已经无形之中就已经带入了偏见了,而且并不是真正的 zero shot 的情况。另外作者还说他们整篇文章里不停的用到27个数据集去做测试,但其实数据集千千万万那为什么只选这27个呢?这27个也不一定就具有代表性,所以整个 CLIP 的研发过程也是跟这27个数据集息息相关的。那最后作者总结了一下,就是说如果能真的再创建一个新的数据集,而这个数据集就是用来测试各种各样的 Zero Shot 的迁移的能力的那就太好了。如果只是像他们现在一样简单的重复使用已有的做有监督训练的数据集,就难免会有局限性。
  • 第七个局限性就是 OpenAI 经常说的局限性了,因为他们的数据都是从网上爬的。不论是图片还是文字,这些爬下来的图片文本对基本是没有经过清洗的、就是既没有被过滤过、也没有被审查过。所以这就导致最后学得的 CLIP 模型很有可能就带了一些社会上的偏见,比如说性别、肤色、宗教,作者这里还专门写了一个第七章去讨论了一下 CLIP 模型有可能带来的这种巨大的社会影响力以及他模型里可能隐藏的偏见。有可能会带来的不当的使用。
  • 最后作者还提到了另外一个局限性,虽然整篇论文他们都在宣传 CLIP 到底有多么的灵活,其实还是有局限性的。因为很多很复杂的任务或者是很复杂的概念其实即使你用语言也无法描述的,如果你能在做下游任务做泛化的时候提供一些训练样本还是非常有帮助的,但可惜 CLIP 模型的提出并不是为了 Few shot 的情况而提出的,也不是为了优化的,所以就导致了一个非常奇怪的现象。就是当给 CLIP 提供了一些训练样本,结果反而还不如直接用 zero-shot,这个就很耐人寻味了。你不给他提供训练样本他反而效果很好,你给他提供一些训练样本他反而效果还差了,那这个跟我们人的学习呢就截然不同了。所以说之后的工作还有很多,怎么能让CLIP 既在这种 Zero Shot 的情况下工作的很好,也能在给他提供一些训练样本的时候,Few Shot 的做的也很好。

7. 结论和总结

结论部分其实写的非常短,因为实在是该讨论的都讨论完了,该做的实验也全做完了,如果这里要把结果再秀一遍也写不下,所以说就只好简单的总结一下了。作者说他们的研究动机就是因为在 nlp 领域现在利用大规模的数据去预训练模型,而且用跟下游任务无关的训练方式,比如说完形填空。nlp 取得了非常革命性的成功,比如就是 OpenAI 他们自己的 gpt 的一系列工作,所以他们就想把 nlp 里的成功复制到其他的领域里去。

然后作者发现在视觉领域里用了这一套思路之后,确实效果也不错,其实是特别好。所以还顺带讨论一下有可能带来的社会影响力,接下来作者用一句话总结了他们的方法,就是在预训练阶段就有了对比学习,利用文本的提示去做 zero shot 的迁移学习。

在这里插入图片描述


8. CLIP预训练demo

这里使用 OpenAI 提供的 notebook 演示 CLIP 的效果,地址为:
https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb

安装 CLIP

pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git
  • 1
  • 2

导入需要的库,PyTorch 版本在 1.7.1及以上:

import numpy as np
import torch
from pkg_resources import packaging

print("Torch version:", torch.__version__)

# Torch version: 1.12.1+cu113
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

下面是加载模型:

import clip

clip.available_models()

'''
['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

'''
Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

下面是图片预处理,即 preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

然后是文本预处理,得到77个 tokens

clip.tokenize("Hello World!")

'''
tensor([[49406,  3306,  1002,   256, 49407,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]], dtype=torch.int32)
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

然后是输入图片和文本对,这里输入的图片文本对为8对:

import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

from collections import OrderedDict
import torch

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# images in skimage to use and their textual descriptions
descriptions = {
    "page": "a page of text about segmentation",
    "chelsea": "a facial photo of a tabby cat",
    "astronaut": "a portrait of an astronaut with the American flag",
    "rocket": "a rocket standing on a launchpad",
    "motorcycle_right": "a red motorcycle standing in a garage",
    "camera": "a person looking at a camera on a tripod",
    "horse": "a black-and-white silhouette of a horse", 
    "coffee": "a cup of coffee on a saucer"
}

original_images = []
images = []
texts = []
plt.figure(figsize=(16, 5))

for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
    name = os.path.splitext(filename)[0]
    if name not in descriptions:
        continue

    image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
  
    plt.subplot(2, 4, len(images) + 1)
    plt.imshow(image)
    plt.title(f"{filename}\n{descriptions[name]}")
    plt.xticks([])
    plt.yticks([])

    original_images.append(image)
    images.append(preprocess(image))
    texts.append(descriptions[name])

plt.tight_layout()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

在这里插入图片描述

下面是得到图像和文本特征:

image_input = torch.tensor(np.stack(images)).cuda()
text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()

with torch.no_grad():
    image_features = model.encode_image(image_input).float()
    text_features = model.encode_text(text_tokens).float()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

计算余弦相似性,可以看到对角线上的相似性值最高:

image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T

count = len(descriptions)

plt.figure(figsize=(20, 14))
plt.imshow(similarity, vmin=0.1, vmax=0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize=18)
plt.xticks([])
for i, image in enumerate(original_images):
    plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):
    for y in range(similarity.shape[0]):
        plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)

for side in ["left", "top", "right", "bottom"]:
  plt.gca().spines[side].set_visible(False)

plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])

plt.title("Cosine similarity between text and image features", size=20)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

在这里插入图片描述
下面进行 zero-shot 分类,这里使用 CIFAR100 数据集,可能是类别比较少的原因,第一张图这里分错了:

from torchvision.datasets import CIFAR100

cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)

text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()

with torch.no_grad():
    text_features = model.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)

text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

plt.figure(figsize=(16, 16))

for i, image in enumerate(original_images):
    plt.subplot(4, 4, 2 * i + 1)
    plt.imshow(image)
    plt.axis("off")

    plt.subplot(4, 4, 2 * i + 2)
    y = np.arange(top_probs.shape[-1])
    plt.grid()
    plt.barh(y, top_probs[i])
    plt.gca().invert_yaxis()
    plt.gca().set_axisbelow(True)
    plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
    plt.xlabel("probability")

plt.subplots_adjust(wspace=0.5)
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/450920
推荐阅读
相关标签
  

闽ICP备14008679号