当前位置:   article > 正文

open_clip仓库成分与模型文件model.py 介绍_open_clip库

open_clip库

起因:

 在DA-CLIP的开源库的DA-CLIP.md中自述该项目基于CLIP 和open_clip,在之前的退化类型检测中 我一度以为仓库只是使用了CLIP 的源码, 然而当发现缺少da-clip的模型名称时,我发现DA-CLIP使用的完全是open_clip的代码版本,专门配置了da-clip.json在open_clip的model_configs

This repository is based on the OpenAI's CLIP (Contrastive Language-Image Pre-training) and open_clip.

We extend the CLIP to a degradation-aware version (DA-CLIP) which predicts both degradation embedding and clean content embedding from corrupted images.

Then we can use the embeddings to improve image restoration performance and assist unified image restoration.

Moreover, we use the pretrained ViT CLIP model (ViT-B/32 on LAION-2B) and add an additional controller to control the image encoder.

该库基于OpenAI的CLIP(对比语言图像预训练)和open_clip。
我们将CLIP扩展到退化感知版本(DA-CLIP),该版本预测退化嵌入和从损坏图像中嵌入干净内容。

然后,我们可以利用该嵌入来提高图像恢复性能,帮助统一图像恢复。

此外,我们使用预训练的ViT CLIP模型(LAION-2B上的ViT- b /32),并添加一个额外的控制器来控制图像编码器。

 纵观DA-CLIP代码,关于CLIP模块,基本是从open_clip库上进行扩展。对于clip代码使用已有博主进行说明

CLIP模型原理与代码实现详解-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/weixin_38252409/article/details/133828294

 背景:

然而CSDN上对open_clip项目的介绍寥寥,以下项目虽有涉猎但都未对其仓库源码进行解析。

ImageNet零样本准确率首次超过80%!OpenCLIP:性能最强的开源CLIP模型-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/amusi1994/article/details/129036171?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171120819416800227464964%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=171120819416800227464964&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_click~default-1-129036171-null-null.142%5Ev99%5Epc_search_result_base6&utm_term=OpenCLIP&spm=1018.2226.3001.4449

项目解读:open_clip

结构

  • docs/ - 包含项目文档和使用说明。
  • src/ - 包含项目的主要源代码。
  • tests/ - 包含项目测试代码。
  • setup.py - 用于安装和项目的 Python 包设置文件。

src下的文件夹,包含open_clip和training,

由于DA-CLIP与open_clip数据集、模型结构等不同,其中DA-CLIP 主要使用了open_clip文件夹

该文件夹下:

  • model.py - 定义了 CLIP 模型的结构,包括图像编码器和文本编码器。这是模型训练和推理的核心部分。
  • config.py - 包含了模型配置的类和函数,用于设置模型的不同参数,如学习率、批次大小等。
  • tokenizer.py - 包含了文本处理的分词器,用于将文本转换为模型可以理解的格式。
  • trainer.py - 包含了模型训练的主要逻辑,如设置优化器、损失函数和训练循环。
  • evaluator.py - 包含了评估模型性能的代码,通常用于在验证集或测试集上计算指标。
  • utils.py - 包含了一些实用工具函数,如数据处理、日志记录等。
  • transform.py - 包含了数据预处理和增强的转换函数,用于准备输入数据。
  • model_configs:包含各类模型配置,如coca、convnext、EVA、Vit、RN50等json文件

open_clip 目录下,每个文件都有其独特的价值和作用,但如果要挑选一个文件进行深入解读,model.py 可能是最值得关注的。这是因为 model.py 定义了 CLIP 模型的核心结构,它直接关联到模型的性能和功能。以下是对 model.py 文件的详细解读:

model.py 的作用

model.py 文件定义了 CLIP 模型的架构,包括图像编码器和文本编码器的设计。CLIP 模型是一个双流(two-stream)模型,它分别处理图像和文本输入,并通过对比学习(contrastive learning)的方式,学习图像和文本之间的对应关系。

主要组件

  1. 图像编码器(Image Encoder): 通常使用预训练的卷积神经网络(如 ResNet 或 Vision Transformer)作为基础,用于提取图像特征。

  2. 文本编码器(Text Encoder): 通常基于 Transformer 架构,用于处理文本输入并提取文本特征。文本编码器可能使用 BERT 或类似的 Transformer 变体。

  3. 投影头(Projection Heads): 用于将图像和文本编码器的输出映射到共同的特征空间,以便进行相似性比较。

  4. 损失函数(Loss Function): CLIP 模型使用对比损失函数来训练,这要求模型能够区分匹配的图像-文本对和不匹配的对。

