当前位置:   article > 正文

计算机视觉中的语义相似性搜索

semantic similarity

caba2a6ef2fb8b623d2bfb8213731c06.jpeg

自人类文明诞生以来,高达90%的数据是在过去两年中产生的!

随着社交媒体和物联网(IoT)等数字技术以及5G等速度更快的无线通信技术的普及,数据创建速度不断提高。然而,创建的大多数新数据都是“非结构化的”,例如文本、图像、音频和视频[来源]。

非结构化数据之所以得名,是因为它不像由行和列组成的表那样具有固有的结构。相反,非结构化数据包含多种可能格式之一的信息。例如,电子商务图像、客户评论、社交媒体帖子、监控视频、语音命令等都是不遵循传统表格数据格式的丰富信息源。

人工智能(AI)和机器学习(ML)技术的最新进展创造了一种通过使用“嵌入”以可伸缩的方式从非结构化数据源中提取有用信息的方法,这种方法将非结构化数据转换为嵌入式数据并将其存储在向量数据库(如Milvus)上,实现了几年前难以想象的几个优秀应用程序。

一些示例应用程序包括视觉图像搜索、语义文本搜索、推荐引擎、打击误报、药物发现等!

在这篇文章中,我们将讨论以下内容。

  1. 什么是嵌入

  2. 使用Kaggle API加载一些数据

  3. 展平原始像素值

  4. 基于分类目标的预训练卷积神经网络

  5. 基于度量学习目标的预训练卷积神经网络

  6. 基于度量学习目标的预训练图文多模态神经网络

  7. 结论

在这篇文章中,我们将使用Kaggle提供的Digikala产品颜色分类数据集来构建一个简单的基于电子商务图像的类似产品搜索服务。该数据集是根据GPL 2许可证授权的。

什么是嵌入

我们的计算机无法像人类那样直接理解图像或文本。然而,计算机很擅长理解数字!

因此,为了帮助我们的计算机理解图像或文本的内容,我们需要将它们转换为数字表示。例如,如果我们考虑图像用例,我们本质上是将图像的上下文和场景“编码”或“嵌入”到向量形式的一系列数字中。

“嵌入”向量是图像数据的数字表示,以便计算机能够理解图像的上下文和场景。

fad1aca16d799465bb36a9469f5a2007.png

几个Python库允许我们从图像生成嵌入。通常,我们可以将这些库分为两大类。

  1. 提供带有预训练模型的现成API的库:对于许多涉及日常对象图像的真实问题,我们可能不需要训练任何模型。相反,我们可以依赖世界各地研究人员开放的许多高质量的预训练模型。研究人员已经训练这些模型来识别和聚类ImageNet数据集中的对象。

  2. 允许我们训练和微调模型的库:顾名思义,对于这些模型,我们可以从头开始提供数据和训练模型,或者专门针对我们的用例微调预训练的模型。如果预训练好的模型还没有为我们的问题域提供良好的嵌入,那么我们只需要沿着这条路走下去。

让我们看看这篇文章中的几个库。

但首先,让我们加载一些图像数据,以定性地评估相似性搜索应用程序中的嵌入。

加载一些数据

我们首先需要加载一些图像数据来测试各种嵌入策略。在这篇文章中,我们使用Kaggle提供的Digikala产品颜色分类数据集。该数据集包含超过6K个电子商务产品的图像,非常适合测试基于电子商务图像的类似产品搜索服务。

步骤1:设置Kaggle环境
  1. 在kaggle上创建帐户

  2. 单击你的个人资料图片,然后从下拉菜单中单击“帐户”。

  3. 向下滚动至“API”部分。

  4. 单击下图所示的“创建新API令牌”按钮,下载一个新的令牌,作为一个JSON文件,其中包含用户名和API密钥。

  5. 如果你使用的是macOS或Linux,请将JSON文件复制到~/.kaggle/目录。在Windows系统上,转到根目录,然后转到.kaggle文件夹,并将下载的文件复制到此文件夹。如果.kaggle目录不存在,请创建它并将JSON文件复制到其中。

