赞
踩
Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection
对于Grounding DINO paper地详细阅读,请移步之前的文章:
由文本提示检测图像任意目标(Grounding DINO)论文阅读: Marrying DINO for Open-Set Object Detection
结合Grounding DINO paper和Grounding DINO code,真正理解Grounding DINO的设计思路和整体数据走向
本篇主要针对Grounding DINO code的使用,以及从代码端对model进一步理解:很多人看完论文后,其实对模型的构造以及数据中维度的流向是一知半解的。因此,本篇将详细讲解model中每个部件的构成。本篇较长,包含了交叉注意力机制,多池可变形注意力机制等源码的详细讲解。
因篇幅较长(4.6w+字!!!),可能存在错误的地方,请多多包涵!
1.从 GitHub 克隆 GroundingDINO 仓库
git clone https://github.com/IDEA-Research/GroundingDINO.git
如果git出现网络超时问题,可以自行从Github中下载到本地
2.将当前目录更改为 GroundingDINO 文件夹
cd GroundingDINO/
3.在当前目录下安装所需的依赖项
pip install -e .
4.下载预训练的 groundingdino-swin-tiny 模型
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
Note:如果要使用GPU运行,需要提前设置 export CUDA_HOME=/path/to/cuda-xxx,否则在运行的时候会报错: NameError: name ‘_C’ is not defined
如果忘记设置,需要设置后在重新安装GroundingDINO,即重新执行 pip install -e .
1.创建一个新的python文件命名为:grounding_dino_demo.py 放在GroundingDINO目录下,其代码如下:
from groundingdino.util.inference import load_model, load_image, predict, annotate, Model
import cv2
CONFIG_PATH = "groundingdino/config/GroundingDINO_SwinT_OGC.py" #源码自带的配置文件
CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth" #下载的权重文件
DEVICE = "cpu" #可以选择cpu/cuda
IMAGE_PATH = "../assets/demo4.jpg" #用户设置的需要读取image的路径
TEXT_PROMPT = "Two dogs with a stick." #用户给出的文本提示
BOX_TRESHOLD = 0.35 #源码给定的边界框判定阈值
TEXT_TRESHOLD = 0.25 #源码给定的文本端获取关键属性阈值
image_source, image = load_image(IMAGE_PATH)
model = load_model(CONFIG_PATH, CHECKPOINT_PATH)
boxes, logits, phrases = predict(
model=model,
image=image,
caption=TEXT_PROMPT,
box_threshold=BOX_TRESHOLD,
text_threshold=TEXT_TRESHOLD,
device=DEVICE,
)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
cv2.imwrite("../result_image/annotated_image.jpg", annotated_frame)
Note:
在运行过程中因为使用了bert做为文本编码器,当本地文件没有bert权重文件时,程序会自动请求bert-base-uncased·Hugging Face并下载权重文件,但是如果存在网络问题,即出现443时,可以进行手动下载,并将下载的文件保存到bert-base-uncased文件夹下,放在GroundingDINO目录下(需要下载的文件:config.json/pytorch_model.bin/tokenizer.json/tokenizer_config.json/vocab.txt)
图为我的输入image示例:
图为Grounding DINO的输出结果,输出box/scores/entity property of text:
Grounding DINO是一种双编码器-单解码器架构。它包含用于图像特征提取的图像主干、用于文本特征提取的文本主干,用于图像和文本特征融合的特征增强器,用于查询初始化的语言引导查询选择模块和用于框细化的跨模态解码器(共5个大模块)。
输入:(图像,文本)对 对于每个(图像、文本)对,首先分别使用**Swin
输出:最后一个解码器层的输出查询将用于预测对象框并提取相应的短语以及scores
note: 颈部模块(Neck Module)是VLP模型中的一部分,主要负责图像和文本特征的融合。 头部模块(Head Module)是VLP模型中的一部分,主要负责根据预训练的特征进行下游任务的推断和预测。
image_source, image = load_image(IMAGE_PATH)
: 它接受一个图像路径作为输入,并返回原始图像数组以及经过预处理后的图像张量。
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = Image.open(image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
函数内部的操作如下:
caption = preprocess_caption(caption=caption)
: caption(文本提示)作为输入,并返回预处理后的文本提示信息,即将输入的caption字符串进行小写化处理,并确保其以句号结尾。
def preprocess_caption(caption: str) -> str:
result = caption.lower().strip()
if result.endswith("."):
return result
return result + "."
caption = preprocess_caption(caption=caption)
函数内部的操作如下:
boxes, logits, phrases = predict(model=model,image=image,caption=TEXT_PROMPT, box_threshold=BOX_TRESHOLD,text_threshold=TEXT_TRESHOLD,device=DEVICE)
执行model推理。
def predict(
model,
image: torch.Tensor,
caption: str,
box_threshold: float,
text_threshold: float,
device: str = "cuda"
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
caption = preprocess_caption(caption=caption)
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
...
...
Note: image[None]是将image转换为一个包含一维元组的二维数组。这个操作可以用image.unsqueeze(0)来实现相同的效果。
在预处理结束后,就要开始执行模型处理流程了,首先是第一个模块:将文本经过文本编码器转换为text embedding,Grounding DINO使用的文本编码器是Bert(bert-base-uncased)。假设
captions
captions | Two dogs. with a stick. |
---|
作为例子执行流程。
具体流程如下:
输入:[captions] 输出: [tokenized]
tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(samples.device)
输入:[tokenized] 输出: [text_self_attention_masks,position_ids,cate_to_token_mask_list]
(text_self_attention_masks,position_ids, cate_to_token_mask_list,)= generate_masks_with_special_tokens_and_transfer_map(tokenized, self.specical_tokens, self.tokenizer)
GroundingDINO为了消除sentence level(失了句子中的细粒度信息)和word level(类别之间引入了不必要的依赖性)中的缺陷,使用了Sub-Sentence Level Text Feature,即引入了注意力masks来阻断不相关类别名称之间的注意力,它消除了不同类别名称之间的影响,同时保留了每个单词的特征,以便进行细粒度的理解。因此,需要重构caption的self-attention mask和生成类别标记映射。
类别:输入的一个句子中可能会包括多个待检测的实体,每个待检测的实体就是一个类别,例:Two dogs. with a stick. GroundingDINO使用特殊符号(. ?)作为分割标准,这个句子中就有两个类别,一个是Two dogs,一个是with a stick,通过引入的self-attention mask,使得两个类别之间在编码时互不干扰。 Note: GroundingDINO建议在不同类别名称之间用 .分隔。
具体细节(参考如下代码):
bs, num_token = input_ids.shape
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
for special_token in special_tokens_list:
special_tokens_mask |= input_ids == special_token
git clone https://github.com/IDEA-Research/GroundingDINO.git
idxs = torch.nonzero(special_tokens_mask
attention_mask = (torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1))
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
cate_to_token_mask_list = [[] for _ in range(bs)]
previous_col = 0
for i in range(idxs.shape[0]):
row, col = idxs[i]
if (col == 0) or (col == num_token - 1):
attention_mask[row, col, col] = True
position_ids[row, col] = 0
else:
attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
position_ids[row, previous_col + 1 : col + 1] = torch.arange(0, col - previous_col, device=input_ids.device)
c2t_maski = torch.zeros((num_token), device=input_ids.device).bool()
c2t_maski[previous_col + 1 : col] = True
cate_to_token_mask_list[row].append(c2t_maski)
previous_col = col
cate_to_token_mask_list = [torch.stack(cate_to_token_mask_listi, dim=0) for cate_to_token_mask_listi in cate_to_token_mask_list]
return attention_mask, position_ids.to(torch.long), cate_to_token_mask_list
输入:[input_ids,text_self_attention_masks,position_ids] 输出: [bert_output]
bert_output = self.bert(**tokenized_for_encoder)
输入:[bert_output[“last_hidden_state”]] 输出: [text_dict]
将bert输出的768维度映射到256维度,目的是为了对齐图像端的embedding维度。最终输出的结果存储到text_dict字典中。
encoded_text = self.feat_map(bert_output["last_hidden_state"])
text_dict = {
"encoded_text": encoded_text,
"text_token_mask": text_token_mask,
"position_ids": position_ids,
"text_self_attention_masks": text_self_attention_masks}
在获取text embedding后,GroundingDINO使用Swin Transformer做为图像backbone,对输入的image提取image embedding。
具体流程如下:
输入:[samples] 输出: [samples]
这样做的目的是将输入数据转换为嵌套的张量形式,以便后续的处理和计算。嵌套的张量可以方便地处理不同大小的输入数据,并且还可以保留原始形状的信息。
if isinstance(samples, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(samples)
输入:[samples] 输出: [features,poss]
features, poss = self.backbone(samples)
对于self.backone中的具体细节:
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
pos.append(self[1](x).to(x.tensors.dtype))
return out, pos
其中,Joiner类继承自nn.Sequential。Joiner类的作用是将backbone和position_embedding合并成一个整体的模型,并定义了前向传播的过程。在forward方法中:
self[0](tensor_list)
对输入的tensor_list
进行前向传播。这里self[0]表示Joiner类的第一个子模块,即backbone(swin Transformer)模型。xs是一个字典,它保存了从backbone模型中得到的特征图。对于backbone中的swin Transformer (self[0](tensor_list)
)细节分析:
self.patch_size[1]
整除,如果不能整除,则对
x
x
x应用水平填充。同理,判断H 。填充操作旨在确保输入图像的高度和宽度都能被patch_size整除。(patch_size=(4,4))class PatchEmbed(nn.Module):
...
def forward(self, x):
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
x = self.patch_embed(x)
2.根据特征图x的尺寸Wh和Ww,对位置编码进行插值以匹配特征图的尺寸。如果启用了位置编码(self.ape为真,本代码为False),则将特征图与位置编码相加,并进行展平和转置操作。最后经过位置的随机丢弃操作(Dropout),得到最终的特征图 x x x
Wh, Ww = x.size(2), x.size(3)
if self.ape:
absolute_pos_embed = F.interpolate(
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
3.使用一个循环处理Swin Transformer的各个层(num_layers=4),对特征图 x x x进行处理,并按照指定的out_indices采集输出。这一部分的代码还会将处理得到的特征图尺寸按照一定的约定进行整理,最终将结果存储到outs列表中。
outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f"norm{i}")
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
对于Swin Transformer块具体的细节:
Swin Transformer块使用的是BasicLayer, BasicLayer对输入特征进行了窗口化处理,并使用自注意力机制进行特征提取。注意力掩码的创建确保了每个窗口只能关注同一个窗口内部的关联。下面是代码的详细解释:
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)
h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None),)
w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), )
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition( img_mask, self.window_size ) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
4.通过循环遍历outs列表,将每个特征图和对应的掩码信息存储到outs_dict字典中,并用NestedTensor进行包装。最终,返回outs_dict作为整个前向传播方法的输出结果。
outs_dict = {}
for idx, out_i in enumerate(outs):
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
outs_dict[idx] = NestedTensor(out_i, mask)
对于position_embedding中的 (self[1](x)
)细节分析:
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
mask = tensor_list.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
# import ipdb; ipdb.set_trace()
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_tx
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
pos_y = y_embed[:, :, :, None] / dim_ty
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# import ipdb; ipdb.set_trace()
return pos
输入:[features,poss] 输出: [srcs,poss,masks]
在获取features, poss后,需要对features的维度就行转换,以能够与text embedding执行特征融合操作。由2.5.2可知,features的通道数并没有统一,而2.4.4的encoded_text(text embedding)的维度是256,因此需要将features中的多尺度特征图进行变换。
具体操作如下:
srcs = []
masks = []
for l, feat in enumerate(features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
其中,投影处理:有两个模块:
这个操作用来进行特征图的维度转换和通道归一化,为后续的特征处理和图像识别任务做准备。
self.input_proj = nn.ModuleList([nn.Sequential(nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),nn.GroupNorm(32, hidden_dim), )] )
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs: src = self.input_proj[l](features[-1].tensors)
else: src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src),masks.append(mask),poss.append(pos_l)
在提取普通的图像(srcs,masks,poss)和文本特征后(维度对齐后)(text_dict),我们将它们输入到特征增强器中进行跨模态特征融合。特征增强器包括多个特征增强层。利用可变形的自注意力来增强图像特征,并利用普通的自注意力增强文本特征,添加了一个图像到文本的交叉注意力和一个文本到图像的交叉注意力来进行特征融合。
input_query_bbox = input_query_label = attn_mask = dn_meta = None
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict)
输入:[srcs, masks,poss] 输出: [src_flatten,mask_flatten,lvl_pos_embed_flatten,spatial_shapes, level_start_index,valid_ratios]
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2)
mask = mask.flatten(1) # bs, hw
pos_embed = pos_embed.flatten(2).transpose(1, 2)
if self.num_feature_levels > 1 and self.level_embed is not None:
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
else:
lvl_pos_embed = pos_embed
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
输入:[spatial_shapes, valid_ratios] 输出: [reference_points]
计算参考点(reference points),函数 get_reference_points 接收三个参数 spatial_shapes、valid_ratios 和 device。它首先创建一个空列表 reference_points_list 用于存储各个层的参考点。接下来,通过循环遍历 spatial_shapes 列表的每个元素,其中每个元素代表一层的空间形状。在每次循环中,它会执行以下步骤:
循环结束后,通过torch.cat将reference_points_list列表中的所有参考点连接起来,形成一个(batch, H_*W_,num_levels, 2)的张量reference_points,其中num_levels是层的数量。最后,通过乘以相应的valid_ratios对参考点进行缩放。
memory, memory_text = self.encoder(
src_flatten,
pos=lvl_pos_embed_flatten,
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
key_padding_mask=mask_flatten,
memory_text=text_dict["encoded_text"],
text_attention_mask=~text_dict["text_token_mask"],
# we ~ the mask . False means use the token; True means pad the token
position_ids=text_dict["position_ids"],
text_self_attention_masks=text_dict["text_self_attention_masks"],
)
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
if self.num_layers > 0:
reference_points = self.get_reference_points(
spatial_shapes, valid_ratios, device=src.device)
输入:[position_ids] 输出: [pos_text]
通过一个函数 get_sine_pos_embed,用于生成正弦位置嵌入(sine position embedding)。该函数接收四个参数 pos_tensor、num_pos_feats、temperature 和 exchange_xy。
if position_ids is not None:
pos_text = get_sine_pos_embed(position_ids[..., None], num_pos_feats=256, exchange_xy=False)
def get_sine_pos_embed(pos_tensor: torch.Tensor,
num_pos_feats: int = 128,temperature: int = 10000, exchange_xy: bool = True,):
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
def sine_func(x: torch.Tensor):
sin_x = x * scale / dim_t
sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2)
return sin_x
pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)]
if exchange_xy:
pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
pos_res = torch.cat(pos_res, dim=-1)
return pos_res
输入:[output,memory_text,key_padding_mask,text_attention_mask] 输出:[output,memory_text]
GroundingDINO使用了6层的交叉注意力模块,具体执行操作如下:
#Bi-Direction MHA (text->image, image->text)
output, memory_text = self.fusion_layers[layer_id](
v=output,
l=memory_text,
attention_mask_v=key_padding_mask,
attention_mask_l=text_attention_mask,
)
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
v = self.layer_norm_v(v)
l = self.layer_norm_l(l)
delta_v, delta_l = self.attn(
v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l
)
# v, l = v + delta_v, l + delta_l
v = v + self.drop_path(self.gamma_v * delta_v)
l = l + self.drop_path(self.gamma_l * delta_l)
return v, l
其中,最重要的就是交叉注意力机制(self.attn),其内部的具体流程如下:
bsz, tgt_len, _ = v.size()
query_states = self.v_proj(v) * self.scale
key_states = self._shape(self.l_proj(l), -1, bsz)
value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_v_states = value_v_states.view(*proj_shape)
value_l_states = value_l_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}")
if self.stable_softmax_2d:
attn_weights = attn_weights - attn_weights.max()
if self.clamp_min_for_underflow:
attn_weights = torch.clamp(attn_weights, min=-50000)
if self.clamp_max_for_overflow:
attn_weights = torch.clamp( attn_weights, max=50000)
attn_weights_T = attn_weights.transpose(1, 2)
attn_weights_l = attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0]
if self.clamp_min_for_underflow:
attn_weights_l = torch.clamp(attn_weights_l, min=-50000)
if self.clamp_max_for_overflow:
attn_weights_l = torch.clamp( attn_weights_l, max=50000)
# mask vison for language
if attention_mask_v is not None:
attention_mask_v = (attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1))
attn_weights_l.masked_fill_(attention_mask_v, float("-inf"))
attn_weights_l = attn_weights_l.softmax(dim=-1)
# mask language for vision
if attention_mask_l is not None:
attention_mask_l = (attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1))
attn_weights.masked_fill_(attention_mask_l, float("-inf"))
attn_weights_v = attn_weights.softmax(dim=-1)
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
attn_output_v = torch.bmm(attn_probs_v, value_l_states)
attn_output_l = torch.bmm(attn_probs_l, value_v_states)
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}")
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
raise ValueError(f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}")
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output_v = attn_output_v.transpose(1, 2)
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
attn_output_l = attn_output_l.transpose(1, 2)
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
attn_output_v = self.out_v_proj(attn_output_v)
attn_output_l = self.out_l_proj(attn_output_l)
return attn_output_v, attn_output_l
输入:[memory_text,text_self_attention_masks,text_attention_mask,pos_text] 输出: [memory_text]
在执行完交叉注意力机制后,对文本端执行自注意力已增强文本特征,
if self.text_layers: # 执行text的self-attention
memory_text = self.text_layers[layer_id](
src=memory_text.transpose(0, 1),
src_mask=~text_self_attention_masks, # note we use ~ for mask here
src_key_padding_mask=text_attention_mask,
pos=(pos_text.transpose(0, 1) if pos_text is not None else None)).transpose(0, 1)
具体操作如下:
if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]:
src_mask = src_mask.repeat(self.nhead, 1, 1)
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
输入:[output,pos,reference_points,spatial_shapes,level_start_index,key_padding_mask] 输出: [output]
对于可变形自注意力增强图像特征的执行如下:
output = layer(
src=output,
pos=pos,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask,)
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None):
src2 = self.self_attn(
query=self.with_pos_embed(src, pos),
reference_points=reference_points,
value=src,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask,
)
src = src + self.dropout1(src2)
src = self.norm1(src)
src = self.forward_ffn(src)
return src
其中可变形自注意力具体操作:
if value is None:
value = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], float(0))
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs,num_query,self.num_heads,self.num_levels,self.num_points,)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = (reference_points[:, :, None, :, None, :]+ sampling_offsets / offset_normalizer[None, None, None, :, None, :])
elif reference_points.shape[-1] == 4:
sampling_locations = (reference_points[:, :, None, :, None, :2]+ sampling_offsets/ self.num_points* reference_points[:, :, None, :, None, 2:]* 0.5)
else:
raise ValueError("Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1]))
if torch.cuda.is_available() and value.is_cuda:
halffloat = False
if value.dtype == torch.float16:
halffloat = True
value = value.float()
sampling_locations = sampling_locations.float()
attention_weights = attention_weights.float()
output = MultiScaleDeformableAttnFunction.apply(value,spatial_shapes,level_start_index,sampling_locations,attention_weights,self.im2col_step,)
if halffloat:
output = output.half()
else:
output = multi_scale_deformable_attn_pytorch(value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
if not self.batch_first:
output = output.permute(1, 0, 2)
return output
输入:[memory,mask_flatten,spatial_shapes,text_dict] 输出: [refpoint_embed, tgt,init_box_proposal]
语言引导的查询选择模块,以选择与输入文本更相关的特征作为解码器查询。num_query是解码器中的查询数,在实现中设置为900,并且使用混合查询选择来初始化解码器查询。每个解码器查询分别包含两部分:内容部分和位置部分。
(这部分后续在补充)
if self.two_stage_type == "standard":
output_memory, output_proposals = gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
output_memory = self.enc_output_norm(self.enc_output(output_memory))
if text_dict is not None:
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict)
else:
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
topk_logits = enc_outputs_class_unselected.max(-1)[0]
enc_outputs_coord_unselected = (self.enc_out_bbox_embed(output_memory) + output_proposals)
topk = self.num_queries
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
# gather boxes
refpoint_embed_undetach = torch.gather( enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
refpoint_embed_ = refpoint_embed_undetach.detach()
init_box_proposal = torch.gather(output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)).sigmoid()
# gather tgt
tgt_undetach = torch.gather(output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model))
if self.embed_init_tgt:
tgt_ = (self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1))
else:
tgt_ = tgt_undetach.detach()
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
elif self.two_stage_type == "no":
tgt_ = (self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1))
refpoint_embed_ = (self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1))
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
if self.num_patterns > 0:
tgt_embed = tgt.repeat(1, self.num_patterns, 1)
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(self.num_queries, 1)
tgt = tgt_embed + tgt_pat
init_box_proposal = refpoint_embed_.sigmoid()
else:
raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
输入:[…] 输出: [hs, references]
跨模态解码器来组合图像和文本模态特征,每个跨模态查询被送到自注意力层、用于组合图像特征的图像交叉注意力层、用来组合文本特征的文本交叉注意力层以及每个跨模态解码器层中的FFN层。具体流程:
def forward(self,tgt,memory,tgt_mask: Optional[Tensor] = None,memory_mask: Optional[Tensor] = None,tgt_key_padding_mask: Optional[Tensor] = None,memory_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None,refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2level_start_index: Optional[Tensor] = None, # num_levelsspatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2valid_ratios: Optional[Tensor] = None,memory_text: Optional[Tensor] = None,text_attention_mask: Optional[Tensor] = None,):
output = tgt
intermediate = []
reference_points = refpoints_unsigmoid.sigmoid()
ref_points = [reference_points]
for layer_id, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = (reference_points[:, :, None]* torch.cat([valid_ratios, valid_ratios], -1)[None, :])
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :])
# conditional query
raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
pos_scale = self.query_scale(output) if self.query_scale is not None else 1
query_pos = pos_scale * raw_query_pos
# main process
output = layer(tgt=output,tgt_query_pos=query_pos,tgt_query_sine_embed=query_sine_embed,tgt_key_padding_mask=tgt_key_padding_mask,tgt_reference_points=reference_points_input,memory_text=memory_text,text_attention_mask=text_attention_mask,memory=memory,memory_key_padding_mask=memory_key_padding_mask,
memory_level_start_index=level_start_index,memory_spatial_shapes=spatial_shapes,memory_pos=pos,self_attn_mask=tgt_mask,cross_attn_mask=memory_mask,)
if output.isnan().any() | output.isinf().any():
print(f"output layer_id {layer_id} is nan")
try:
num_nan = output.isnan().sum().item()
num_inf = output.isinf().sum().item()
print(f"num_nan {num_nan}, num_inf {num_inf}")
except Exception as e:
print(e)
# iter update
if self.bbox_embed is not None:
reference_before_sigmoid = inverse_sigmoid(reference_points)
delta_unsig = self.bbox_embed[layer_id](output)
outputs_unsig = delta_unsig + reference_before_sigmoid
new_reference_points = outputs_unsig.sigmoid()
reference_points = new_reference_points.detach()
ref_points.append(new_reference_points)
intermediate.append(self.norm(output))
return [[itm_out.transpose(0, 1) for itm_out in intermediate],[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],]
layer具体操作:
assert cross_attn_mask is None
# self attention
if self.self_attn is not None:
# import ipdb; ipdb.set_trace()
q = k = self.with_pos_embed(tgt, tgt_query_pos)
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
if self.use_text_cross_attention:
tgt2 = self.ca_text(self.with_pos_embed(tgt, tgt_query_pos),memory_text.transpose(0, 1),memory_text.transpose(0, 1),
key_padding_mask=text_attention_mask,)[0]
tgt = tgt + self.catext_dropout(tgt2)
tgt = self.catext_norm(tgt)
tgt2 = self.cross_attn(query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
reference_points=tgt_reference_points.transpose(0, 1).contiguous(),value=memory.transpose(0, 1),
spatial_shapes=memory_spatial_shapes,level_start_index=memory_level_start_index,
key_padding_mask=memory_key_padding_mask,).transpose(0, 1)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt = self.forward_ffn(tgt)
return tgt
最后实现了目标检测模型中的anchor更新操作,包括预测目标框的坐标和类别信息,并将结果存储在输出字典中,具体操作如下:
# deformable-detr-like anchor update
outputs_coord_list = []
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):
layer_delta_unsig = layer_bbox_embed(layer_hs)
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
outputs_coord_list.append(layer_outputs_unsig)
outputs_coord_list = torch.stack(outputs_coord_list)
# output
outputs_class = torch.stack(
[layer_cls_embed(layer_hs, text_dict) for layer_cls_embed, layer_hs in zip(self.class_embed, hs)])
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
return out
输入:[…] 输出: [boxes, logits, phrases]
根据目标检测模型输出的预测结果,进行后处理并返回结果:
1.处理模型输出:
2.模型tokenizer标记化文本:
3.基于预测结果和tokenized生成短语信息:
4.返回结果:
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
mask = prediction_logits.max(dim=1)[0] > box_threshold
logits = prediction_logits[mask] # logits.shape = (n, 256)
boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
phrases = [get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') for logit in logits]
return boxes, logits.max(dim=1)[0], phrases
通过
上述的流程,最终返回边界框boxes、logits中的最大置信度、以及根据阈值生成的短语phrases,整个过程用到了交叉注意力机制,自注意力机制,多头注意力机制,多尺度可变形注意力,并通过许多细节的处理最终完成整个从文本到图像端的对象检测。
对于一些小细节:后续补充 !
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。