当前位置:   article > 正文

YOLOv8添加DCNv3可变形卷积报错张量不匹配_dcnv3代码

dcnv3代码

 加入dcnv3代码后报错:

问题:

sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \
RuntimeError: The size of tensor a (16) must match the size of tensor b (32) at non-singleton dimension 2

       报错张量不匹配问题可能出现在上述地方也可能是其他地方,但原因都是因为输入要求的数据是16位浮点数,但出现的是32位浮点型。导致当我们把数据放在gpu上训练时要求的数据是16位,而dcnv3内的数据是32位,会放在cpu上进行训练,结果会导致一个错误是数据

张量不匹配,另一个是训练时会出现两个device,一个是gpu,另一个是cpu。

 解决办法:

将ops--_dcnv3文件夹中functions文件夹中的dcnv3_func.py中函数

_get_reference_points 与  _generate_dilation_grids中的 dtype=torch.float32 替换为 dtype=torch.float16,让代码在进行dcnv3相关的数据计算时的数据类型为16位浮点型而不是32位浮点型,让它在gpu上计算而不是cpu上计算,这样就能让 ref 与 grid 参数在gpu上进行计算。
  1. def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1):
  2. _, H_, W_, _ = spatial_shapes
  3. H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1
  4. W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1
  5. ref_y, ref_x = torch.meshgrid(
  6. torch.linspace(
  7. # pad_h + 0.5,
  8. # H_ - pad_h - 0.5,
  9. (dilation_h * (kernel_h - 1)) // 2 + 0.5,
  10. (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h,
  11. H_out,
  12. dtype=torch.float16, ##torch.float32
  13. device=device),
  14. torch.linspace(
  15. # pad_w + 0.5,
  16. # W_ - pad_w - 0.5,
  17. (dilation_w * (kernel_w - 1)) // 2 + 0.5,
  18. (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w,
  19. W_out,
  20. dtype=torch.float32,
  21. device=device))
  22. ref_y = ref_y.reshape(-1)[None] / H_
  23. ref_x = ref_x.reshape(-1)[None] / W_
  24. ref = torch.stack((ref_x, ref_y), -1).reshape(
  25. 1, H_out, W_out, 1, 2)
  26. return ref
  27. def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device):
  28. _, H_, W_, _ = spatial_shapes
  29. points_list = []
  30. x, y = torch.meshgrid(
  31. torch.linspace(
  32. -((dilation_w * (kernel_w - 1)) // 2),
  33. -((dilation_w * (kernel_w - 1)) // 2) +
  34. (kernel_w - 1) * dilation_w, kernel_w,
  35. dtype=torch.float16, #torch.float32
  36. device=device),
  37. torch.linspace(
  38. -((dilation_h * (kernel_h - 1)) // 2),
  39. -((dilation_h * (kernel_h - 1)) // 2) +
  40. (kernel_h - 1) * dilation_h, kernel_h,
  41. dtype=torch.float16, #torch.float32
  42. device=device))
  43. points_list.extend([x / W_, y / H_])
  44. grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\
  45. repeat(1, group, 1).permute(1, 0, 2)
  46. grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2)
  47. return grid

上述方法不保证一定修改正确,,可能修改了别的什么我也不知道,但是能解决该问题 。

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号