77737970569a9888a88ae54df1c3b9dc.png
步骤2:从Kaggle下载数据

我们将使用Anaconda来管理此项目的虚拟环境。你可以从这里安装Anaconda。下载并安装Anaconda后,我们可以建立一个名为semantic_similarity的新环境,安装必要的库,如kaggle和pandas,并从kaggle下载整个数据集。如果不想使用Anaconda,还可以使用python的venv为该项目创建和管理虚拟环境。

  1. # Create a directory for notebooks and another to download data 
  2. mkdir -p semantic_similarity/notebooks semantic_similarity/data/cv 
  3. # CD into the data directory
  4. cd semantic_similarity/data/cv
  5. # Create and activate a conda environment
  6. conda create -n semantic_similarity python=3.8
  7. conda activate semantic_similarity
  8. ## Create Virtual Environment using venv if not using conda
  9. # python -m venv semantic_similarity
  10. # source semantic_similarity/bin/activate
  11. # Pip install the necessary libraries
  12. pip install jupyterlab kaggle pandas matplotlib scikit-learn tqdm ipywidgets
  13. # Download data using the kaggle API
  14. kaggle datasets download -d masouduut94/digikala-color-classification
  15. # Unzip the data into a fashion/ directory
  16. unzip digikala-color-classification.zip -d ./fashion
  17. # Delete the Zip file
  18. rm digikala-color-classification.zip

该数据包含超过6K张各种电子商务产品的图像。我们可以在下图中看到数据集中的一些示例图像。正如你所注意到的,该数据集包括各种时尚产品,如男装、女装、包、珠宝、手表等。

2e9b6ed5039315f88d3dd673920a7c1a.png
步骤3:将所有图像从每个文件夹移动到父文件夹

让我们在semantic_similarity/notebooks目录中创建一个新的jupyter笔记本,以测试各种嵌入策略。首先,让我们导入必要的库。

  1. from matplotlib import pyplot as plt
  2. import numpy as np
  3. import os
  4. import pandas as pd
  5. from PIL import Image
  6. from random import randint
  7. import shutil
  8. from sklearn.metrics.pairwise import cosine_similarity
  9. import sys 
  10. from tqdm import tqdm
  11. tqdm.pandas()

下载的图像位于多个子文件夹中。让我们将它们全部移动到主父目录中,以便轻松获取它们的所有路径。

  1. def move_to_root_folder(root_path, cur_path):
  2.     
  3.     # Code from https://stackoverflow.com/questions/8428954/move-child-folder-contents-to-parent-folder-in-python
  4.     for filename in os.listdir(cur_path):
  5.         
  6.         if os.path.isfile(os.path.join(cur_path, filename)):
  7.             shutil.move(os.path.join(cur_path, filename), os.path.join(root_path, filename))
  8.             
  9.         elif os.path.isdir(os.path.join(cur_path, filename)):
  10.             move_to_root_folder(root_path, os.path.join(cur_path, filename))
  11.             
  12.         else:
  13.             sys.exit("Should never reach here.")
  14.     # remove empty folders
  15.     if cur_path != root_path:
  16.         os.rmdir(cur_path)
  17.         
  18. move_to_root_folder(root_path='../data/cv/fashion', cur_path='../data/cv/fashion')
步骤4:将图像路径加载到pandas数据帧中

接下来,让我们将所有图像文件路径的列表加载到pandas数据帧中。

  1. # Path to all the downloaded images
  2. img_path = '../data/cv/fashion'
  3. # Find list of all files in the path
  4. images = [
  5.     f'../data/cv/fashion/{f}' 
  6.     for f in os.listdir(img_path) 
  7.     if os.path.isfile(os.path.join(img_path, f))
  8. ]
  9. # Load the file names to a dataframe
  10. image_df = pd.DataFrame(images, columns=['img_path'])
  11. print(image_df.shape)
  12. image_df.head()
