当前位置:   article > 正文

【计算机视觉】使用 python-notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类_clip提取图像特征计算图片相似度

clip提取图像特征计算图片相似度


项目地址:

https://github.com/biluko/Paper_Codes_for_fun/tree/master/CLIP

   
   
  • 1
  • 1

在这里插入图片描述

一、CLIP 模型

CLIP(Contrastive Language-Image Pretraining)是由OpenAI开发的一个深度学习模型,用于处理图像和文本之间的联合表示。它的目标是将图像和文本嵌入到一个共享的向量空间中,使得相似的图像和文本在这个空间中距离较近,而不相似的图像和文本距离较远。

CLIP模型的特点在于它可以通过对图像和文本之间进行对比学习,来学习到一个通用的特征表示。在训练过程中,CLIP通过最大化相似图像和文本的相似性,并最小化不相似图像和文本的相似性来调整模型参数。这种对比学习的方法使得CLIP能够在多个任务上进行迁移学习,如图像分类、文本分类、图像生成等。

CLIP模型的应用非常广泛。通过将图像和文本映射到共享的向量空间,CLIP可以实现图像和文本之间的多模态检索和匹配。例如,通过将一张图片和一个描述该图片内容的文本查询进行编码,可以计算它们在向量空间中的距离,并找到与之相似的图片或文本。这为图像搜索、商品推荐、智能问答等应用提供了新的可能性。

CLIP模型的优势在于它不需要大量标注的训练数据,而是通过对比学习来学习通用的特征表示。这使得CLIP在跨领域和跨语言的应用上具有良好的泛化能力。此外,CLIP还能够理解和生成自然语言描述的图像,以及生成图像描述的文本,具备了一定的语义理解和生成能力。

总之,CLIP是一个强大的深度学习模型,能够将图像和文本嵌入到共享的向量空间中,并实现多模态的检索和匹配。它在图像和文本处理、多模态应用以及迁移学习等方面有着广泛的应用前景。

二、准备

包括下载 CLIP 依赖和将设置改为 GPU:

! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

   
   
  • 1
  • 2
  • 1
  • 2

在这里插入图片描述

import numpy as np
import torch
from pkg_resources import packaging
  • 1
  • 2
  • 3

print(“Torch version:”, torch.version)

在这里插入图片描述

三、加载模型

展示可选择的不同图片特征提取器:

import clip
clip.available_models()
  • 1
  • 2

在这里插入图片描述
加载模型和图片处理器:

model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size
  • 1
  • 2
  • 3
  • 4
  • 5

print(“模型参数:”, f"{ np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print(“输入图片尺寸:”, input_resolution)
print(“文本长度:”, context_length)
print(“词表大小:”, vocab_size)

在这里插入图片描述

四、查看图片处理器

这里调整图片大小 224 × 224 ,中心裁剪,然后使用均值和标准差进行归一化,最后输出tensor向量:
preprocess
  • 1

在这里插入图片描述

五、文本分词

clip.tokenize("Hello World!")

 
 
  • 1
  • 1

在这里插入图片描述

六、输入图片和文本,并可视化

import os 
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

%matplotlib inline
%config InlineBackend.figure_format=‘retina’

# images in skimage to use and their textual descriptions
descriptions = {
“page”: “a page of text about segmentation”,
“chelsea”: “a facial photo of a tabby cat”,
“astronaut”: “a portrait of an astronaut with the American flag”,
“rocket”: “a rocket standing on a launchpad”,
“motorcycle_right”: “a red motorcycle standing in a garage”,
“camera”: “a person looking at a camera on a tripod”,
“horse”: “a black-and-white silhouette of a horse”,
“coffee”: “a cup of coffee on a saucer”
}

original_images=[]
images=[]
texts=[]
plt.figure(figsize=(16,5))

for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
    name = os.path.splitext(filename)[0]
    if name not in descriptions:
        continue

    image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
  
    plt.subplot(2, 4, len(images) + 1)
    plt.imshow(image)
    plt.title(f"{filename}\n{descriptions[name]}")
    plt.xticks([])
    plt.yticks([])

    original_images.append(image)
    images.append(preprocess(image))
    texts.append(descriptions[name])

plt.tight_layout()

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

在这里插入图片描述

八、计算 cosine 相似度

image_features /= image_features.norm(dim = -1,keepdim = True)
text_features /= text_features.norm(dim = -1,keepdim = True)
  • 1
  • 2

similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T

count = len(descriptions)
plt.figure(figsize = (20, 14))
plt.imshow(similarity, vmin = 0.1, vmax = 0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize = 18)
plt.xticks([])
for i, image in enumerate(original_images):
    plt.imshow(image, extent = (i - 0.5, i + 0.5, -1.6, -0.6), origin = "lower")
for x in range(similarity.shape[1]):
    for y in range(similarity.shape[0]):
        plt.text(x, y, f"{similarity[y, x]:.2f}", ha = "center", va = "center", size = 12)

for side in ["left", "top", "right", "bottom"]:
  plt.gca().spines[side].set_visible(False)

plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])