关键概念

  • 对比学习(Contrastive Learning): 一种自监督学习方法,模型通过比较正样本对和负样本对来学习特征表示。

  • 多模态学习(Multimodal Learning): 涉及处理和理解多种类型数据(如图像和文本)的机器学习方法。

  • 零样本学习(Zero-Shot Learning): 模型能够在没有见过特定类别的样本的情况下进行分类或识别。

代码结构

model.py 文件包含以下部分:

  • 类定义(Class Definitions): 定义了图像编码器、文本编码器和整个 CLIP 模型的结构。

  • 前向传播(Forward Pass): 描述了数据如何通过模型,以及如何计算图像和文本的特征表示。

  • 初始化方法(Initialization Methods): 描述了模型权重的初始化过程,这对于训练的稳定性和收敛速度至关重要。

  • 损失计算(Loss Computation): 实现了对比损失函数,用于训练过程中的优化。

通过对 model.py 文件的深入理解,可以更好地把握 CLIP 模型的工作原理,以及如何修改和扩展模型以适应不同的应用场景。这个文件是进行模型训练和评估的基础,对于任何想要深入了解或贡献于 OpenCLIP 项目的人来说都是必读的。

该文件的CLIP类如下

  1. class CLIP(nn.Module):
  2. output_dict: torch.jit.Final[bool]
  3. def __init__(
  4. self,
  5. embed_dim: int,
  6. vision_cfg: CLIPVisionCfg,
  7. text_cfg: CLIPTextCfg,
  8. quick_gelu: bool = False,
  9. init_logit_scale: float = np.log(1 / 0.07),
  10. init_logit_bias: Optional[float] = None,
  11. cast_dtype: Optional[torch.dtype] = None,
  12. output_dict: bool = False,
  13. ):
  14. super().__init__()
  15. self.output_dict = output_dict
  16. self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
  17. text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
  18. self.transformer = text.transformer
  19. self.context_length = text.context_length
  20. self.vocab_size = text.vocab_size
  21. self.token_embedding = text.token_embedding
  22. self.positional_embedding = text.positional_embedding
  23. self.ln_final = text.ln_final
  24. self.text_projection = text.text_projection
  25. self.text_pool_type = text.pool_type
  26. self.register_buffer('attn_mask', text.attn_mask, persistent=False)
  27. self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
  28. if init_logit_bias is not None:
  29. self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
  30. else:
  31. self.logit_bias = None
  32. def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
  33. # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
  34. self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
  35. @torch.jit.ignore
  36. def set_grad_checkpointing(self, enable=True):
  37. self.visual.set_grad_checkpointing(enable)
  38. self.transformer.grad_checkpointing = enable
  39. def encode_image(self, image, normalize: bool = False):
  40. features = self.visual(image)
  41. return F.normalize(features, dim=-1) if normalize else features
  42. def encode_text(self, text, normalize: bool = False):
  43. cast_dtype = self.transformer.get_cast_dtype()
  44. x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
  45. x = x + self.positional_embedding.to(cast_dtype)
  46. x = x.permute(1, 0, 2) # NLD -> LND
  47. x = self.transformer(x, attn_mask=self.attn_mask)
  48. x = x.permute(1, 0, 2) # LND -> NLD
  49. x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
  50. x, _ = text_global_pool(x, text, self.text_pool_type)
  51. if self.text_projection is not None:
  52. if isinstance(self.text_projection, nn.Linear):
  53. x = self.text_projection(x)
  54. else:
  55. x = x @ self.text_projection
  56. return F.normalize(x, dim=-1) if normalize else x
  57. def get_logits(self, image, text):
  58. image_features = self.encode_image(image, normalize=True)
  59. text_features = self.encode_text(text, normalize=True)
  60. image_logits = self.logit_scale.exp() * image_features @ text_features.T
  61. if self.logit_bias is not None:
  62. image_logits += self.logit_bias
  63. text_logits = image_logits.T
  64. return image_logits, text_logits
  65. def forward(
  66. self,
  67. image: Optional[torch.Tensor] = None,
  68. text: Optional[torch.Tensor] = None,
  69. ):
  70. image_features = self.encode_image(image, normalize=True) if image is not None else None
  71. text_features = self.encode_text(text, normalize=True) if text is not None else None
  72. if self.output_dict:
  73. out_dict = {
  74. "image_features": image_features,
  75. "text_features": text_features,
  76. "logit_scale": self.logit_scale.exp()
  77. }
  78. if self.logit_bias is not None:
  79. out_dict['logit_bias'] = self.logit_bias
  80. return out_dict
  81. if self.logit_bias is not None:
  82. return image_features, text_features, self.logit_scale.exp(), self.logit_bias
  83. return image_features, text_features, self.logit_scale.exp()

这个类定义了 CLIP 模型的结构和行为。