bdb7c580b6999e9ff0f95eeabc137942.png

生成嵌入的策略

计算机视觉技术的最新进展开辟了许多将图像转换为数字嵌入的方法。让我们看看其中的一些。

  1. 展平原始像素值

  2. 基于分类目标的卷积神经网络预训练

  3. 基于度量学习目标的卷积神经网络预训练

  4. 基于度量学习目标的图文多模态神经网络预训练

展平原始像素值

彩色图像由三维像素阵列组成。第一个维度是图像的高度,第二个维度是图像的宽度,最后的第三个维度是颜色通道,统称为RGB,包含红色、绿色和蓝色,如下图所示。每个像素的值是0到255之间的整数,255是可能的最高强度。

因此,(0,0,0)的RGB值是完全黑暗或纯黑色像素,并且(255,255,255)是完全饱和的纯白色像素。我们图像中可见的所有其他颜色都是由这三个基本RGB值的各种组合组成的。

RapidTables网站上的RGB颜色代码图表允许你选择任何颜色来查看其RGB值,可以点击以下链接进行尝试:

https://www.rapidtables.com/web/color/RGB_Color.html

a351f7cc9f81b6cd67ef0e083638330d.png

如果图像是三维数组格式的一系列数字,则使用重塑操作将其转换为一维向量非常简单,如下图所示。我们还可以通过将每个像素的值除以255来规范化。我们将在代码中执行此操作。

44012515b38580896b3708cd20eddb0b.png
  1. def flatten_pixels(img_path):
  2.     # Load the image onto python
  3.     img = Image.open(img_path).convert('RGB')
  4.     
  5.     # Reshape the image to 1D and normalize the values
  6.     flattened_pixels = np.array(img).reshape(-1)/255.
  7.     
  8.     return flattened_pixels
  9. # Apply the transformation to the dataframe
  10. # Warning! Running only on a subset 1K rows of the data,
  11. # Your computer might crash if you run on the entire dataset! 
  12. # Better don’t run it. We have much better ways to generate embeddings!
  13. pixels_df = image_df.sample(1_000).reset_index(drop=True).copy()
  14. pixels_df['flattened_pixels'] = pixels_df['img_path'].progress_apply(flatten_pixels)
此方法的问题

虽然这种方法易于理解和实现,但这种将图像“嵌入”到向量中的方法存在一些严重的缺点。

  1. 巨大的向量:我们从Kaggle下载的图像非常小[224 x 224 x 3],对应于[Height x Width x Channels],将此3D阵列转换为1D向量将得到大小为150528的向量!对于如此小的图像,这是一个巨大的向量!在整个数据集上生成此向量时,我的计算机崩溃了好几次。最后我只在一个较小的子集(1K行)上运行它来说明这个方法。

  2. 具有大量白色的稀疏向量:在视觉上检查时装数据集中的图像时,我们会注意到图片中有很大的白色区域。因此,这个150528元素向量的许多元素是值255(对于白色),并且没有添加任何与图像中的对象相关的信息。换句话说,这种“嵌入”方案不能有效地对图像对象进行编码,而是包含大量无用的空白。

  3. 缺乏局部结构:最后,直接展平图像会丢失图片的所有局部结构。例如,我们通过眼睛、耳朵、鼻子和嘴巴的相对位置来识别人脸图像。这些是各种各样的“特征”级别的信息,如果我们一次只看一行像素,就会完全忽略这些信息。这种损失的影响是,一张倒置的脸与一张右侧朝上的脸有着非常不同的嵌入,即使这两张脸都是同一张人脸的照片!

随着基于卷积神经网络CNN和Transformer结构的新型神经网络结构的出现,我们基本上克服了这些问题。这篇文章的其余部分将深入探讨如何使用这些神经网络将我们的图像转换为嵌入。

