赞
踩
跟着这位同志的视频将yolo V8 做了替换,很快就跑通了。
Bilibili 飞飞
只不过,Snake Conv对于我的问题上,准确率不升反降。有观察到在一些feature类似,但是又不能算作一个类别的,它会容易混淆。我的问题比较tricky,因为数据量不平衡和数据少(就是力求不平衡和数据少),需要找到一个不需要在意数据量的方法。Emmm…没关系,或许你们的问题可以用得上呢~
除此之外还需要把weight转换成tensorRT, 飞飞视频的代码里有几个地方需要改一下,就能顺利输出TensorRT:
下面的zero, max_y, max_x 需要确保都是在torch cuda上的量。所以需要看着添加.to(device)
def _bilinear_interpolate_3D(self, input_feature, y, x): device = input_feature.device y = y.reshape([-1]).float() x = x.reshape([-1]).float() zero = torch.zeros([]).int().to(device) max_y = torch.tensor(self.width - 1) max_x = torch.tensor(self.height - 1) # find 8 grid locations y0 = torch.floor(y).int() y1 = y0 + 1 x0 = torch.floor(x).int() x1 = x0 + 1 # clip out coordinates exceeding feature map volume y0 = torch.clamp(y0, zero, max_y.to(device)) y1 = torch.clamp(y1, zero, max_y.to(device)) x0 = torch.clamp(x0, zero, max_x.to(device)) x1 = torch.clamp(x1, zero, max_x.to(device)) input_feature_flat = input_feature.flatten() input_feature_flat = input_feature_flat.reshape( self.num_batch, self.num_channels, self.width, self.height) input_feature_flat = input_feature_flat.permute(0, 2, 3, 1) input_feature_flat = input_feature_flat.reshape(-1, self.num_channels) dimension = self.height * self.width base = torch.arange(self.num_batch) * dimension base = base.reshape([-1, 1]).float() repeat = torch.ones([self.num_points * self.width * self.height ]).unsqueeze(0) repeat = repeat.float() base = torch.matmul(base, repeat) base = base.reshape([-1]) base = base.to(device) base_y0 = base + y0 * self.height base_y1 = base + y1 * self.height # top rectangle of the neighbourhood volume index_a0 = base_y0 - base + x0 index_c0 = base_y0 - base + x1 # bottom rectangle of the neighbourhood volume index_a1 = base_y1 - base + x0 index_c1 = base_y1 - base + x1 # get 8 grid values value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device) value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device) value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device) value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device) # find 8 grid locations y0 = torch.floor(y).int() y1 = y0 + 1 x0 = torch.floor(x).int() x1 = x0 + 1 # clip out coordinates exceeding feature map volume y0 = torch.clamp(y0, zero, max_y.to(device) + 1) y1 = torch.clamp(y1, zero, max_y.to(device) + 1) x0 = torch.clamp(x0, zero, max_x.to(device) + 1) x1 = torch.clamp(x1, zero, max_x.to(device) + 1) x0_float = x0.float() x1_float = x1.float() y0_float = y0.float() y1_float = y1.float() vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(device) vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(device) vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(device) vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(device) outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 + value_c1 * vol_c1) if self.morph == 0: outputs = outputs.reshape([ self.num_batch, self.num_points * self.width, 1 * self.height, self.num_channels, ]) outputs = outputs.permute(0, 3, 1, 2) else: outputs = outputs.reshape([ self.num_batch, 1 * self.width, self.num_points * self.height, self.num_channels, ]) outputs = outputs.permute(0, 3, 1, 2) return outputs def deform_conv(self, input, offset, if_offset): y, x = self._coordinate_map_3D(offset, if_offset) deformed_feature = self._bilinear_interpolate_3D(input, y, x) return deformed_feature
Anyway, 反正结构看起来挺颠的
然后我有尝试更改了能改的层,随机的都尝试了一下,最后都没有原始的c2f效果好。
# YOLOv8.0n backbone backbone: # [from, repeats, module, args] - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 - [-1, 3, C2f_DySnakeConv, [128, True]] - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 - [-1, 6, C2f_DySnakeConv, [256, True]] - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 - [-1, 6, C2f_DySnakeConv, [512, True]] - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 - [-1, 3, C2f_DySnakeConv, [1024, True]] - [-1, 1, SPPF, [1024, 5]] # 9 # YOLOv8.0n head head: - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - [[-1, 6], 1, Concat, [1]] # cat backbone P4 - [-1, 3, C2f_DySnakeConv, [512]] # 12 - [-1, 1, nn.Upsample, [None, 2, 'nearest']] - [[-1, 4], 1, Concat, [1]] # cat backbone P3 - [-1, 3, C2f_DySnakeConv, [256]] # 15 (P3/8-small) - [-1, 1, Conv, [256, 3, 2]] - [[-1, 12], 1, Concat, [1]] # cat head P4 - [-1, 3, C2f_DySnakeConv, [512]] # 18 (P4/16-medium) - [-1, 1, Conv, [512, 3, 2]] - [[-1, 9], 1, Concat, [1]] # cat head P5 - [-1, 3, C2f_DySnakeConv, [1024]] # 21 (P5/32-large) - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
后续我会再check一遍snake conv的论文,顺便猜测or理解一下,为啥在我的问题上,它效果不太行~ (除了数据量之外)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。