当前位置:   article > 正文

TensorRT加速Deformable Detr实践_deformable detr onnx

deformable detr onnx

TensorRT加速Deformable Detr实践

自TensorRT 8.4.1.5发布以来,惊喜的发现TensorRT官方实现了可变形transformer的插件。
在这里插入图片描述
这让TensorRT便捷实现加速Deformable Detr乃至今年(2022年)最新的DETR类sota模型DINO、Mask DINO成为了可能。查了一下当前网络上并没有关于Deformable Detr 的TensorRT加速的实现方法,可能大佬们都觉的太简单没有必要吧,于是就自己写了一版方便大家使用。源码地址放在了github上: https://github.com/talebolano/Tensorrt-Deformable-Detr

我使用的Deformable-Detr pytorch模型来自于mmdetection库,没有使用官方的原版。自己代码主要贡献了MultiScaleDeformableAttention层的onnx导出,通过实现一个伪MultiScaleDeformableAttention层进行symbolic的注册:

class Etmpy_MultiScaleDeformableAttnFunction(torch.autograd.Function):
    @staticmethod
    def symbolic(g,value, value_spatial_shapes, value_level_start_index,
                sampling_locations, attention_weights, im2col_step):

        return g.op('com.microsoft::MultiscaleDeformableAttnPlugin_TRT',value, value_spatial_shapes, value_level_start_index,
                    sampling_locations, attention_weights)
    @staticmethod
    def forward(ctx, value, value_spatial_shapes, value_level_start_index,
                sampling_locations, attention_weights, im2col_step):
        '''
        no real mean,just for inference
        '''
        bs, _, mum_heads, embed_dims_num_heads = value.shape
        bs ,num_queries, _, _, _, _ = sampling_locations.shape
        return value.new_zeros(bs, num_queries, mum_heads, embed_dims_num_heads)

    @staticmethod
    def backward(ctx, grad_output):
        pass   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

注册后的MultiScaleDeformableAttention层可实现onnx导出,如下图所示:
在这里插入图片描述
之后的转TensorRT就直接利用官方插件即可,没有任何困难。对于低于8.4.1.5的TensorRT版本,也可以选择把官方的插件自己编译到旧版本上。TensorRT加速后的Deformable-Detr模型的速度和效果如下图和下表所示:

GPUModelModeInference time
3090deformable_detr_twostage_refine_r50_16x2_50e_cocofp3235ms
3090deformable_detr_twostage_refine_r50_16x2_50e_cocofp1617ms

在这里插入图片描述
如果感兴趣就帮我加一颗星吧。

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/617806
推荐阅读
相关标签
  

闽ICP备14008679号