基于分类目标的卷积神经网络预训练

也许最著名的计算机视觉任务之一就是将图像分为不同的类别。通常,对于这项任务,我们将使用CNN模型(如ResNet)作为编码器,将图像转换为向量,然后将该向量通过多层感知器(MLP)模型来确定图像的类别,如下图所示。

研究人员将使用交叉熵损失对CNN+MLP模型进行训练,以准确分类图像类别。

e77de1c090e03313c2a94b545096000f.png

这种方法提供了最先进的精确度,甚至超过了大多数人的能力。在训练这样一个模型后,我们可以去掉MLP层,直接将CNN编码器的输出作为每个输入图像的嵌入向量。

事实上,我们不需要为许多现实世界的问题从头开始训练我们自己的CNN模型。相反,我们直接下载并使用已经训练过的模型来识别日常对象,例如ImageNet数据集中的类别。

Towhee是一个python库,它可以使用这些预训练好的模型快速生成嵌入。让我们看看如何做到这一点。

Towhee管道

Towhee是一个python库,提供了非常易于使用的嵌入生成管道。我们可以使用towhee将图像转换为嵌入,代码不到五行!首先,让我们在终端窗口中使用pip安装towhee。

  1. # Activate the conda environment if not already done so
  2. # conda activate semantic_similarity
  3. pip install towhee torch torchvision

接下来,在Jupyter笔记本单元中,让我们导入库并实例化一个管道对象。

  1. from towhee import pipeline
  2. embedding_pipeline = pipeline('image-embedding')

接下来,让我们在一行代码中使用管道将图像转换为嵌入!嵌入管道的输出有一些额外的维度,我们可以使用np.squeeze去除这些维度。

  1. image_df['towhee_img_embedding'] = image_df['img_path'].progress_apply(lambda x: np.squeeze(embedding_pipeline(x)))
  2. image_df.head()
c67c4e4ce2cdeda1f813e80a75645ddf.png

在继续之前,让我们创建一个helper函数,该函数接受嵌入列的名称、用作查询图像的数据帧索引以及要搜索的类似图像的k个数。

该函数计算查询图像的嵌入与数据帧中所有其他图像的嵌入之间的余弦相似度,以找到前k个最相似的图像并显示它们。

  1. def plot_similar(df, embedding_col, query_index, k_neighbors=5):
  2.     '''Helper function to take a dataframe index as input query and display the k nearest neighbors
  3.     '''
  4.     
  5.     # Calculate pairwise cosine similarities between query and all rows
  6.     similarities = cosine_similarity([df[embedding_col][query_index]], df[embedding_col].values.tolist())[0]
  7.     
  8.     # Find nearest neighbor indices
  9.     k = k_neighbors+1
  10.     
  11.     nearest_indices = np.argpartition(similarities, -k)[-k:]
  12.     
  13.     nearest_indices = nearest_indices[nearest_indices != query_index]
  14.     
  15.     # Plot input image
  16.     img = Image.open(df['img_path'][query_index]).convert('RGB')
  17.     
  18.     plt.imshow(img)
  19.     
  20.     plt.title(f'Query Product.\nIndex: {query_index}')
  21.     
  22.     # Plot nearest neighbors images
  23.     fig = plt.figure(figsize=(20,4))
  24.     
  25.     plt.suptitle('Similar Products')
  26.     
  27.     for idx, neighbor in enumerate(nearest_indices):
  28.         
  29.         plt.subplot(1len(nearest_indices), idx+1)
  30.         
  31.         img = Image.open(df['img_path'][neighbor]).convert('RGB')
  32.         
  33.         plt.imshow(img)
  34.         
  35.         plt.title(f'Cosine Sim: {similarities[neighbor]:.3f}')
  36.         
  37.     plt.tight_layout()