初始化

  1. def __init__(
  2. self,
  3. embed_dim: int,
  4. vision_cfg: CLIPVisionCfg,
  5. text_cfg: CLIPTextCfg,
  6. quick_gelu: bool = False,
  7. cast_dtype: Optional[torch.dtype] = None,
  8. output_dict: bool = False,
  9. ):
  10. super().__init__()
  11. self.output_dict = output_dict
  12. self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
  13. text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
  14. self.transformer = text.transformer
  15. self.context_length = text.context_length
  16. self.vocab_size = text.vocab_size
  17. self.token_embedding = text.token_embedding
  18. self.positional_embedding = text.positional_embedding
  19. self.ln_final = text.ln_final
  20. self.text_projection = text.text_projection
  21. self.register_buffer('attn_mask', text.attn_mask, persistent=False)
  22. self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
参数解释
  • self: 指向类的实例。
  • embed_dim: 嵌入维度,这是模型中嵌入层的维度。
  • vision_cfg: 图像配置对象,它包含了构建图像处理部分(视觉塔)所需的配置信息。
  • text_cfg: 文本配置对象,它包含了构建文本处理部分(文本塔)所需的配置信息。
  • quick_gelu: 布尔值,指示是否使用快速的GELU(Gaussian Error Linear Unit)激活函数。
  • cast_dtype: 可选参数,指定数据类型,用于将模型参数转换为指定的数据类型。
  • output_dict: 布尔值,指示模型输出是否应该是一个字典。
方法体解释
  • super().__init__(): 调用父类的构造函数。
  • self.output_dict: 存储传入的output_dict参数,这可能影响模型输出的格式。
  • self.visual: 通过调用一个内部函数_build_vision_tower来构建视觉塔,并存储结果。
  • text: 通过调用一个内部函数_build_text_tower来构建文本塔,并存储结果。
  • self.transformer: 从文本塔中提取变换器(transformer)模块。
  • self.context_length: 存储文本塔的上下文长度。
  • self.vocab_size: 存储文本塔的词汇表大小。
  • self.token_embedding: 存储文本塔的词嵌入层。
  • self.positional_embedding: 存储文本塔的位置嵌入层。
  • self.ln_final: 存储文本塔的最终层归一化(Layer Normalization)。
  • self.text_projection: 存储文本塔的文本投影层。
  • self.register_buffer('attn_mask', text.attn_mask, persistent=False): 注册一个缓冲区,用于存储文本塔的注意力掩码(attention mask),这个掩码在自注意力机制中用于指示哪些位置应该被模型关注。
  • self.logit_scale: 创建一个可学习的参数,用于缩放模型的输出(logits),初始化为一个全1的向量,乘以一个基于经验的对数缩放因子。

锁定参数

下面两个方法的目的是为了在训练过程中冻结图像塔和文本塔的参数,这与 LiT 方法的核心思想相符,即在对比学习(contrastive learning)过程中,只更新文本模型的参数,而保持图像模型的参数不变。这样做可以利用预训练图像模型的强大特征提取能力,同时通过文本模型来适应新任务,实现零样本(zero-shot)迁移学习。 

  1. def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
  2. # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
  3. self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
  4. def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
  5. for param in self.transformer.parameters():
  6. param.requires_grad = False
  7. self.token_embedding.requires_grad = False
  8. self.positional_embedding.requires_grad = False
  9. self.text_projection.requires_grad = False

这段代码定义了两个方法,lock_image_towerlock_text_tower,它们似乎是用于控制神经网络模型中的图像塔(image tower)和文本塔(text tower)的参数更新机制。CLIP类实现了类似于 LiT(Locked-image Tuning)的机制,这是一种在图像和文本模型对齐时锁定预训练图像模型参数的技术,如在论文 "LiT: Zero-Shot Transfer with Locked-image text Tuning" 中所描述的。

lock_image_tower 方法
  • unlocked_groups 参数:指定哪些层组应该保持未锁定(即可训练的)。默认值为 0,意味着所有层组都被锁定。
  • freeze_bn_stats 参数:决定是否冻结批量归一化(Batch Normalization, BN)层的统计数据。如果设置为 True,则BN层的运行时统计数据(均值和方差)不会在训练过程中更新。

在这个方法中,调用了 self.visual.lock(),这是一个自定义的方法,用于锁定图像塔中的参数。这个方法可能会根据 unlocked_groups 参数来决定哪些层或层组应该保持可训练状态。

lock_text_tower 方法
  • unlocked_layers 参数:指定文本塔中应该保持未锁定的层的数量。默认值为 0,意味着所有层都被锁定。
  • freeze_layer_norm 参数:决定是否冻结层归一化(Layer Normalization, LN)的参数。如果设置为 True,则LN层的参数不会在训练过程中更新。

