让我是由文本引导扩散模型 [38] 使用文本提示P和随机种子s生成的图像。我们的目标是仅在编辑的提示p ∗ 的引导下编辑输入图像,从而产生编辑图像i ∗。例如,考虑从提示 “我的新自行车” 生成的图像,并假设用户想要编辑自行车的颜色、其材料,或者甚至用踏板车替换它,同时保留原始图像的外观和结构。对于用户来说,一个直观的界面是通过进一步描述自行车的外观或将其替换为另一个单词来直接更改文本提示。与以前的作品相反,我们希望避免依赖任何用户定义的掩码来帮助或表示编辑应该发生的位置。一个简单但不成功的尝试是修复内部随机性,并使用编辑后的文本提示重新生成。不幸的是,如图2所示,这导致具有不同结构和组成的完全不同的图像。我们的主要观察结果是,生成的图像的结构和外观不仅取决于随机种子,而且还取决于像素之间通过扩散过程嵌入文本的相互作用。通过修改交叉注意层中发生的像素到文本交互,我们提供了提示到提示的图像编辑功能。更具体地说,注入输入图像I的交叉注意图使我们能够保留原始构图和结构。在第3.1节中,我们回顾了如何使用交叉注意,在第3.中,我们描述了如何利用交叉注意进行编辑。有关扩散模型的其他背景,请参阅附录A。
from typing import Union, Tuple, List, Callable, Dict, Optional
import torch
import torch.nn.functional as nnf
from diffusers import DiffusionPipeline
import numpy as np
from IPython.display import display
from PIL import Image
import abc
import ptp_utils
import seq_aligner
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model_id = "CompVis/ldm-text2im-large-256"
# load model and scheduler
ldm = DiffusionPipeline.from_pretrained(model_id).to(device)
tokenizer = ldm.tokenizer
{'cross_attention_dim'} was not found in config. Values will be initialized to default values.
{'set_alpha_to_one'} was not found in config. Values will be initialized to default values.
Prompt-to-Prompt Attnetion Controllers
Our main logic is implemented in the forward call in an AttentionControl object. The forward is called in each attention layer of the diffusion model and it can modify the input attnetion weights attn.
is_cross, place_in_unet in ("down", "mid", "up"), AttentionControl.cur_step can help us track the exact attention layer and timestamp during the diffusion iference.
class LocalBlend:
def __call__(self, x_t, attention_store, step):
k = 1
maps = attention_store["down_cross"][:2] + attention_store["up_cross"][3:6]
maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
maps = torch.cat(maps, dim=1)
maps = (maps * self.alpha_layers).sum(-1).mean(1)
mask = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k))
mask = nnf.interpolate(maps, size=(x_t.shape[2:]))
mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
mask = mask.gt(self.threshold)
mask = (mask[:1] + mask).float()
x_t = x_t[:1] + mask * (x_t - x_t[:1])
return x_t
def __init__(self, prompts: List[str], words: [List[List[str]]], threshold: float = .3):
alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
for i, (prompt, words_) in enumerate(zip(prompts, words)):
if type(words_) is str:
words_ = [words_]
for word in words_:
ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
alpha_layers[i, :, :, :, :, ind] = 1
self.alpha_layers = alpha_layers.to(device)
self.threshold = threshold
class AttentionControl(abc.ABC):
def step_callback(self, x_t):
return x_t
def between_steps(self):
def forward (self, attn, is_cross: bool, place_in_unet: str):
raise NotImplementedError
def __call__(self, attn, is_cross: bool, place_in_unet: str):
h = attn.shape[0]
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
return attn
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
class EmptyControl(AttentionControl):
def forward (self, attn, is_cross: bool, place_in_unet: str):
return attn
class AttentionStore(AttentionControl):
def get_empty_store():
return {"down_cross": [], "mid_cross": [], "up_cross": [],
"down_self": [], "mid_self": [], "up_self": []}
def forward(self, attn, is_cross: bool, place_in_unet: str):
if attn.shape[1] <= 16 ** 2: # avoid memory overhead
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
return attn
def between_steps(self):
if len(self.attention_store) == 0:
self.attention_store = self.step_store
for key in self.attention_store:
for i in range(len(self.attention_store[key])):
self.attention_store[key][i] += self.step_store[key][i]
self.step_store = self.get_empty_store()
def get_average_attention(self):
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
return average_attention
def reset(self):
super(AttentionStore, self).reset()
self.step_store = self.get_empty_store()
self.attention_store = {}
def __init__(self):
super(AttentionStore, self).__init__()
self.step_store = self.get_empty_store()
self.attention_store = {}
class AttentionControlEdit(AttentionStore, abc.ABC):
def step_callback(self, x_t):
if self.local_blend is not None:
x_t = self.local_blend(x_t, self.attention_store, self.cur_step)
return x_t
def replace_self_attention(self, attn_base, att_replace):
if att_replace.shape[2] <= 16 ** 2:
return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
return att_replace
def replace_cross_attention(self, attn_base, att_replace):
raise NotImplementedError
def forward(self, attn, is_cross: bool, place_in_unet: str):
super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
h = attn.shape[0] // (self.batch_size)
attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
attn_base, attn_repalce = attn[0], attn[1:]
if is_cross:
alpha_words = self.cross_replace_alpha[self.cur_step]
attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
attn[1:] = attn_repalce_new
attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
return attn
def __init__(self, prompts, num_steps: int,
cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
self_replace_steps: Union[float, Tuple[float, float]],
local_blend: Optional[LocalBlend]):
super(AttentionControlEdit, self).__init__()
self.batch_size = len(prompts)
self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device)
if type(self_replace_steps) is float:
self_replace_steps = 0, self_replace_steps
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
self.local_blend = local_blend
我们使用Image文本引导的综合模型作为主干。由于构图和几何形状主要是在64 × 64分辨率下确定的,因此我们仅使用超分辨率过程来适应文本到图像的扩散模型。回想一下,每个扩散步骤t都包括从嘈杂的图像zt预测噪声 ,并使用U形网络进行文本嵌入 ψ§ [。在最后一步,这个过程产生生成的图像I = z0。最重要的是,两种模式之间的相互作用发生在噪声预测期间,其中使用交叉注意层融合视觉和文本特征的嵌入,这些交叉注意层为每个文本令牌生成空间注意图。