我们现在可以通过查询数据帧中的随机图像并使用上述辅助函数显示k个最相似的图像来测试towhee嵌入的质量。

如下图所示,towhee嵌入非常准确,我们每次查询都会从包含多个不同产品(如连衣裙、手表、包和配件)的整套图像中找到类似的图片!

考虑到我们仅用三行代码就生成了这些嵌入,这更令人印象深刻!

  1. plot_similar(df=image_df,
  2.              embedding_col='towhee_img_embedding'
  3.              query_index=randint(0len(image_df)), # Query a random image
  4.              k_neighbors=5)
995d5edbff7b3ab2550a8e5b3fafd5de.png

从结果中,我们可以得出结论,towhee是快速生成相似性搜索应用程序嵌入的良好起点。

然而,我们没有明确地训练这些模型,以确保相似的图像具有彼此相同的嵌入。因此,在相似性搜索的上下文中,来自此类模型的嵌入对于所有用例可能都不是最准确的。

你现在可能会问的一个自然问题是,“是否有一种方法可以训练模型,使相似的图像具有彼此相似的嵌入?”谢天谢地,有!

基于度量学习目标的卷积神经网络预训练

进入度量学习,这是生成嵌入的最有希望的方法之一,特别是对于相似性搜索应用程序。在度量学习的最基本层面上,

  1. 我们使用神经网络(如CNN或Transformer网络)将图像转换为嵌入。

  2. 我们构造这些嵌入,以便语义相似的图像彼此靠近,而不同的图像则相距更远。

acdfdd9ef5f9e30764f885576c12356e.png

训练度量学习模型需要在数据处理方式和模型训练方式方面进行创新。

  1. 数据:在度量学习中,对于每个称为“锚点”图像的源图像,我们需要至少一个称为“正样本”的类似图像我们还需要第三个图像,称为“负样本”,以改进嵌入表示。在最近针对每个源图像的度量学习方法中,我们使用各种数据增强综合生成“锚点”和“正样本”图像,如下图所示。

  2. 模型:度量学习模型大多具有暹罗网络体系结构。锚点图像、正图像和负图像依次通过相同的模型生成嵌入,然后使用特殊的损失函数进行比较。其中一个损失函数称为对比损失,该模型的目标是将锚点图像和正面图像的嵌入移动得更近,使它们之间的距离接近0。相反,该模型旨在将锚点和负样本移动得更远,以便它们之间的距离更大。

61b9086a738f7182aa1f0ce3c827a050.png

在用这种方法训练模型后,我们可以通过数学计算嵌入向量之间的距离来发现任意两幅图像之间的相似性,这些距离可以使用余弦距离等度量。正如这篇中型博客文章所示,存在几种距离度量,余弦距离常用于比较嵌入。

SimCLR:简单对比学习

SimCLR代表视觉表征对比学习的简单框架。它是使用一种称为对比学习的度量学习方法生成图像嵌入的常用方法之一。在对比学习中,对比损失函数比较两个嵌入是相似的(0)还是不同的(1)。

SimCLR的优点在于它是一种简单的自监督算法(图像类不需要任何标签!)这实现了与一些受监督方法相当的性能,如下图所示!

0443885e6b99c39992d92d36ba6a1025.png 815817ab0bdd36f2cb5307bfa3455315.png

SimCLR的基本思想如下。

  1. 给定一个图像,创建同一图像的两个增强版本。这些增强可以是裁剪和调整大小、颜色失真、旋转、添加噪声等。上图显示了一些增强的示例。

  2. 批处理中所有图像的增强版本通过CNN编码器,该编码器将图像转换为嵌入。然后,这些CNN嵌入通过一个简单的多层感知器(MLP)将其转换为另一个空间,该感知器只有一个隐藏层。

  3. 最后,使用余弦距离比较MLP输出处的嵌入。该模型期望来自同一图像的增强的余弦距离为0,而来自不同图像的增强的余弦距离为1。然后,损失函数更新CNN和MLP的参数,以便嵌入更接近我们的期望。

  4. 一旦训练完成,我们就不再需要MLP,直接使用CNN编码器的输出作为嵌入。

