当前位置:   article > 正文

【YOLOv8改进[Head检测头]】YOLOv8换个RT-DETR head助力模型更优秀

【YOLOv8改进[Head检测头]】YOLOv8换个RT-DETR head助力模型更优秀

一RT-DETR

官方论文地址https://arxiv.org/pdf/2304.08069.pdf

因为YOLO的合理速度和准确性之间的权衡, 这一系列已成为最流行的实时目标检测框架。然而,观察到nms对yolo的速度和准确性产生了负面影响。最近,基于端到端变换器的检测器(DETRs)消除了传统实时检测器中的非最大抑制(NMS)等后处理步骤的需要,这些步骤一直是传统实时检测器中的瓶颈,提供了一种替代方案。然而,高昂的计算成本限制了它们的实用性,阻碍了它们充分发挥不用NMS的优势

在本文中,提出了实时检测变换器(RT-DETR),这是我们所知的第一个解决上述困境的实时端到端目标检测器。在先进的DETR基础上分两步构建RT-DETR:首先专注于在提高速度的同时保持精度,其次是在提高精度的同时保持速度。具体而言,设计了一种高效的混合编码器,通过解耦尺度内相互作用和跨尺度融合来快速处理多尺度特征,以提高速度。然后,提出了不确定性最小的查询选择,为解码器提供高质量的初始查询,从而提高准确率。此外,RT-DETR支持灵活的速度调整,通过调整解码器层的数量,以适应各种场景,而无需重新训练

RT-DETR-R50/ R101在COCO上实现53.1% / 54.3%的AP,在T4 GPU上实现108 / 74 FPS,在速度和精度方面都优于以前先进的yolo。此外,RT-DETR-R50在精度上比DINO-R50高出2.2%,在FPS上高出约21倍。RT - DETR - R50 / R101经过Objects365预训练,AP达到55.3% / 56.2%。

官方代码地址DETRs Beat YOLOs on Real-time Object Detection

综上,RT-DETR模型建立在于两个关键创新:

高效混合编码器:通过解耦内部尺度交互和跨尺度融合来处理多尺度特征。这种设计显著降低了计算负担,同时保持了高性能,实现了实时目标检测

提出了不确定性最小的查询选择,为解码器提供高质量的初始查询,从而提高准确率。

1 编码器结构

下图是每个变体的编码器结构。SSE表示单尺度Transformer编码器,MSE表示多尺度Transformer编码器,CSF表示跨尺度融合。AIFI和CCFF是我们设计的混合编码器的两个模块。

2 RT-DETR

下图为RT-DETR概述。将主干最后三个阶段的特征输入到编码器中。高效混合编码器通过基于注意力的尺度内特征交互(AIFI)和基于cnn的跨尺度特征融合(CCFF)将多尺度特征转化为图像特征序列。然后,最小不确定性查询选择固定数量的编码器特征作为解码器的初始对象查询。最后,具有辅助预测头的解码器迭代优化对象查询以生成类别和框。

3 CCFF中的融合块

下图为 CCFF中的融合块。

实验结果

二 RT-DETR检测头的代码