plt.title("Cosine similarity between text and image features", size = 20)

在这里插入图片描述

九、零样本进行图片分类

数据集CIFAR100,就是使用相似度计算得分,然后softmax一下:

from torchvision.datasets import CIFAR100

cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform = preprocess, download = True)

在这里插入图片描述
加上prompt 提示模板进行分类:

text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()

   
   
  • 1
  • 2
  • 1
  • 2
  • 1
  • 2
  • 3
  • 4
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

计算相似度得分:

with torch.no_grad():
    text_features = model.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim = -1, keepdim = True)
  • 1
  • 2
  • 3

text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)

可视化结果:

plt.figure(figsize=(16, 16))
  • 1

for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis(“off”)

plt<span class="token punctuation">.</span>subplot<span class="token punctuation">(</span><span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span> <span class="token operator">*</span> i <span class="token operator">+</span> <span class="token number">2</span><span class="token punctuation">)</span>
y <span class="token operator">=</span> np<span class="token punctuation">.</span>arange<span class="token punctuation">(</span>top_probs<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>grid<span class="token punctuation">(</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>barh<span class="token punctuation">(</span>y<span class="token punctuation">,</span> top_probs<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>gca<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>invert_yaxis<span class="token punctuation">(</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>gca<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>set_axisbelow<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>yticks<span class="token punctuation">(</span>y<span class="token punctuation">,</span> <span class="token punctuation">[</span>cifar100<span class="token punctuation">.</span>classes<span class="token punctuation">[</span>index<span class="token punctuation">]</span> <span class="token keyword">for</span> index <span class="token keyword">in</span> top_labels<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>xlabel<span class="token punctuation">(</span><span class="token string">"probability"</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

plt.subplots_adjust(wspace = 0.5)
plt.show()

在这里插入图片描述

十、编写函数进行图片分类

输入图片和供选择标签进行分类:

def show_result(image, probs, labels, label_name):
  plt.figure()
  plt.subplot(1, 2, 1)
  plt.imshow(image)
  plt.axis("off")
  • 1
  • 2
  • 3
  • 4
  • 5

plt.subplot(1, 2, 2)
y = np.arange(probs.shape[-1])
plt.grid()
plt.barh(y, probs[0])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [label_name[index] for index in labels[0].numpy()])
plt.xlabel(“probability”)

plt.subplots_adjust(wspace = 0.5)
plt.show()

def clip_classifier(image_path, choice_label, top_k = 5):
# top_k小于choice_label数
if top_k > len(choice_label):
raise Exception(‘top_k大于候选标签数’)

# 读取图片
image = Image.open(image_path).convert(“RGB”)
# 输入特征
text_descriptions = [f"This is a photo of a { label}" for label in choice_label]
text_tokens = clip.tokenize(text_descriptions).cuda()
image_input = preprocess(image)
image_input = image_input.clone().detach().cuda()

with torch.no_grad():
image_features = model.encode_image(image_input.unsqueeze(0)).float()
text_features = model.encode_text(text_tokens).float()

image_features <span class="token operator">/=</span> image_features<span class="token punctuation">.</span>norm<span class="token punctuation">(</span>dim <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> keepdim <span class="token operator">=</span> <span class="token boolean">True</span><span class="token punctuation">)</span>
text_features <span class="token operator">/=</span> text_features<span class="token punctuation">.</span>norm<span class="token punctuation">(</span>dim <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> keepdim <span class="token operator">=</span> <span class="token boolean">True</span><span class="token punctuation">)</span>
  • 1
  • 2

#相似度得分
text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)
show_result(image, top_probs, top_labels, choice_label)
十一、测试自己的函数

clip_classifier('R.jpg',['Luffy','pig','boy','girl','one piece','bleach','black','man','cartoon','red','detector'])

 
 
  • 1
  • 1

在这里插入图片描述

clip_classifier('Holmes.jpg',['Holmes','pig','boy','girl','one piece','bleach','black','man','cartoon','red','detector'])

 
 
  • 1
  • 1

在这里插入图片描述

十二、编写函数对多图片进行分类

def clip_classifier_m(image_dir, choice_label, top_k = 5):
  # image_dir不为文件夹
  if not os.path.isdir(image_dir):
    raise Exception(image_dir + ' 应该为一个图片文件夹')
  • 1
  • 2
  • 3
  • 4

# top_k小于choice_label数
if top_k > len(choice_label):
raise Exception(‘top_k大于候选标签数’)

#读取图片
original_images = []
images = []

for filename in [filename for filename in os.listdir(image_dir) if filename.endswith(“.png”) or filename.endswith(“.jpg”)]:
image = Image.open(os.path.join(image_dir, filename)).convert(“RGB”)

original_images<span class="token punctuation">.</span>append<span class="token punctuation">(</span>image<span class="token punctuation">)</span>
images<span class="token punctuation">.</span>append<span class="token punctuation">(</span>preprocess<span class="token punctuation">(</span>image<span class="token punctuation">)</span><span class="token punctuation">)</span>
  • 1
  • 2

# 输入特征
text_descriptions = [f"This is a photo of a { label}" for label in choice_label]
text_tokens = clip.tokenize(text_descriptions).cuda()
image_input = torch.tensor(np.stack(images)).cuda()
with torch.no_grad():
image_features = model.encode_image(image_input).float()
text_features = model.encode_text(text_tokens).float()

image_features <span class="token operator">/=</span> image_features<span class="token punctuation">.</span>norm<span class="token punctuation">(</span>dim <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> keepdim <span class="token operator">=</span> <span class="token boolean">True</span><span class="token punctuation">)</span>
text_features <span class="token operator">/=</span> text_features<span class="token punctuation">.</span>norm<span class="token punctuation">(</span>dim <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> keepdim <span class="token operator">=</span> <span class="token boolean">True</span><span class="token punctuation">)</span>
  • 1
  • 2

# 相似度得分
text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)
show_result_m(original_images, top_probs, top_labels, choice_label)

def show_result_m(images, probs, labels, label_name):
length = len(images)
num_row = length // 2

plt.figure(figsize = (16, 16))

for i, image in enumerate(images):
plt.subplot(num_row, 4, 2 * i + 1)
plt.imshow(image)
plt.axis(“off”)

plt<span class="token punctuation">.</span>subplot<span class="token punctuation">(</span>num_row<span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span> <span class="token operator">*</span> i <span class="token operator">+</span> <span class="token number">2</span><span class="token punctuation">)</span>
y <span class="token operator">=</span> np<span class="token punctuation">.</span>arange<span class="token punctuation">(</span>probs<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>grid<span class="token punctuation">(</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>barh<span class="token punctuation">(</span>y<span class="token punctuation">,</span> probs<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>gca<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>invert_yaxis<span class="token punctuation">(</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>gca<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>set_axisbelow<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>yticks<span class="token punctuation">(</span>y<span class="token punctuation">,</span> <span class="token punctuation">[</span>label_name<span class="token punctuation">[</span>index<span class="token punctuation">]</span> <span class="token keyword">for</span> index <span class="token keyword">in</span> labels<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>xlabel<span class="token punctuation">(</span><span class="token string">"probability"</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

plt.subplots_adjust(wspace = 1)
plt.show()

clip_classifier_m('img',['Luffy','pig','boy','girl','one piece','bleach','black','man','cartoon','red','Holmes'])

 
 
  • 1
  • 1

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/439265
推荐阅读
相关标签
  

闽ICP备14008679号