下图从概念上解释了整个过程。有关更多详细信息,请查看这篇谷歌博客文章。

https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html

73ac9ad6768f1cb3d5bcf6d7f56fb42a.png

作者进行了几次实验,确定随机裁剪和颜色失真是增强的最佳组合,如下所示。

9a3305d4f263bf47f6e0c52c73b0e57f.png

与towhee一样,我们使用其他研究人员在ImageNet上预先训练的模型直接提取SimCLR嵌入。然而,在撰写本文时,为了获得SimCLR预训练过的嵌入,我们需要使用Pytorch Lightning Bolts库编写几行代码。我从官方lightning文档中改编了以下内容。首先,在终端窗口中使用pip安装必要的库。

  1. # Activate the conda environment if not already done so
  2. # conda activate semantic_similarity
  3. pip install lightning-bolts

接下来,在Jupyter笔记本单元中,我们导入必要的库,并根据你的计算机是否有GPU将设备设置为cuda或cpu。

  1. from pl_bolts.models.self_supervised import SimCLR
  2. import torch
  3. from torch.utils.data import Dataset, DataLoader
  4. from torchvision import io, transforms 
  5. # Use GPU if it is available
  6. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

接下来,让我们加载在ImageNet上预先训练的SimCLR模型,并将其设置为评估模式,因为我们只想从模型中获得嵌入,而不想再训练它。

  1. # load resnet50 pre-trained using SimCLR on imagenet
  2. weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
  3. simclr = SimCLR.load_from_checkpoint(weight_path, strict=False, batch_size=32)
  4. # Send the SimCLR encoder to the device and set it to eval
  5. simclr_resnet50 = simclr.encoder.to(device)
  6. simclr_resnet50.eval();

以下两个步骤针对Pytorch;用于实现模型的基础库。我们首先创建一个数据集,该数据集可以接受我们的数据帧作为输入,从img_path列读取图像,应用一些转换,最后创建一批图像,我们可以将其输入到模型中。

  1. # Create a dataset for Pytorch
  2. class FashionImageDataset(Dataset):
  3.     def __init__(self, df, transform=None):
  4.         
  5.         self.df = df
  6.         
  7.         self.transform = transform
  8.     def __len__(self):
  9.         
  10.         return len(self.df)
  11.     def __getitem__(self, idx):
  12.         
  13.         # Load the Image
  14.         
  15.         img_path = self.df.loc[idx, 'img_path']
  16.         
  17.         image = io.read_image(img_path, mode=io.image.ImageReadMode.RGB)/255.
  18.         # Apply Transformations
  19.         
  20.         if self.transform:
  21.             
  22.             image = self.transform(image)
  23.             
  24.         return image
  25. # Transforms
  26. ## Normalize transform to ensure the images have similar intensity distributions as ImageNet
  27. ## Resize transform to ensure all images in a batch have the same size 
  28. transformations = transforms.Compose([
  29.     transforms.Normalize((0.4850.4560.406), (0.2290.2240.225)),
  30.     transforms.Resize(size=(6464))
  31. ])
  32.     
  33. # Create the DataLoader to load images in batches
  34. emb_dataset = FashionImageDataset(df=image_df, transform=transformations)
  35. emb_dataloader = DataLoader(emb_dataset, batch_size=32)

最后,我们可以在dataloader中迭代批处理,为所有图像生成嵌入,并将它们作为列存储回dataframe中。

  1. # Create embeddings
  2. embeddings = []
  3. for batch in tqdm(emb_dataloader):
  4.     
  5.     batch = batch.to(device)
  6.     
  7.     embeddings += simclr_resnet50(batch)[0].tolist()
  8. # Assign embeddings to a column in the dataframe
  9. image_df['simclr_img_embeddings'] = embeddings
  10. image_df.head()
