当前位置:   article > 正文

【CV学习笔记】之tensorrt篇之Vision transformer_vit_base_patch16_224_in21k.pth

vit_base_patch16_224_in21k.pth

1、摘要

本次学习内容主要学习了vision transformer的网络结构,并在cpp和python中实现了后处理代码(其实没啥后处理的,取最大值即可),同时加强了对transformer原理的理解,主要是为了学习detr等模型做铺垫。

原vit学习地址:Vision Transformer详解

个人学习代码仓库地址: https://github.com/Rex-LK/tensorrt_learning
欢迎正在学习或者想学的CV的同学进群一起讨论与学习,v:Rex1586662742,q群:468713665

2、vit简介

Transformer 原本是在nlp领域提出的模型,后来发现在图像处理领域也有很好的效果,并且transformer日益成为图像处理中的主干网络,甚至在不久的将来,有替换掉cnn的可能,于是,学习transformer是十分必要的。

transformer的主要结构都是围绕了qkv三个向量来运作的,其中

Q为query,可认为是查询向量,当前向量与其他向量计算注意力

K为key,可认为是被查向量,其他向量与当前向量计算注意力

V为注意力计算的结果

其主要过程可利用一个公式表示

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T / d k ) V Attention(Q,K,V) = softmax(QK^T/\sqrt{d_k})V Attention(Q,K,V)=softmax(QKT/dk )V

其中 d k d_k dk为向量长度,用于减小不同输入向量之间的差距。

vit是基于transformer的基础上,将图像均分为多个patch,然后再将每个patch作为一个token输入到模型中,计算每个patch之间的注意力,引用原论文中的一张图片。

在这里插入图片描述

上图很好的反应了vit的结构,其中每个patch的位置信息是独一无二的,如果不加上Position Embedding的话,打乱patch的顺序是不改变注意力的结果的,这显然不适用于图像处理,transformer中使用了固定的位置编码,而在vit中使用了可以训练的位置编码。

在经过一些列的Transformer Encode块之后堆叠之后,输入维度为[197,768],输出维度也为[197,768],然后采用MLP直接进行分类,下图是原文作者绘制的网络结构图,简直是太清晰了,不愧是B站导师,由此,vit模型的网络结构图已经学习的差不多了,更加详细的代码和讲解可以参考讲解视频,现在就是要将这个模型在tensorrt中跑起来。

在这里插入图片描述

3、tensorrt加速

3.1、pytorch2onnx

原模型可以在https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth进行下载,也可以到百度网盘上进行下载,下载完成后,进入到demo/vit/文件夹中,然后原项目中的数据集进行训练,可以得到自己训练的分类模型,然后导出onnx并使用onnx_simplify

python export_onnx.py
python onnx_simplify.py
  • 1
  • 2

利用netron查看生成的vit.onnx模型,这里主要查看最下面output是否有问题。导出onnx后,可以利用predict.py和infer-onnxruntime.py进行推理来验证导出onnx的正确性

在torch下模型的输出为

tensor([0.8261, 0.0730, 0.0128, 0.0589, 0.0292])
  • 1

onnxruntime下的输出为

[0.8260881  0.0730136  0.01283037 0.05891288 0.02915508]
  • 1
3.2、python -tensorrt加速

由于分类模型没有后处理,直接获得模型输出中的最大值即可。

3.3、cpp-tensorrt加速

其构造engine的模型与其他模型一致,后处理只需要一行即可搞定

int predict_label = std::max_element(prob, prob + num_classes) - prob;
  • 1

4、总结

本次学习笔记学习了transformer的基本结构,并引申到了图像领域,对vision transformer的基本网络结构有了大致的了解,本次学习内容是为了为之后的detr的学习做铺垫。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/287846
推荐阅读
相关标签
  

闽ICP备14008679号