赞
踩
CLIP (Contrastive Language-Image Pre-training)是一种利用对比学习进行语言、图像预训练的模型,它是OpenAI的一项工作,当时他们的工作还是开源的。对比学习思想是:将正样本拉近,将负样本拉远的这一类算法
如何利用预训练得到的clip模型来做图像分类任务呢?这里就和传图的CV模型是不一样的。
对于CV来说通常是拿到预训练的权重,针对下游任务
通过接一个任务Head来
训练任务头或者Fine-tune整个网络来得到预测输出。
CLIP则设计了一种新的方式,它首先从标签label层面入手, 对数据集可能存在的类名,使用prompt模板,比如对于飞机plane这个类,我们将飞机这个类转换为一个文本如A photo of a plane
,同理将car类,将它转换为文本A photo of a car
然后将所有类别得到的文本信息,输入预训练好的Text Encoder
, 这样对于数据集中的每一个类别都会生成一个向量。这个向量就能代表这一个类别的通用的Embedding。
最后将输入的一张image,它经过Image Encoder
会得到对应的图像Embedding,和每个类别所代表的文本Embedding来计算相似度,找到相似度最大的文本标签类别,就认为该图像就是该类别。
Clip这种图像分类的Idea是一种非常创新的想法,可以从输入和label之间同时经过网络来约束他们之间的距离,而不是只有输入到预测目标这一单向的路径。这一idea也是Lecun大佬提出world model的概念。
总结
clip 这样设计的模型它的最大的能力可以让图像和Text各自的Embedding, 能够在共享的空间下,通过输入一张图像和该图像的描述信息会得到两个相同的Embedding。这样,比如在一个语言大模型中,如果我们想加入图像信息,可以将图像通过编码映射到文本的Embedding中,因为他们之间的Embedding已经对齐了。
这样在训练好clip模型之后,可以直接将Text Encoder
和Image Encoder
拿过来使用,在构建自己的大模型的时候,就不在需要重新训练这两个Encoder了。在后续介绍的多模态大模型比如Flamingo和LLaVA,他们对于图像的编码就直接使用CLIP中训练好的Image Encoder
CLIP 模型结构其实非常简单,针对两个模态Text和Image,分别通过各自Encoder编码(Text Encoder
和Image Encoder
), 得到经过编码后的高纬向量Embedding
, 然后计算相似度
,最终使得匹配的image 和Text对的相似度尽可能大,其他没有匹配到的Embedding相似度尽可能小。
class ViT(nn.Module): def __init__(self,output_dim): # 使用来自timm的VIT模型 self.vit = timm.create_model('vit_small_patch16_224',pretrained=True,num_classes =output_dim) def forward(self,x): return self.vit(x) class TextEncoder(nn.Module): def __init__(self): super(TextEncoder,self).__init__() BERT_LOCAL_PATH ='./bert-base-uncased' self.model = BertModel.from_pretrained(BERT_LOCAL_PATH) self.tokenizer = BertTokenizer.from_pretrained(BERT_LOCAL_PATH) def forward(self,texts): encoded_input = self.tokenizer(texts, return_tensors ='pt',padding= True,truncation =True) outputs = self.model(**encoded_input) return outputs.last_hidden_state[:,0,:]
Image Encoder
: 使用的是ViT模型,为了方便起见使用timm工具来搭建,使用的是vit_small_patch16_224
版本的模型,也可以使用vit-base版本的。得到编码后的embedding维度为output_dimTextEncoder
: 文本编码器使用的是Bert模型,利用transformers
工具进行构建。由于是文本,除了模型本身之外,还需要实例化一个tokenizer
tokenizer
操作,然后将tokenizer后的encoded_input, 输入到TextEncoder
中class CLIP(nn.Module):
def __init__(self, image_output_dim, text_output_dim):
super(CLIP, self).__init__()
self.image_encoder = ViT(image_output_dim)
self.text_encoder = TextEncoder()
# 因为图像和文本emb可能维度不同(图像512,文本768),所以需要对图像和文本的emb再经过一层以将维度持平
self.W_i = nn.Parameter(torch.randn(image_output_dim, text_output_dim))
self.W_t = nn.Parameter(torch.randn(768, text_output_dim)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。