93c9fb2ed975699126bd184baa0505e3.png

现在是有趣的部分!我们可以使用相同的helper函数测试SimCLR嵌入的质量。

我们从数据帧中查询随机图像,并显示k个最相似的图像。如下所示,SimCLR嵌入也非常适合为我们运行的每个查询找到类似的图像!

  1. plot_similar(df=image_df,
  2.              embedding_col='simclr_img_embeddings'
  3.              query_index=randint(0len(image_df)), 
  4.              k_neighbors=5)
490f540df0eb3f4b661a3b9e957c9885.png

用度量学习目标预训练的图像-文本多模态神经网络

最后,将图像和文本嵌入到统一嵌入空间的模型有了巨大的改进,这开辟了几个优秀的应用程序,如 Image-To-Text 和 Text-To - 图像相似度搜索。这种范式中最流行的模型之一是CLIP(对比语言图像预训练)。

CLIP是一种基于度量学习框架的神经网络。CLIP使用图像作为锚点,相应的文本描述作为正样本来构建图像-文本对。我们可以在多个应用程序中使用CLIP,包括文本到图像、图像到文本、图像到图像和文本到文本的相似性搜索。

图像通过ResNet或ViT编码器生成图像嵌入。文本描述通过基于Transformer的编码器提供,以生成文本嵌入。CLIP联合训练图像和文本编码器,以便在一批N个图像-文本对中,第i个图像的嵌入与第i个文本的嵌入具有最高的点积,如下图所示。

584e2342b06d6f812832cac9e7290aa4.png

训练完成后,我们可以通过将两者转换为各自的嵌入并使用点积或余弦距离进行比较,找到与查询图像最相似的文本行,如下图所示。

相反,我们也可以以相同的方式搜索给定查询文本的最相似图像。让我们看看如何在示例问题上实现这一点。

332982e57a82157035ce77f9f73e55fc.png
Sentence Transformers

使用优秀的句子Transformer库,在我们的数据上生成CLIP嵌入非常简单。然而,由于操作系统对一次可以打开的文件数量的限制,在处理成千上万的图像时,我们需要编写几行样板代码。首先,在终端窗口中使用pip安装必要的库。

  1. # Activate the conda environment if not already done so
  2. # conda activate semantic_similarity
  3. pip install sentence_transformers ftfy

接下来,在Jupyter笔记本单元中,让我们导入库并实例化片段模型。

  1. from sentence_transformers import SentenceTransformer
  2. model = SentenceTransformer('clip-ViT-B-32')

