【YOLOv8改进[Head检测头]】YOLOv8换个RT-DETR head助力模型更优秀

因为YOLO的合理速度和准确性之间的权衡, 这一系列已成为最流行的实时目标检测框架。然而,观察到nms对yolo的速度和准确性产生了负面影响。最近,基于端到端变换器的检测器(DETRs)消除了传统实时检测器中的非最大抑制(NMS)等后处理步骤的需要,这些步骤一直是传统实时检测器中的瓶颈,提供了一种替代方案。然而,高昂的计算成本限制了它们的实用性,阻碍了它们充分发挥不用NMS的优势


RT-DETR-R50/ R101在COCO上实现53.1% / 54.3%的AP,在T4 GPU上实现108 / 74 FPS,在速度和精度方面都优于以前先进的yolo。此外,RT-DETR-R50在精度上比DINO-R50高出2.2%,在FPS上高出约21倍。RT - DETR - R50 / R101经过Objects365预训练,AP达到55.3% / 56.2%。

1 编码器结构




3 CCFF中的融合块

下图为 CCFF中的融合块。


二 RT-DETR检测头的代码


  1. class RTDETRDecoder(nn.Module):
  2. """
  3. Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
  4. This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes
  5. and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
  6. Transformer decoder layers to output the final predictions.
  7. """
  8. export = False # export mode
  9. def __init__(
  10. self,
  11. nc=80,
  12. ch=(512, 1024, 2048),
  13. hd=256, # hidden dim
  14. nq=300, # num queries
  15. ndp=4, # num decoder points
  16. nh=8, # num head
  17. ndl=6, # num decoder layers
  18. d_ffn=1024, # dim of feedforward
  19. dropout=0.0,
  20. act=nn.ReLU(),
  21. eval_idx=-1,
  22. # Training args
  23. nd=100, # num denoising
  24. label_noise_ratio=0.5,
  25. box_noise_scale=1.0,
  26. learnt_init_query=False,
  27. ):
  28. """
  29. Initializes the RTDETRDecoder module with the given parameters.
  30. Args:
  31. nc (int): Number of classes. Default is 80.
  32. ch (tuple): Channels in the backbone feature maps. Default is (512, 1024, 2048).
  33. hd (int): Dimension of hidden layers. Default is 256.
  34. nq (int): Number of query points. Default is 300.
  35. ndp (int): Number of decoder points. Default is 4.
  36. nh (int): Number of heads in multi-head attention. Default is 8.
  37. ndl (int): Number of decoder layers. Default is 6.
  38. d_ffn (int): Dimension of the feed-forward networks. Default is 1024.
  39. dropout (float): Dropout rate. Default is 0.
  40. act (nn.Module): Activation function. Default is nn.ReLU.
  41. eval_idx (int): Evaluation index. Default is -1.
  42. nd (int): Number of denoising. Default is 100.
  43. label_noise_ratio (float): Label noise ratio. Default is 0.5.
  44. box_noise_scale (float): Box noise scale. Default is 1.0.
  45. learnt_init_query (bool): Whether to learn initial query embeddings. Default is False.
  46. """
  47. super().__init__()
  48. self.hidden_dim = hd
  49. self.nhead = nh
  50. self.nl = len(ch) # num level
  51. self.nc = nc
  52. self.num_queries = nq
  53. self.num_decoder_layers = ndl
  54. # Backbone feature projection
  55. self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
  56. # NOTE: simplified version but it's not consistent with .pt weights.
  57. # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)
  58. # Transformer module
  59. decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
  60. self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)
  61. # Denoising part
  62. self.denoising_class_embed = nn.Embedding(nc, hd)
  63. self.num_denoising = nd
  64. self.label_noise_ratio = label_noise_ratio
  65. self.box_noise_scale = box_noise_scale
  66. # Decoder embedding
  67. self.learnt_init_query = learnt_init_query
  68. if learnt_init_query:
  69. self.tgt_embed = nn.Embedding(nq, hd)
  70. self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)
  71. # Encoder head
  72. self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
  73. self.enc_score_head = nn.Linear(hd, nc)
  74. self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)
  75. # Decoder head
  76. self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
  77. self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])
  78. self._reset_parameters()
  79. def forward(self, x, batch=None):
  80. """Runs the forward pass of the module, returning bounding box and classification scores for the input."""
  81. from ultralytics.models.utils.ops import get_cdn_group
  82. # Input projection and embedding
  83. feats, shapes = self._get_encoder_input(x)
  84. # Prepare denoising training
  85. dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(
  86. batch,
  87. self.nc,
  88. self.num_queries,
  89. self.denoising_class_embed.weight,
  90. self.num_denoising,
  91. self.label_noise_ratio,
  92. self.box_noise_scale,
  93. self.training,
  94. )
  95. embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
  96. # Decoder
  97. dec_bboxes, dec_scores = self.decoder(
  98. embed,
  99. refer_bbox,
  100. feats,
  101. shapes,
  102. self.dec_bbox_head,
  103. self.dec_score_head,
  104. self.query_pos_head,
  105. attn_mask=attn_mask,
  106. )
  107. x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
  108. if self.training:
  109. return x
  110. # (bs, 300, 4+nc)
  111. y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
  112. return y if self.export else (y, x)
  113. def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
  114. """Generates anchor bounding boxes for given shapes with specific grid size and validates them."""
  115. anchors = []
  116. for i, (h, w) in enumerate(shapes):
  117. sy = torch.arange(end=h, dtype=dtype, device=device)
  118. sx = torch.arange(end=w, dtype=dtype, device=device)
  119. grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
  120. grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
  121. valid_WH = torch.tensor([w, h], dtype=dtype, device=device)
  122. grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
  123. wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i)
  124. anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
  125. anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
  126. valid_mask = ((anchors > eps) & (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
  127. anchors = torch.log(anchors / (1 - anchors))
  128. anchors = anchors.masked_fill(~valid_mask, float("inf"))
  129. return anchors, valid_mask
  130. def _get_encoder_input(self, x):
  131. """Processes and returns encoder inputs by getting projection features from input and concatenating them."""
  132. # Get projection features
  133. x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
  134. # Get encoder inputs
  135. feats = []
  136. shapes = []
  137. for feat in x:
  138. h, w = feat.shape[2:]
  139. # [b, c, h, w] -> [b, h*w, c]
  140. feats.append(feat.flatten(2).permute(0, 2, 1))
  141. # [nl, 2]
  142. shapes.append([h, w])
  143. # [b, h*w, c]
  144. feats = torch.cat(feats, 1)
  145. return feats, shapes
  146. def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
  147. """Generates and prepares the input required for the decoder from the provided features and shapes."""
  148. bs = feats.shape[0]
  149. # Prepare input for decoder
  150. anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
  151. features = self.enc_output(valid_mask * feats) # bs, h*w, 256
  152. enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
  153. # Query selection
  154. # (bs, num_queries)
  155. topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
  156. # (bs, num_queries)
  157. batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
  158. # (bs, num_queries, 256)
  159. top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
  160. # (bs, num_queries, 4)
  161. top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)
  162. # Dynamic anchors + static content
  163. refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors
  164. enc_bboxes = refer_bbox.sigmoid()
  165. if dn_bbox is not None:
  166. refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
  167. enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
  168. embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features
  169. if self.training:
  170. refer_bbox = refer_bbox.detach()
  171. if not self.learnt_init_query:
  172. embeddings = embeddings.detach()
  173. if dn_embed is not None:
  174. embeddings = torch.cat([dn_embed, embeddings], 1)
  175. return embeddings, refer_bbox, enc_bboxes, enc_scores
  176. # TODO
  177. def _reset_parameters(self):
  178. """Initializes or resets the parameters of the model's various components with predefined weights and biases."""
  179. # Class and bbox head init
  180. bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
  181. # NOTE: the weight initialization in `linear_init` would cause NaN when training with custom datasets.
  182. # linear_init(self.enc_score_head)
  183. constant_(self.enc_score_head.bias, bias_cls)
  184. constant_(self.enc_bbox_head.layers[-1].weight, 0.0)
  185. constant_(self.enc_bbox_head.layers[-1].bias, 0.0)
  186. for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
  187. # linear_init(cls_)
  188. constant_(cls_.bias, bias_cls)
  189. constant_(reg_.layers[-1].weight, 0.0)
  190. constant_(reg_.layers[-1].bias, 0.0)
  191. linear_init(self.enc_output[0])
  192. xavier_uniform_(self.enc_output[0].weight)
  193. if self.learnt_init_query:
  194. xavier_uniform_(self.tgt_embed.weight)
  195. xavier_uniform_(self.query_pos_head.layers[0].weight)
  196. xavier_uniform_(self.query_pos_head.layers[1].weight)
  197. for layer in self.input_proj:
  198. xavier_uniform_(layer[0].weight)

三 YOLOv8换个RT-DETR head


1 总体修改



2 配置文件


  1. # Ultralytics YOLO