赞
踩
Diffusion扩散模型分为两个阶段:前向过程 + 反向过程
前向过程 ——>图片中添加噪声
反向过程——>去除图片中的噪声
在每一轮的训练过程中,包含以下内容:
下图是每个Epoch详细的训练过程:
噪声图Noisy image经过训练后的U-Net网络,会得到预测噪声Predicted Noisy,而:去噪图Denoised image = 噪声图Noisy image - 预测噪声图Predicted Noisy。(计算公式省略了具体的参数,只表述逻辑关系)
U-Net的模型结构就是一个编-解码的过程,下采样Downsample、中间块Middle block、上采样Upsample中都包含了ResNet残差网络
1、主干网络做特征提取;2、加强网络做特征组合;3、预测网络做预测输出;
为改善DM扩散模型的缺点,Stable Diffusion引入图像压缩技术,在低维空间完成扩散过程;并添加CLIP模型,使文本-图像产生关联。
1. 图像压缩:DM扩散模型是直接在原图上进行操作,而Stale Diffusion是在较低维度的潜在空间上应用扩散过程,而不是使用实际像素空间,这样可以大幅减少内存和计算成本;
2. 文本-图像关联:在反向扩散过程中对U-Net的结构做了修改,使其可以添加文本向量Text Embedding,使得在每一轮的去噪过程中,让输出的图像与输入的文字产生关联;
Stable Diffusion在实际应用中的过程:原图——经过编码器E变成低维编码图——DM的前向过程逐步添加噪声,变成噪声图——T轮U-Net网络完成DM的反向过程——经过解码器D变成新图。
CLIP(Contrastive Language-Image Pre-Training) 模型是 OpenAI 在 2021 年初发布的用于匹配图像和文本的预训练神经网络模型,是近年来在多模态研究领域的经典之作。OpenAI 收集了 4 亿对图像 - 文本对(一张图像和它对应的文本描述),分别将图像和文本进行编码,使用 metric learning进行训练。希望通过对比学习,模型能够学习到图像 - 文本对的匹配关系。
CLIP模型共有3个阶段:1阶段用作训练,2、3阶段用作推理。
通过计算目标图像和对应文本描述的余弦相似度从而获取预测值。CLIP第一阶段主要包含以下两个子模型;
这里举例一个包含N个文本-图像对的训练batch,对提取的文本特征和图像特征进行训练的过程:
CLIP的预测推理过程主要有以下两步:
A photo of {object}.
,然后再送入Text Encoder得到对应的文本特征。如果预测类别的数目为N,那么将得到N个文本特征。zero-shot :零样本学习,域外泛化问题。利用训练集数据训练模型,使得模型能够对测试集的对象进行分类,但是训练集类别和测试集类别之间没有交集,期间需要借助类别的描述,来建立训练集和测试集之间的联系,从而使得模型有效。
在计算机视觉中,即便想迁移VGG、MobileNet这种预训练模型,也需要新数据经过预训练、微调等手段,才能学习新数据集所持有的数据特征,而CLIP可以直接实现zero-shot的图像分类,即:不需要训练任何新数据,就能在某个具体下游任务上实现分类,这也是CLIP亮点和强大之处。
我的猜测:CLIP的zero-shot能力是依赖于它预训练的4亿对图像-文本对,样本空间非常大,下游任务的类别也不过是CLIP样本空间的子集,并不是真正的零样本学习,和解决域外泛化问题。和人脸比对的原理相似,依靠大量样本来学习分类对象的特征空间,区别在于人脸比对是image-to-image,CLIP是image-to-text。
向模型提供8个示例图像及其文本描述,并比较相应特征之间的相似性
# images in skimage to use and their textual descriptions
descriptions = {
"page": "a page of text about segmentation",
"chelsea": "a facial photo of a tabby cat",
"astronaut": "a portrait of an astronaut with the American flag",
"rocket": "a rocket standing on a launchpad",
"motorcycle_right": "a red motorcycle standing in a garage",
"camera": "a person looking at a camera on a tripod",
"horse": "a black-and-white silhouette of a horse",
"coffee": "a cup of coffee on a saucer"
}
from torchvision.datasets import CIFAR100
cifar100 = CIFAR100(os.path.expanduser(“~/.cache”), transform=preprocess, download=True)
text_descriptions = [f"This is a photo of a {
label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
plt.figure(figsize=(16, 16))
for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis(“off”)
plt<span class="token punctuation">.</span>subplot<span class="token punctuation">(</span><span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span> <span class="token operator">*</span> i <span class="token operator">+</span> <span class="token number">2</span><span class="token punctuation">)</span>
y <span class="token operator">=</span> np<span class="token punctuation">.</span>arange<span class="token punctuation">(</span>top_probs<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>grid<span class="token punctuation">(</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>barh<span class="token punctuation">(</span>y<span class="token punctuation">,</span> top_probs<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>gca<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>invert_yaxis<span class="token punctuation">(</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>gca<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>set_axisbelow<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>yticks<span class="token punctuation">(</span>y<span class="token punctuation">,</span> <span class="token punctuation">[</span>cifar100<span class="token punctuation">.</span>classes<span class="token punctuation">[</span>index<span class="token punctuation">]</span> <span class="token keyword">for</span> index <span class="token keyword">in</span> top_labels<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>xlabel<span class="token punctuation">(</span><span class="token string">"probability"</span><span class="token punctuation">)</span>
plt.subplots_adjust(wspace=0.5)
plt.show()
、
Dreambooth:会使用正则化。通常只用少量图片做输入微调,就可以做一些其他扩散模型不能或者不擅长的事情——具备个性化结果的能力,既包括文本到图像模型生成的结果,也包括用户输入的任何图片;
text-inversion:通过控制文本到图像的管道,标记特定的单词,在文本提示中使用,以实现对生成图像的细粒度控制;
LoRA:大型语言模型的低阶自适应,简化过程降低硬件需求,详情请学习LoRA模型原理;
Hypernetwork:这是连接到Stable Diffusion模型上的一个小型神经网络,是噪声预测器U-Net的交叉互视(cross-attention)模块;
四个主流模型的区别:
GAN生成对抗模型、VAE变微分自动编码器、流模型、DM扩散模型
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。