接下来,我们需要迭代10K个图像,以绕过操作系统对一次可以打开的文件数量的限制。我们在每次迭代期间加载所有图像并生成CLIP嵌入。

  1. # Initialize an empty list to collect embeddings
  2. clip_embeddings = []
  3. # Generate embeddings for 10_000 images on each iteration
  4. step = 10_000
  5. for idx in range(0len(image_df), step):
  6.     # Load the `step` number of images
  7.     images = [
  8.         Image.open(img_path).convert('RGB'
  9.         for img_path in image_df['img_path'].iloc[idx:idx+step]
  10.     ]
  11.     
  12.     # Generate CLIP embeddings for the loaded images
  13.     clip_embeddings.extend(model.encode(images, show_progress_bar=True).tolist())
  14. # Assign the embeddings back to the dataframe
  15. image_df['clip_img_embedding'] = clip_embeddings
  16. image_df.head()
1ac75c1218be15fa5eda0f5c6df3a6f0.png

现在我们有了所有图像的CLIP嵌入,我们可以使用相同的辅助函数来测试嵌入的质量。

我们从数据帧中查询随机图像,并显示k个最相似的图像。如下图所示,CLIP嵌入也非常准确,可以为我们运行的每个查询找到相似的图像!

  1. plot_similar(df=image_df,
  2.              embedding_col='clip_img_embedding'
  3.              query_index=randint(0len(image_df)), 
  4.              k_neighbors=5)
a6b138295b688eb49dcf595a194f41f6.png

虽然我们必须编写一些额外的代码来生成CLIP嵌入,但它提供的一个显著优势是文本到图像搜索。换句话说,我们可以搜索与给定文本描述匹配的所有图像。让我们看看下面的内容。

由于我们已经将图像转换为CLIP嵌入,现在只需要将文本查询转换为CLIP嵌入。然后,我们可以利用文本嵌入和数据帧中所有图像嵌入之间的余弦相似度来搜索相似的产品。我们将编写一个简单的助手函数来为我们完成这一切,如下所示。最后,我们将绘制所有类似的k个产品图像。

  1. def text_image_search(text_query, df, img_emb_col, k=5):
  2.     '''Helper function to take a text query as input and display the k nearest neighbor images
  3.     '''
  4.     
  5.     # Calculate the text embeddings
  6.     text_emb = model.encode(text_query).tolist()
  7.     
  8.     # Calculate the pairwise cosine similarities between text query and images from all rows
  9.     similarities = cosine_similarity([text_emb], df[img_emb_col].values.tolist())[0]
  10.     
  11.     # Find nearest neighbors
  12.     nearest_indices = np.argpartition(similarities, -k)[-k:]
  13.     
  14.     # Print Query Text
  15.     print(f'Query Text: {text_query}')
  16.     
  17.     # Plot nearest neighbors images
  18.     fig = plt.figure(figsize=(20,4))
  19.     
  20.     plt.suptitle('Similar Products')
  21.     
  22.     for idx, neighbor in enumerate(nearest_indices):
  23.         
  24.         plt.subplot(1len(nearest_indices), idx+1)
  25.         
  26.         img = Image.open(df['img_path'][neighbor]).convert('RGB')
  27.         
  28.         plt.imshow(img)
  29.         
  30.         plt.title(f'Cosine Sim: {similarities[neighbor]:.3f}')
  31.         
  32.     plt.tight_layout()

现在,我们可以使用helper函数测试示例文本查询。如下图所示,如果我们的测试查询是“一件女装的照片”,那么最相似的产品都是女装!尽管每个产品的标题没有明确指定“连衣裙”一词,但CLIP模型能够仅从文本和图像嵌入推断出这些图像与查询“一张女装照片”最相关。

继续尝试其他查询!

  1. text_query = "a photo of a women's dress"
  2. text_image_search(text_query, 
  3.                   df=image_df, 
  4.                   img_emb_col='clip_img_embedding'
  5.                   k=5)
16cd60cc2517a1a62defbd905d0efae8.png

结论

深度学习研究和开源代码库的最新技术为从图像和文本数据生成高质量嵌入提供了许多简单的方法。这些现成的嵌入是为许多实际问题构建原型的绝佳起点!下面的流程图有助于选择要使用的初始嵌入。但是,在将任何单个查询部署到生产环境之前,请不断评估一些复杂示例查询上嵌入模型的准确性!

说到生产,我们在这里使用的数据集是一个只有6K个图像的玩具数据集。在现实世界的应用程序中,例如电子商务商店,你将有数亿个产品图像需要在几秒钟内嵌入、存储和执行近邻搜索!问题的规模需要使用强大的向量搜索数据库,如Milvus!

6c681feb6c2f7450c5e67a4b52ec79d1.png

感谢阅读!

☆ END ☆

如果看到这里,说明你喜欢这篇文章,请转发、点赞。微信搜索「uncle_pn」,欢迎添加小编微信「 woshicver」,每日朋友圈更新一篇高质量博文。

扫描二维码添加小编↓

1f9bdda42f088831d9016ab00bbd95a5.jpeg

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/468535
推荐阅读
相关标签
  

闽ICP备14008679号