RT-DETR检测头的代码如下所示:

  1. class RTDETRDecoder(nn.Module):
  2. """
  3. Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
  4. This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes
  5. and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
  6. Transformer decoder layers to output the final predictions.
  7. """
  8. export = False # export mode
  9. def __init__(
  10. self,
  11. nc=80,
  12. ch=(512, 1024, 2048),
  13. hd=256, # hidden dim
  14. nq=300, # num queries
  15. ndp=4, # num decoder points
  16. nh=8, # num head
  17. ndl=6, # num decoder layers
  18. d_ffn=1024, # dim of feedforward
  19. dropout=0.0,
  20. act=nn.ReLU(),
  21. eval_idx=-1,
  22. # Training args
  23. nd=100, # num denoising
  24. label_noise_ratio=0.5,
  25. box_noise_scale=1.0,
  26. learnt_init_query=False,
  27. ):
  28. """
  29. Initializes the RTDETRDecoder module with the given parameters.
  30. Args:
  31. nc (int): Number of classes. Default is 80.
  32. ch (tuple): Channels in the backbone feature maps. Default is (512, 1024, 2048).
  33. hd (int): Dimension of hidden layers. Default is 256.
  34. nq (int): Number of query points. Default is 300.
  35. ndp (int): Number of decoder points. Default is 4.
  36. nh (int): Number of heads in multi-head attention. Default is 8.
  37. ndl (int): Number of decoder layers. Default is 6.
  38. d_ffn (int): Dimension of the feed-forward networks. Default is 1024.
  39. dropout (float): Dropout rate. Default is 0.
  40. act (nn.Module): Activation function. Default is nn.ReLU.
  41. eval_idx (int): Evaluation index. Default is -1.
  42. nd (int): Number of denoising. Default is 100.
  43. label_noise_ratio (float): Label noise ratio. Default is 0.5.
  44. box_noise_scale (float): Box noise scale. Default is 1.0.
  45. learnt_init_query (bool): Whether to learn initial query embeddings. Default is False.
  46. """
  47. super().__init__()
  48. self.hidden_dim = hd
  49. self.nhead = nh
  50. self.nl = len(ch) # num level
  51. self.nc = nc
  52. self.num_queries = nq
  53. self.num_decoder_layers = ndl
  54. # Backbone feature projection
  55. self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
  56. # NOTE: simplified version but it's not consistent with .pt weights.
  57. # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)
  58. # Transformer module
  59. decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
  60. self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)
  61. # Denoising part
  62. self.denoising_class_embed = nn.Embedding(nc, hd)
  63. self.num_denoising = nd
  64. self.label_noise_ratio = label_noise_ratio
  65. self.box_noise_scale = box_noise_scale
  66. # Decoder embedding
  67. self.learnt_init_query = learnt_init_query
  68. if learnt_init_query:
  69. self.tgt_embed = nn.Embedding(nq, hd)
  70. self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)
  71. # Encoder head
  72. self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
  73. self.enc_score_head = nn.Linear(hd, nc)
  74. self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)
  75. # Decoder head
  76. self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
  77. self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])
  78. self._reset_parameters()
  79. def forward(self, x, batch=None):
  80. """Runs the forward pass of the module, returning bounding box and classification scores for the input."""
  81. from ultralytics.models.utils.ops import get_cdn_group
  82. # Input projection and embedding
  83. feats, shapes = self._get_encoder_input(x)
  84. # Prepare denoising training
  85. dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(
  86. batch,
  87. self.nc,
  88. self.num_queries,
  89. self.denoising_class_embed.weight,
  90. self.num_denoising,
  91. self.label_noise_ratio,
  92. self.box_noise_scale,
  93. self.training,
  94. )
  95. embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
  96. # Decoder
  97. dec_bboxes, dec_scores = self.decoder(
  98. embed,
  99. refer_bbox,
  100. feats,
  101. shapes,
  102. self.dec_bbox_head,
  103. self.dec_score_head,
  104. self.query_pos_head,
  105. attn_mask=attn_mask,
  106. )
  107. x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
  108. if self.training:
  109. return x
  110. # (bs, 300, 4+nc)
  111. y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
  112. return y if self.export else (y, x)
  113. def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
  114. """Generates anchor bounding boxes for given shapes with specific grid size and validates them."""
  115. anchors = []
  116. for i, (h, w) in enumerate(shapes):
  117. sy = torch.arange(end=h, dtype=dtype, device=device)
  118. sx = torch.arange(end=w, dtype=dtype, device=device)
  119. grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
  120. grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
  121. valid_WH = torch.tensor([w, h], dtype=dtype, device=device)
  122. grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
  123. wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i)
  124. anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
  125. anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
  126. valid_mask = ((anchors > eps) & (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
  127. anchors = torch.log(anchors / (1 - anchors))
  128. anchors = anchors.masked_fill(~valid_mask, float("inf"))
  129. return anchors, valid_mask
  130. def _get_encoder_input(self, x):
  131. """Processes and returns encoder inputs by getting projection features from input and concatenating them."""
  132. # Get projection features
  133. x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
  134. # Get encoder inputs
  135. feats = []
  136. shapes = []
  137. for feat in x:
  138. h, w = feat.shape[2:]
  139. # [b, c, h, w] -> [b, h*w, c]
  140. feats.append(feat.flatten(2).permute(0, 2, 1))
  141. # [nl, 2]
  142. shapes.append([h, w])
  143. # [b, h*w, c]
  144. feats = torch.cat(feats, 1)
  145. return feats, shapes
  146. def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
  147. """Generates and prepares the input required for the decoder from the provided features and shapes."""
  148. bs = feats.shape[0]
  149. # Prepare input for decoder
  150. anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
  151. features = self.enc_output(valid_mask * feats) # bs, h*w, 256
  152. enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
  153. # Query selection
  154. # (bs, num_queries)
  155. topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
  156. # (bs, num_queries)
  157. batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
  158. # (bs, num_queries, 256)
  159. top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
  160. # (bs, num_queries, 4)
  161. top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)
  162. # Dynamic anchors + static content
  163. refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors
  164. enc_bboxes = refer_bbox.sigmoid()
  165. if dn_bbox is not None:
  166. refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
  167. enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
  168. embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features
  169. if self.training:
  170. refer_bbox = refer_bbox.detach()
  171. if not self.learnt_init_query:
  172. embeddings = embeddings.detach()
  173. if dn_embed is not None:
  174. embeddings = torch.cat([dn_embed, embeddings], 1)
  175. return embeddings, refer_bbox, enc_bboxes, enc_scores
  176. # TODO
  177. def _reset_parameters(self):
  178. """Initializes or resets the parameters of the model's various components with predefined weights and biases."""
  179. # Class and bbox head init
  180. bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
  181. # NOTE: the weight initialization in `linear_init` would cause NaN when training with custom datasets.
  182. # linear_init(self.enc_score_head)
  183. constant_(self.enc_score_head.bias, bias_cls)
  184. constant_(self.enc_bbox_head.layers[-1].weight, 0.0)
  185. constant_(self.enc_bbox_head.layers[-1].bias, 0.0)
  186. for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
  187. # linear_init(cls_)
  188. constant_(cls_.bias, bias_cls)
  189. constant_(reg_.layers[-1].weight, 0.0)
  190. constant_(reg_.layers[-1].bias, 0.0)
  191. linear_init(self.enc_output[0])
  192. xavier_uniform_(self.enc_output[0].weight)
  193. if self.learnt_init_query:
  194. xavier_uniform_(self.tgt_embed.weight)
  195. xavier_uniform_(self.query_pos_head.layers[0].weight)
  196. xavier_uniform_(self.query_pos_head.layers[1].weight)
  197. for layer in self.input_proj:
  198. xavier_uniform_(layer[0].weight)

三 YOLOv8换个RT-DETR head

ultralytics的版本为8.1.47,如下图所示:

1 总体修改

RT-DETR检测头已经集成YOLOv8的项目里面了,我们可以直接使用。

注意:使用了RT-DETR检测头后,需要增加epoch

2 配置文件

yolov8_RT-DETR.yaml的内容如下所示:

  1. # Ultralytics YOLO
    声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/478922
    推荐阅读
    相关标签