在这个方法中,遍历了 self.transformer 中的所有参数,并将它们的 requires_grad 属性设置为 False,这意味着在训练过程中这些参数不会更新。此外,还冻结了文本塔中的词嵌入(token_embedding)、位置嵌入(positional_embedding)和文本投影(text_projection)的参数。

set_grad_checkpointing 方法设置梯度检查点

  1. def set_grad_checkpointing(self, enable=True):
  2. self.visual.set_grad_checkpointing(enable)
  3. self.transformer.grad_checkpointing = enable
  • enable: 布尔值,指示是否启用梯度检查点。当启用时,可以减少模型训练过程中的内存消耗,但可能会增加计算成本。

这个方法在模型的两个主要组件,visual(视觉塔)和transformer(文本塔)上设置梯度检查点。梯度检查点是一种内存优化技术,它允许模型在前向传播过程中保存中间激活的梯度,从而在反向传播时减少内存使用。

encode_image 方法

  1. def encode_image(self, image, normalize: bool = False):
  2. features = self.visual(image)
  3. return F.normalize(features, dim=-1) if normalize else features
  • image: 输入的图像数据。
  • normalize: 布尔值,指示是否对编码后的特征进行归一化。

这个方法使用模型的视觉塔来编码输入的图像数据。如果normalize参数为True,则使用F.normalize函数对特征进行归一化处理,否则直接返回特征。

encode_text 方法

  1. def encode_text(self, text, normalize: bool = False):
  2. cast_dtype = self.transformer.get_cast_dtype()
  3. x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
  4. x = x + self.positional_embedding.to(cast_dtype)
  5. x = x.permute(1, 0, 2) # NLD -> LND
  6. x = self.transformer(x, attn_mask=self.attn_mask)
  7. x = x.permute(1, 0, 2) # LND -> NLD
  8. x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
  9. # take features from the eot embedding (eot_token is the highest number in each sequence)
  10. x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
  11. return F.normalize(x, dim=-1) if normalize else x
  • text: 输入的文本数据。
  • normalize: 布尔值,指示是否对编码后的文本特征进行归一化。

这个方法首先将文本数据通过词嵌入层和位置嵌入层来获取嵌入表示,然后将这些嵌入表示转换为适合transformer的格式。接下来,使用transformer处理这些嵌入,并通过层归一化(self.ln_final)进行归一化。最后,通过文本投影层(self.text_projection)将嵌入映射到最终的特征空间,并进行归一化处理(如果normalizeTrue)。

前向传播函数 

  1. def forward(
  2. self,
  3. image: Optional[torch.Tensor] = None,
  4. text: Optional[torch.Tensor] = None,
  5. ):
  6. image_features = self.encode_image(image, normalize=True) if image is not None else None
  7. text_features = self.encode_text(text, normalize=True) if text is not None else None
  8. if self.output_dict:
  9. return {
  10. "image_features": image_features,
  11. "text_features": text_features,
  12. "logit_scale": self.logit_scale.exp()
  13. }
  14. return image_features, text_features, self.logit_scale.exp()

段代码定义了一个名为 forward 的方法,它是神经网络模型中的前向传播函数。该方法接收图像和文本作为输入,并输出它们的特征表示以及用于缩放 logits 的比例因子。这个方法是模型核心功能的一部分,负责将输入数据转换为模型可以处理的嵌入向量。

参数解释
  • self: 指向类的实例。
  • image: 输入的图像数据,类型为 torch.Tensor。如果为 None,则表示没有图像输入。
  • text: 输入的文本数据,类型为 torch.Tensor。如果为 None,则表示没有文本输入。
方法体解释
  • image_features: 如果提供了图像输入,使用 self.encode_image 方法对图像进行编码,并在返回前进行归一化处理。如果没有图像输入,则设置为 None
  • text_features: 如果提供了文本输入,使用 self.encode_text 方法对文本进行编码,并在返回前进行归一化处理。如果没有文本输入,则设置为 None
  • if self.output_dict: 判断是否以字典格式输出。如果 output_dict 属性为 True,则将图像特征、文本特征和 logits 缩放因子封装成一个字典返回。否则,将这三个值作为独立的返回值。
返回值
  • 如果 self.output_dict 为 True,则返回一个包含图像特征、文本特征和 logits 缩放因子的字典。
  • 如果 self.output_dict 为 False,则返回一个包含图像特征、文本特征和 logits 缩放因子的元组。
总结

forward 方法是模型的入口点,它根据输入的图像和文本数据,通过模型的编码器生成对应的特征表示。这些特征表示可以用于后续的多模态任务,例如图像-文本匹配、联合嵌入学习或零样本分类等。此外,该方法还提供了 logits 缩放因子,这在某些情况下(如对比学习或分类任务)可能是必需的。通过灵活的输出格式,该方法可以适应不同的使用场景和后处理需求。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/548662
推荐阅读
相关标签
  

闽ICP备14008679号