赞
踩
问题:
sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \
RuntimeError: The size of tensor a (16) must match the size of tensor b (32) at non-singleton dimension 2
报错张量不匹配问题可能出现在上述地方也可能是其他地方,但原因都是因为输入要求的数据是16位浮点数,但出现的是32位浮点型。导致当我们把数据放在gpu上训练时要求的数据是16位,而dcnv3内的数据是32位,会放在cpu上进行训练,结果会导致一个错误是数据
张量不匹配,另一个是训练时会出现两个device,一个是gpu,另一个是cpu。
解决办法:
将ops--_dcnv3文件夹中functions文件夹中的dcnv3_func.py中函数
_get_reference_points 与 _generate_dilation_grids中的 dtype=torch.float32 替换为 dtype=torch.float16,让代码在进行dcnv3相关的数据计算时的数据类型为16位浮点型而不是32位浮点型,让它在gpu上计算而不是cpu上计算,这样就能让 ref 与 grid 参数在gpu上进行计算。
def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1): _, H_, W_, _ = spatial_shapes H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 ref_y, ref_x = torch.meshgrid( torch.linspace( # pad_h + 0.5, # H_ - pad_h - 0.5, (dilation_h * (kernel_h - 1)) // 2 + 0.5, (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h, H_out, dtype=torch.float16, ##torch.float32 device=device), torch.linspace( # pad_w + 0.5, # W_ - pad_w - 0.5, (dilation_w * (kernel_w - 1)) // 2 + 0.5, (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w, W_out, dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / H_ ref_x = ref_x.reshape(-1)[None] / W_ ref = torch.stack((ref_x, ref_y), -1).reshape( 1, H_out, W_out, 1, 2) return ref def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device): _, H_, W_, _ = spatial_shapes points_list = [] x, y = torch.meshgrid( torch.linspace( -((dilation_w * (kernel_w - 1)) // 2), -((dilation_w * (kernel_w - 1)) // 2) + (kernel_w - 1) * dilation_w, kernel_w, dtype=torch.float16, #torch.float32 device=device), torch.linspace( -((dilation_h * (kernel_h - 1)) // 2), -((dilation_h * (kernel_h - 1)) // 2) + (kernel_h - 1) * dilation_h, kernel_h, dtype=torch.float16, #torch.float32 device=device)) points_list.extend([x / W_, y / H_]) grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\ repeat(1, group, 1).permute(1, 0, 2) grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2) return grid
上述方法不保证一定修改正确,,可能修改了别的什么我也不知道,但是能解决该问题 。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。