赞
踩
总体而言代码使用了多层方法的调用
clip_interrogator在open_clip的最外层又定义了一层
Interrogator():__init__()加载BLIP模型load_clip_model()加载CLIP模型
open_clip.create_model_and_transforms,open_clip.create_model_from_pretrained
create_model()
这个比较多样。以openai模型为例,open_clip\openai.py。还有CoCa、CustomTextCLIP、CLIP等直接创建类实例,而非第五层才创建,该途径需要根据create_model()输入的custom_text布尔值设置。具体见create_model()
load_openai_model()
以openai模型为例
build_model_from_openai_state_dict()
CLIP类、CustomTextCLIP类 open_clip\model.py
CoCa类 open_clip\coca_model.py
DA-CLIP接受一个CLIP实例作为参数初始化
def __init__(self, clip_model: CLIP):
【3万字代码解读】DA-CLIP/open_clip模型创建、模型配置读取、预训练权重地址读取http://t.csdnimg.cn/HTz2m
旨在实现统一的视觉-语言理解和生成。该库提供了预训练和微调后的模型检查点,支持
图像-文本检索、图像标题生成、视觉问答和NLVR2等多种任务。
DA-CLIP关于使用BLIP生成数据集的代码注释http://t.csdnimg.cn/rW7E1一个BLIP在colab上的运行demo,和对generate_caption.py代码讲解,注意本文才涉及带代码运行和windows运行相关问题。
是一个结合了 OpenAI 的 CLIP 和 Salesforce 的 BLIP 技术的 prompt 工程工具,它专门设计用于优化文本提示(prompts),以便与给定图像相匹配。该工具的技术实现基于这两个先进的多模态模型,通过分析图像内容和相关的文本描述,生成高质量的文本提示。
功能方面,CLIP-Interrogator 允许用户通过自然语言与 AI 进行交互,提出关于图像的问题,并获取相应的文本描述。这些生成的文本提示可以用于文本到图像的模型(如 Stable Diffusion)来创造新的艺术作品或进行图像生成的实验。此外,CLIP-Interrogator 支持在不同的 CLIP 模型之间进行选择,并且可以配置以适应不同的硬件条件,如低 VRAM 设备。
CLIP-Interrogator 可以作为一个库来使用,允许开发者在自己的脚本中调用其功能,从而实现图像内容的自动化分析和描述生成。此外,它还提供了与自己定义的术语列表进行排名对比的功能,这使得用户可以根据自己的特定需求定制化模型的输出。
GitHub - pharmapsychotic/clip-interrogator: Image to prompt with BLIP and CLIP
建议提前阅读该仓库说明,,对模型使用和参数有相关介绍
CLIP查询器使用OpenCLIP,它支持许多不同的预训练CLIP模型。对于Stable Diffusion1的最佳提示。X使用viti - l -14/openai为clip_model_name。Stable Diffusion2.0使用viti - h -14/ laon2b_s32b_b79k
Config对象允许您配置CLIP询问者的处理。
clip_model_name:使用哪个OpenCLIP预训练的CLIP模型
Cache_path:保存预先计算的文本嵌入的路径
download_cache:当为True时,将从huggingface下载预先计算的嵌入
chunk_size: CLIP的批处理大小,对于较小的VRAM使用更小的
quiet: True时不显示进度条或文本输出
Create dataset:
To generate clean captions with BLIP, we use the clip-interrogator tool. Install it with
pip install clip-interrogator==0.6.0
and run:python ../scripts/generate_captions.pyThen you will get
daclip_train.csv
anddaclip_val.csv
under thedatasets/universal
directory.
使用pip自带的channel直接下载网速太慢又容易中断。
我直接修改配置,下次就不用指定下载源了
在命令行输入notepad %APPDATA%\pip\pip.ini
修改pip.ini文件的
index-url改为清华源,地址如下
https://pypi.tuna.tsinghua.edu.cn/simple/
光速下完
index-url
是默认的 channel。
extra-index-url
是额外的 channels
trusted-host
是你信任的 hosts,
不挂梯子连不了hugging face,,挂了代理又运行不了。
先看报错: Can't load tokenizer for 'bert-base-uncased'.点击查看报错点源码
文件路径为C:\Users\86136\anaconda3\envs\DA-CLIP\Lib\site-packages\blip\models\blip.py
- def init_tokenizer():
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- tokenizer.add_special_tokens({'bos_token':'[DEC]'})
- tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
- tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
- return tokenizer
明确目标先离线下载bert-base-uncased,参考博文
http://t.csdnimg.cn/BDNALhttp://如何下载和在本地使用Bert预训练模型推荐谷歌浏览器,下载更快。博主下载了pytorch版本和其他相关配置文件。
地址放在本项目该代码目录下的新建文件夹bert-base-uncased中
修改该方法下代码读取的目录,,修改为你的地址
tokenizer = BertTokenizer.from_pretrained('C:\\Users\\86136\\Desktop\\daclip-uir-main\\scripts\\bert-base-uncased')
ci = Interrogator(Config(clip_model_name="ViT-L-14/openai"))
该代码创建一个Interrogator类实例ci 。Interrogator类中加载了BLIP模型和CLIP模型,相应代码下面会有。随后使用generate_caption.py中创建的generate_captions()方法
generate_captions(dataroot, ci, 'val')
该该方法中 创建caption的语句是Interrogator类中的generate_caption方法。
所以我们有必要查看clip_interrogator.py的相关代码
caption = ci.generate_caption(image)
在generate_caption.py的ci = Interrogator(Config(clip_model_name="ViT-L-14/openai"))
查看Config方法,鼠标双击Config后等待一会出现注解
点击标蓝的文字出现代码文件路径
打开
C:\Users\86136\anaconda3\envs\DA-CLIP\________这部分是你的环境路径
Lib\site-packages\clip_interrogator\clip_interrogator.py
该文件开头定义了BLIP_MODELS的下载地址,,我运行时自动下载很快就没有修改这部分
有需要可以自行下载再修改地址,,不过下载后本地地址要写对不然报错RuntimeError,
加载地址相关代码可见
3.3.4
"large"1.75Ghttps://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth
- BLIP_MODELS = {
- 'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
- 'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
- }
查看Config类
blip_model_type: str = 'large' # choose between 'base' or 'large' #尽管后面关于blip_model_type的定义全都是base.但在最外层用的是large 因为Interrogator使用的是Config.blip_model_type
在Interrogator类的__init__中,如下
- if config.blip_model is None:
- if not config.quiet:
- print("Loading BLIP model...")
- blip_path = os.path.dirname(inspect.getfile(blip_decoder))
- configs_path = os.path.join(os.path.dirname(blip_path), 'configs')
- med_config = os.path.join(configs_path, 'med_config.json')
- blip_model = blip_decoder(
- pretrained=BLIP_MODELS[config.blip_model_type],
- image_size=config.blip_image_eval_size,
- vit=config.blip_model_type,
- med_config=med_config
- ) #创建模型
- blip_model.eval() # 将模型设置为评估模式
- if not self.config.blip_offload:
- blip_model = blip_model.to(config.device)
- self.blip_model = blip_model
可以看到该blip模型的配置文件路径为:你的虚拟环境包下的blip\configs\bert_config.json
C:\Users\86136\anaconda3\envs\DA-CLIP\Lib\site-packages\blip\configs\bert_config.json
该blip_decoder仍是嵌套了一层模型加载
- def blip_decoder(pretrained='',**kwargs):
- model = BLIP_Decoder(**kwargs)
- if pretrained:
- model,msg = load_checkpoint(model,pretrained)
- assert(len(msg.missing_keys)==0)
- return model
在BLIP_Decoder中读取了配置文件
med_config = BertConfig.from_json_file(med_config)
关于模型权重文件更多细节需要查看loac_checkpoint()
- def load_checkpoint(model,url_or_filename):
- # print("url_or_filename",url_or_filename)
- # https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth
- if is_url(url_or_filename):
- cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
- checkpoint = torch.load(cached_file, map_location='cpu')
- elif os.path.isfile(url_or_filename):
- checkpoint = torch.load(url_or_filename, map_location='cpu')
- else:
- raise RuntimeError('checkpoint url or path is invalid')
-
- state_dict = checkpoint['model']
-
- state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
- if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
- state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
- model.visual_encoder_m)
- for key in model.state_dict().keys():
- if key in state_dict.keys():
- if state_dict[key].shape!=model.state_dict()[key].shape:
- del state_dict[key]
-
- msg = model.load_state_dict(state_dict,strict=False)
- print('load checkpoint from %s'%url_or_filename)
- return model,msg
这段Python代码定义了一个名为 load_checkpoint
的函数,它的目的是加载一个预训练的模型检查点(checkpoint)到一个给定的模型中。这个函数接受两个参数:model
(要加载检查点的模型实例)和 url_or_filename
(包含检查点文件的URL或文件路径)。
由于download_cached_file()代码包含存放地址检测代码,已经下载过的不会再次下载
函数的主要步骤如下:
首先,使用
is_url
函数检查提供的url_or_filename
是否是一个有效的URL。如果是,那么download_cached_file
函数将被用来下载并缓存该文件。torch.load
函数随后用于加载下载的检查点文件,map_location='cpu'
参数确保加载过程在CPU上执行,这在没有GPU可用的情况下很有用。如果
url_or_filename
不是一个URL,那么代码将检查它是否指向一个存在的文件。如果文件存在,同样使用torch.load
来加载检查点。如果
url_or_filename
既不是有效的URL也不是存在的文件,函数将引发一个RuntimeError
,指出检查点的URL或路径无效。加载检查点后,函数将处理检查点中的
visual_encoder.pos_embed
键,使用interpolate_pos_embed
函数对其进行插值。这可能是为了匹配模型的当前位置嵌入(position embeddings)的尺寸。接下来,函数检查
visual_encoder_m.pos_embed
键是否存在于模型的状态字典中,如果存在,也对其进行插值处理。然后,函数遍历模型的状态字典和检查点的状态字典中的所有键,如果键在两个字典中都存在但形状不匹配,那么该键将从检查点的状态字典中删除。
最后,使用
model.load_state_dict
方法将处理后的检查点加载到模型中,strict=False
参数允许忽略不匹配的键。函数打印一条消息,指示已从何处加载检查点,并返回模型和加载操作的结果msg
。
需要使用本地BLIP模型权重地址的同学可以考虑这一行。
- elif os.path.isfile(url_or_filename):
- checkpoint = torch.load(url_or_filename, map_location='cpu')
想知道自动下载的保存地址查看download_cached_file()函数
from timm.models.hub import download_cached_file
博主最终查到如下
C:\Users\86136\.cache\torch\hub\checkpoints\model_large_caption.pth
红色为你的缓存路径
- def load_clip_model(self):
- start_time = time.time()
- config = self.config
-
- if config.clip_model is None:#没有传入模型就
- if not config.quiet:
- print("Loading CLIP model...")
-
- clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2)
- self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
- clip_model_name,
- pretrained=clip_model_pretrained_name,
- precision='fp16' if config.device == 'cuda' else 'fp32',
- device=config.device,
- jit=False,
- cache_dir=config.clip_model_path
- )
- print("create clip_model over")
- self.clip_model.eval()
- else:
- self.clip_model = config.clip_model
- self.clip_preprocess = config.clip_preprocess
- self.tokenize = open_clip.get_tokenizer(clip_model_name)
该代码主要关注点open_clip.create_model_and_transforms,不记得创建模型函数调用过程可以回到开头。注意该函数参数我们可以根据config.clip_model_path设置本地预训练权重地址
cache_dir=config.clip_model_path
ViT-L-4/openai模型权重下载地址在HF,连不上。
根据上篇文章的考证,可以在open_clip\pretrained.py找到所有模型相关下载地址和函数
https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt
挂梯子下载到本地
添加地址参数。注意不包含权重文件名,否则报错。暂时还没找相关代码
- ci = Interrogator(Config(
- clip_model_name="ViT-L-14/openai",
- clip_model_path="E:\\download\\ViT-L-14" # 你的模型权重文件的本地路径
- ))
-
- sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble',
- 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount',
- 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
- trending_list = [site for site in sites]
- trending_list.extend(["trending on "+site for site in sites])
- trending_list.extend(["featured on "+site for site in sites])
- trending_list.extend([site+" contest winner" for site in sites])
-
- raw_artists = _load_list(config.data_path, 'artists.txt')
- artists = [f"by {a}" for a in raw_artists]
- artists.extend([f"inspired by {a}" for a in raw_artists])
- # print(config.data_path)
- # C:\Users\86136\anaconda3\envs\DA-CLIP\lib\site-packages\clip_interrogator\data
- self._prepare_clip()
- self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
- self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
- self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
- self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config)
- self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
- self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config)
这段代码的主要目的是构建和准备一系列的标签(labels)和搜索提示(prompts),这些标签和提示将用于与CLIP模型交互,以便生成或搜索特定类型的艺术作品。代码中创建了几个不同的标签列表,包括基于流行网站和艺术家的标签,以及一些特定的艺术风格、媒介、运动和负面标签。
具体步骤如下:
定义了一个名为
sites
的列表,包含了一系列艺术和设计相关的网站名称,如 'Artstation', 'Behance' 等。使用列表推导式创建了
trending_list
,这个列表包含了基于sites
中的网站名称生成的各种搜索提示,如 "trending on Artstation", "featured on Behance" 等。通过调用
_load_list
函数,从配置路径中加载了一个名为 'artists.txt' 的文件,该文件包含了艺术家的名字。基于加载的艺术家名单
raw_artists
,创建了artists
列表,其中包含了由艺术家名字构成的标签,如 "by {artist_name}" 和 "inspired by {artist_name}"。接下来,代码创建了五个
LabelTable
对象,每个对象都与特定的标签列表和CLIP模型相关联。这些对象分别是:
self.artists
:与艺术家相关的标签。self.flavors
:与艺术风格相关的标签。self.mediums
:与艺术媒介相关的标签。self.movements
:与艺术运动相关的标签。self.trendings
:与流行趋势相关的标签。self.negative
:与负面标签相关的列表
提供对应标签下载地址
https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.safetensorshttps://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.safetensors
https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.safetensorshttps://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.safetensors
https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.safetensors
https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_negative.safetensors
放到你的项目该文件夹下,cache是运行时会生成的,如果没有,新建即可
现在只剩数据集问题了
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。