赞
踩
论文地址:[2103.00020] Learning Transferable Visual Models From Natural Language Supervision (arxiv.org)
CLIP的好处为,无监督学习,迁移性强(zero-shot transfer),并且可预测的物体类别无限制。CLIP是用文本作为监督信号来训练的可迁移的视觉模型。CLIP全称为Contrastive Language-Image Pre-training,即基于文本-图像对的预训练方法或模型。CLIP有两个模型分别提取图像和文本的特征:Text Encoder和Image Encoder,其中Text Encoder是使用的模型是text transformer模型,而Image Encoder使用的是CNN模型或vision transformer模型。 CLIP从互联网上搜集了4亿个文本-图像对(作者将其命名为WebImageText,简称为WIT),当做预训练部分的训练集。在下游任务中,可通过文本匹配图像,从而实现zero-shot transfer,并将其放在30个不同的数据集上进行测试,其表现出了强大的迁移能力。
在这里对提取的文本特征和图像特征进行对比学习,先分别对N对图像-文本对进行encoder,得到的两个特征进行线性映射到相同的维度,并进行L2归一化,然后计算文本特征和图像特征的余弦相似性(cosine similarity),最终得到的矩阵,矩阵的对角线元素为正样本(共有N个),其余非对角线元素为负样本(有N2-N个),实现该部分的伪代码如下:
# image_encoder - ResNet or Vision Transformer # text_encoder - CBOW or Text Transformer # I[n, h, w, c] - minibatch of aligned images # T[n, l] - minibatch of aligned texts # W_i[d_i, d_e] - learned proj of image to embed # W_t[d_t, d_e] - learned proj of text to embed # t - learned temperature parameter 可训练的参数 I_f = image_encoder(I) #维度为[n, d_i] 提取得到图像特征 T_f = text_encoder(T) #维度为[n, d_t] 提取得到文本特征 # 对两个特征进行线性投射,得到相同维度的特征,并进行l2归一化 I_e = l2_normalize(np.dot(I_f, W_i), axis=1) #最终得到的维度为[n,d_e] T_e = l2_normalize(np.dot(T_f, W_t), axis=1) # 计算缩放的余弦相似度:[n, n] logits = np.dot(I_e, T_e.T) * np.exp(t) # 对称的对比学习损失:等价于N个类别的cross_entropy_loss labels = np.arange(n) # 对角线元素的labels 0,1,2,3...n-1 loss_i = cross_entropy_loss(logits, labels, axis=0)#第一维的交叉熵损失函数(横向来看) loss_t = cross_entropy_loss(logits, labels, axis=1)#第一维的交叉熵损失函数(纵向来看) loss = (loss_i + loss_t)/2
作者在Text Encoder上选择了包含63M参数的text transformer,在Image Encoder上选择了5个不同大小的ResNet模型(ResNet50,ResNet101,RN50x4,RN50x16和RNx64)和3个不同大小的ViT(ViT-B/32,ViT-B/16,ViT-L/16)模型。设置的epochs大小为32,使用的是AdamW优化器,并且在训练过程中选择了一个比较大的batch_size:32768,由于数据量较大(4亿),所以训练消耗的资源较多,最终作者选择使用ViT-L/14@336的模型效果最好。
zero-shot进行图像分类,即不需要任何的训练数据,就能直接在下游任务中实现分类任务。CV中目前常用的模型是先在数据集上进行预训练,然后进行微调,而训练好的CLIP可直接进行分类任务。
执行zero-shot分类任务的具体操作:
1.根据任务的n个分类标签,得到n个描述该标签的文本,如:“dog” -> “a photo of a dog”,然后将这些文本送入到text encoder,得到n个class text embedding。
2.将要预测的图片送入到image encoder中,得到image embedding。
3.计算图像特征和n个文本特征的余弦相似度,经过softmax,就可以得到每个类别的概率,概率最大的即预测得到的类别。
分类预测部分的代码如下:
import torch import clip from PIL import Image device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/32", device=device) image = preprocess(Image.open("CLIP3.png")).unsqueeze(0).to(device)#unsqueeze(0)表示在0的位置上增加一维,值为1 text = clip.tokenize(["a diagram", "a dog", "a cat","a glasses","a duck"]).to(device) #将标记写上 with torch.no_grad(): image_features = model.encode_image(image) #在这里得到图像的feature text_features = model.encode_text(text) #在这得到text的feature logits_per_image, logits_per_text = model(image, text) #得到余弦分数 probs = logits_per_image.softmax(dim=-1).cpu().numpy() #在最后一层进行一步softmax print("Label probs:", probs) # 输出图片相对于文本的概率值,概率最大的是预测类别
zero-shot中一个比较重要的地方是prompt(文本提示)的生成。CLIP中提出了两个方法分别为Prompt engineering和ensembling。Prompt模板为"A photo of a {label}.“,而作者考虑对原context针对不同的分类任务,设定不同的context,如果是宠物的那么Prompt模板为"A photo of a {label}, a type of pet”.并且对于一个数据集还可以做一些ensemble,比如"A photo of a big {label}.“和"A photo of a small {label}.”。(在之后提出CoOp,可以直接训练产生Prompt,得到了更好的效果)。
CLIP打破了原有的CV领域模型的设计思路,其一次训练得到的参数可应用到所有的分类任务当中,其使用了文本和图片的对应关系,在30多个数据集上效果相当于CNN网络中的Resnet50,并且迁移性强,鲁棒性好。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。