赞
踩
CLIP的英文全称是Contrastive Language-Image Pre-training,即一种基于对比文本-图像对的预训练模型。
CLIP是一种基于对比学习的多模态模型,训练数据是文本—图像对:一张图像和它对应的文本描述,这里希望通过对比学习,模型能够学习到文本-图像对的匹配关系。
假如模型输入的是n对图片-文本对,那么这n对互相配对的图像–文本对是正样本(下图输出特征矩阵对角线上标识蓝色的部位),其它n2-n对样本都是负样本。此时模型的训练过程就是最大化n个正样本的相似度,同时最小化n2-n个负样本的相似度。
相似度是计算文本特征和图像特征的余弦相似性cosine similarity
CLIP包括两个模型:Text Encoder和Image Encoder,其中Text Encoder用来提取文本的特征,可以采用NLP中常用的text transformer模型;而Image Encoder用来提取图像的特征,可以采用常用CNN模型或者vision transformer。
# image_encoder - ResNet or Vision Transformer # text_encoder - CBOW or Text Transformer # I[n, h, w, c] - 输入图片维度 # T[n, l] - 输入文本维度,l表示序列长度 # W_i[d_i, d_e] - learned proj of image to embed # W_t[d_t, d_e] - learned proj of text to embed # t - learned temperature parameter # 分别提取图像特征和文本特征 I_f = image_encoder(I) #[n, d_i] T_f = text_encoder(T) #[n, d_t] # 对两个特征进行线性投射,得到相同维度的特征d_e,并进行l2归一化,保持数据尺度的一致性 # 多模态embedding [n, d_e] I_e = l2_normalize(np.dot(I_f, W_i), axis=1) T_e = l2_normalize(np.dot(T_f, W_t), axis=1) # 计算缩放的余弦相似度:[n, n] logits = np.dot(I_e, T_e.T) * np.exp(t) # symmetric loss function labels = np.arange(n) # 对角线元素的labels loss_i = cross_entropy_loss(logits, labels, axis=0) # image loss loss_t = cross_entropy_loss(logits, labels, axis=1) # text loss loss = (loss_i + loss_t)/2 # 对称式的目标函数
文本采用数字编码,以sot_token_id + encode(text) + eot_token_id方式,其中sot_token_id:49406,
eot_token_id:49407,
中间则为每个单词对应的数字。
整个文本长度限制为77,不足77的部分补0,超过77的部分只取前77个。
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor: if isinstance(texts, str): texts = [texts] context_length = context_length or self.context_length assert context_length, 'Please set a valid context length' # 将texts中每个单词编码成数字,self.sot_token_id:开始的第一个id为49406,self.eot_token_id:结束的最后一个id为49407,中间为每个单词对应的id all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts] # 只取77个单词,不足77个单词的文本补0,超过的的舍去,同时保证最后一个以self.eot_token_id结尾 result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: tokens = tokens[:context_length] # Truncate tokens[-1] = self.eot_token_id result[i, :len(tokens)] = torch.tensor(tokens)
将输入图像大小resize为(224, 224),RandomCrop, 转RGB,归一化。
{'image': Compose(
RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=bicubic, antialias=warn)
<function _convert_to_rgb at 0x7f8771c14790>
ToTensor()
Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)
def encode_text(self, text, normalize: bool = False): cast_dtype = self.transformer.get_cast_dtype() # text: [b, 77]--->[b, 77, 768] 77个字,每个字的维度为512 x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] # 可学习的positional_embedding:[b, 77, 768], x: [b, 77, 768] x = x + self.positional_embedding.to(cast_dtype) # NLD -> LND [b, 77, 768]--->[77, b, 768] x = x.permute(1, 0, 2) # transformer网络学习 x = self.transformer(x, attn_mask=self.attn_mask) # LND -> NLD [77, b, 768]--->[b, 77, 768] x = x.permute(1, 0, 2) # layerNorm,[b, 77, 768] x = self.ln_final(x) # x:提取每个文本中,最后一个结束符号(eot embedding)的特征: [batch, 512], # _: 全部文本的特征:[batch, 77, 768] x, _ = text_global_pool(x, text, self.text_pool_type) if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): x = self.text_projection(x) else: # [b, 768]@[768, 768]--->[b, 768], self.text_projection:可学习参数 x = x @ self.text_projection # 归一化 return F.normalize(x, dim=-1) if normalize else x def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'): if pool_type == 'first': pooled, tokens = x[:, 0], x[:, 1:] elif pool_type == 'last': pooled, tokens = x[:, -1], x[:, :-1] elif pool_type == 'argmax': # take features from the eot embedding (eot_token is the highest number in each sequence) assert text is not None # pooled:提取每个文本中,最后一个结束符号(eot embedding)的特征: [batch, 768], tokens: 全部文本的特征:[b, 77, 768] pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x else: pooled = tokens = x return pooled, tokens
def forward(self, x: torch.Tensor): # [b, 3, 224, 224]--->[b, 1024, 16, 16] x = self.conv1(x) # [b, 1024, 16, 16]--->[b, 1024, 256] x = x.reshape(x.shape[0], x.shape[1], -1) # [b, 1024, 256]--->[b, 256, 1024] x = x.permute(0, 2, 1) # 图像的每个grid嵌入一个类别,x:[b, 256 + 1, 1024] x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) # 嵌入位置编码,x:[b, 256 + 1, 1024] x = x + self.positional_embedding.to(x.dtype) # patch_dropout, x:[b, 256 + 1, 1024] x = self.patch_dropout(x) # LayerNorm处理 x:[b, 256 + 1, 1024] x = self.ln_pre(x) # NLD -> LND [b, 256 + 1, 1024]---> [256 + 1, b, 1024] x = x.permute(1, 0, 2) # transformer网络处理 x = self.transformer(x) # LND -> NLD [256 + 1, b, 1024]--->[b, 256 + 1, 1024] x = x.permute(1, 0, 2) if self.attn_pool is not None: if self.attn_pool_contrastive is not None: # This is untested, WIP pooling that should match paper x = self.ln_post(x) # TBD LN first or separate one after each pool? tokens = self.attn_pool(x) if self.attn_pool_type == 'parallel': pooled = self.attn_pool_contrastive(x) else: assert self.attn_pool_type == 'cascade' pooled = self.attn_pool_contrastive(tokens) else: # this is the original OpenCLIP CoCa setup, does not match paper x = self.attn_pool(x) x = self.ln_post(x) pooled, tokens = self._global_pool(x) elif self.final_ln_after_pool: pooled, tokens = self._global_pool(x) pooled = self.ln_post(pooled) else: # layernorm:[b, 256+1, 1024]--->[b, 256+1, 1024] x = self.ln_post(x) # pooled: 类别token:[b, 1024] tokens:图像token:[b, 256, 1024] pooled, tokens = self._global_pool(x) # pooled: [4, 1024]@[1024, 768]--->[4, 768] if self.proj is not None: pooled = pooled @ self.proj return pooled
def get_logits(self, image_features, text_features, logit_scale): # 计算图像和文本的余弦相似度 logits_per_image = logit_scale * image_features @ text_features.T # 计算文本和图像的余弦相似度 logits_per_text = logit_scale * text_features @ image_features.T return logits_per_image, logits_per_text def forward(self, image_features, text_features, logit_scale, output_dict=False): device = image_features.device # 假设有N个图像-文本对: logits_per_image: [N, N], logits_per_text: [N, N] logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) # 假设有N个图像-文本对:labels=[0, 1, 2,....N] labels = self.get_ground_truth(device, logits_per_image.shape[0]) # 总损失 = (图像维度的损失 + 文本维度的损失)/ 2 total_loss = ( F.cross_entropy(logits_per_image, labels) + # 图像维度的损失 F.cross_entropy(logits_per_text, labels) # 文本维度的损失 ) / 2 return {"contrastive_loss": total_loss} if output_dict else total_loss
传统的视觉模型需要在新的数据集上进行微调,而clip可以直接实现zero-shot的图像分类,即不需要任何训练数据,就能在某个具体下游任务上实现分类。
操作步骤:
# 首先生成每个类别的文本描述 labels = ["dog", "cat", "bird", "person", "mushroom", "cup"] text_descriptions = [f"A photo of a {label}" for label in labels] text_tokens = clip.tokenize(text_descriptions).cuda() # 提取文本特征 with torch.no_grad(): text_features = model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True) # 读取图像 original_images = [] images = [] texts = [] for label in labels: image_file = os.path.join("images", label+".jpg") name = os.path.basename(image_file).split('.')[0] image = Image.open(image_file).convert("RGB") original_images.append(image) images.append(preprocess(image)) texts.append(name) image_input = torch.tensor(np.stack(images)).cuda() # 提取图像特征 with torch.no_grad(): image_features = model.encode_image(image_input).float() image_features /= image_features.norm(dim=-1, keepdim=True) # 计算余弦相似度(未缩放) similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T # 注意这里要对相似度进行缩放,余弦相似度范围是0-1,不太适合直接作为交叉熵的logits,因为一般logits都是没有上限的,这样区分度会更好一些。所以模型训练增加了一个可训练的温度参数来放大 logit_scale = np.exp(model.logit_scale.data.item()) text_probs = (logit_scale * image_features @ text_features.T).softmax(dim=-1) top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
import numpy as np import torch import clip from PIL import Image device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/32", device=device) image = preprocess(Image.open("./run.jpeg")).unsqueeze(0).to(device) text = clip.tokenize(["plane", "dog", "human", "runner"]).to(device) with torch.no_grad(): image_features = model.encode_image(image) text_features = model.encode_text(text) logits_image, logits_text = model(image, text) probs = logits_image.softmax(dim=-1).cpu().numpy() # 和图片最相似的文本就是图片的类别 print("Label probs:", probs) Label probs: [[1.508e-03 5.999e-04 1.063e-02 9.873e-01]]
prompt learning的核心是通过构建合适prompt(提示)来使预训练模型能够直接应用到下游任务中。
推理时,只使用类别标签作为文本描述效果并不够好,原因有二:
词语存在歧义性
如果我们直接采用类别标签作为文本描述,那么很多文本就是一个单词,缺少具体的上下文,并不能很好的描述图片内容。
比如在做物体检测时,有一个类别是remote(遥控器)。但如果直接喂给文本编码器,很可能被模型认为是遥远的意思。
同一个词语在不同数据集中所表示的意思可能有所不同。例如在 Oxford-IIIT Pets 数据集中,boxer指的是狗的一个种类,在其他数据集中指的是拳击运动员。
所以 CLIP预训练时,用来描述图片内容的文本是一个句子,比如A photo of {label}。这里的label就只能是名词,一定程度上消除了歧义性。
使推理和预训练时保持一致(消除distribution gap)。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。