当前位置:   article > 正文

ONNX格式模型 学习笔记 (onnxRuntime部署)---用java调用yolov8模型来举例_yolov8 java

yolov8 java

ONNX(Open Neural Network Exchange)是一个开源项目,旨在建立一个开放的标准,使深度学习模型可以在不同的软件平台和工具之间轻松移动和重用

ONNX模型可以用于各种应用场景,例如机器翻译、图像识别、语音识别、自然语言处理等。

由于ONNX模型的互操作性,开发人员可以使用不同的框架来训练,模型可以更容易地在不同的框架之间转换,例如从PyTorch转换到TensorFlow,或从TensorFlow转换到MXNet等。然后将其部署到不同的环境中,例如云端、边缘设备或移动设备等

ONNX还提供了一组工具和库,帮助开发人员更容易地创建、训练和部署深度学习模型。

ONNX模型是由多个节点(node)组成的图(graph),每个节点代表一个操作或一个张量(tensor)。ONNX模型还包含了一些元数据,例如模型的版本、输入和输出张量的名称等。

onnx官网

ONNX | Home

pytorch官方使用onnx模型格式举例

(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime — PyTorch Tutorials 2.2.0+cu121 documentation

TensorFlow官方使用onnx模型格式举例

https://github.com/onnx/tutorials/blob/master/tutorials/TensorflowToOnnx-1.ipynb

Netron可视化模型结构工具

Netron

你可通过该工具看到onnx具体的模型结构,点击每层都能看到其对应的内容信息

onnxRuntime  | 提供各种编程语言推导onnx格式模型的接口

ONNX Runtime | Home

比如我需要在java环境下调用一个onnx模型,我可以先导入onnxRuntime的依赖,对数据预处理后,调用onnx格式模型正向传播导出数据,然后将数据处理成我要的数据。 

onnxRuntime也提供了其他编程语言的接口,如C++、C#、JavaScript、python等等。

实际案例举例

python部分

python下利用ultralytics从网上下载并导出yolov8的onnx格式模型,用java调用onnxruntim接口,正向传播推导模型数据。

pip install ultralytics
  1. from ultralytics import YOLO
  2. # 加载模型
  3. model = YOLO('yolov8n.pt') # 加载官方模型
  4. #加载自定义训练的模型
  5. #model = YOLO('F:\\File\\AI\\Object\\yolov8_test\\runs\\detect\\train\\weights\\best.pt')
  6. # 导出模型
  7. model.export(format='onnx')

java部分

前提安装java的opencv(Get Started - OpenCV),我这安装的是opencv480

maven依赖

  1. <dependencies>
  2. <dependency>
  3. <groupId>com.microsoft.onnxruntime</groupId>
  4. <artifactId>onnxruntime</artifactId>
  5. <version>1.12.0</version>
  6. </dependency>
  7. <!-- 加载lib目录下的opencv包 -->
  8. <dependency>
  9. <groupId>org.opencv</groupId>
  10. <artifactId>opencv</artifactId>
  11. <version>4.8.0</version>
  12. <scope>system</scope>
  13. <!--通过路径加载OpenCV480的jar包-->
  14. <systemPath>${basedir}/lib/opencv-480.jar</systemPath>
  15. </dependency>
  16. <dependency>
  17. <groupId>com.alibaba</groupId>
  18. <artifactId>fastjson</artifactId>
  19. <version>2.0.32</version>
  20. </dependency>
  21. </dependencies>

java完整代码

  1. package com.sky;
  2. //天宇 2023/12/21 20:23:13
  3. import ai.onnxruntime.*;
  4. import com.alibaba.fastjson.JSONObject;
  5. import org.opencv.core.*;
  6. import org.opencv.core.Point;
  7. import org.opencv.highgui.HighGui;
  8. import org.opencv.imgcodecs.Imgcodecs;
  9. import org.opencv.imgproc.Imgproc;
  10. import java.nio.FloatBuffer;
  11. import java.text.DecimalFormat;
  12. import java.util.*;
  13. import java.util.List;
  14. /**
  15. * onnx学习笔记 GTianyu
  16. */
  17. public class onnxLoadTest01 {
  18. public static OrtEnvironment env;
  19. public static OrtSession session;
  20. public static JSONObject names;
  21. public static long count;
  22. public static long channels;
  23. public static long netHeight;
  24. public static long netWidth;
  25. public static float srcw;
  26. public static float srch;
  27. public static float confThreshold = 0.25f;
  28. public static float nmsThreshold = 0.5f;
  29. static Mat src;
  30. public static void load(String path) {
  31. String weight = path;
  32. try{
  33. env = OrtEnvironment.getEnvironment();
  34. session = env.createSession(weight, new OrtSession.SessionOptions());
  35. OnnxModelMetadata metadata = session.getMetadata();
  36. Map<String, NodeInfo> infoMap = session.getInputInfo();
  37. TensorInfo nodeInfo = (TensorInfo)infoMap.get("images").getInfo();
  38. String nameClass = metadata.getCustomMetadata().get("names");
  39. System.out.println("getProducerName="+metadata.getProducerName());
  40. System.out.println("getGraphName="+metadata.getGraphName());
  41. System.out.println("getDescription="+metadata.getDescription());
  42. System.out.println("getDomain="+metadata.getDomain());
  43. System.out.println("getVersion="+metadata.getVersion());
  44. System.out.println("getCustomMetadata="+metadata.getCustomMetadata());
  45. System.out.println("getInputInfo="+infoMap);
  46. System.out.println("nodeInfo="+nodeInfo);
  47. System.out.println(nameClass);
  48. names = JSONObject.parseObject(nameClass.replace("\"","\"\""));
  49. count = nodeInfo.getShape()[0];//1 模型每次处理一张图片
  50. channels = nodeInfo.getShape()[1];//3 模型通道数
  51. netHeight = nodeInfo.getShape()[2];//640 模型高
  52. netWidth = nodeInfo.getShape()[3];//640 模型宽
  53. System.out.println(names.get(0));
  54. // 加载opencc需要的动态库
  55. System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
  56. }
  57. catch (Exception e){
  58. e.printStackTrace();
  59. System.exit(0);
  60. }
  61. }
  62. public static Map<Object, Object> predict(String imgPath) throws Exception {
  63. src=Imgcodecs.imread(imgPath);
  64. return predictor();
  65. }
  66. public static Map<Object, Object> predict(Mat mat) throws Exception {
  67. src=mat;
  68. return predictor();
  69. }
  70. public static OnnxTensor transferTensor(Mat dst){
  71. Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);
  72. dst.convertTo(dst, CvType.CV_32FC1, 1. / 255);
  73. float[] whc = new float[ Long.valueOf(channels).intValue() * Long.valueOf(netWidth).intValue() * Long.valueOf(netHeight).intValue() ];
  74. dst.get(0, 0, whc);
  75. float[] chw = whc2cwh(whc);
  76. OnnxTensor tensor = null;
  77. try {
  78. tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), new long[]{count,channels,netWidth,netHeight});
  79. }
  80. catch (Exception e){
  81. e.printStackTrace();
  82. System.exit(0);
  83. }
  84. return tensor;
  85. }
  86. //宽 高 类型 to 类 宽 高
  87. public static float[] whc2cwh(float[] src) {
  88. float[] chw = new float[src.length];
  89. int j = 0;
  90. for (int ch = 0; ch < 3; ++ch) {
  91. for (int i = ch; i < src.length; i += 3) {
  92. chw[j] = src[i];
  93. j++;
  94. }
  95. }
  96. return chw;
  97. }
  98. public static Map<Object, Object> predictor() throws Exception{
  99. srcw = src.width();
  100. srch = src.height();
  101. System.out.println("width:"+srcw+" hight:"+srch);
  102. System.out.println("resize: \n width:"+netWidth+" hight:"+netHeight);
  103. float scaleW=srcw/netWidth;
  104. float scaleH=srch/netHeight;
  105. // resize
  106. Mat dst=new Mat();
  107. Imgproc.resize(src, dst, new Size(netWidth, netHeight));
  108. // 转换成Tensor数据格式
  109. OnnxTensor tensor = transferTensor(dst);
  110. OrtSession.Result result = session.run(Collections.singletonMap("images", tensor));
  111. System.out.println("res Data: "+result.get(0));
  112. OnnxTensor res = (OnnxTensor)result.get(0);
  113. float[][][] dataRes = (float[][][])res.getValue();
  114. float[][] data = dataRes[0];
  115. // 将矩阵转置
  116. // 先将xywh部分转置
  117. float rawData[][]=new float[data[0].length][6];
  118. System.out.println(data.length-1);
  119. for(int i=0;i<4;i++){
  120. for(int j=0;j<data[0].length;j++){
  121. rawData[j][i]=data[i][j];
  122. }
  123. }
  124. // 保存每个检查框置信值最高的类型置信值和该类型下标
  125. for(int i=0;i<data[0].length;i++){
  126. for(int j=4;j<data.length;j++){
  127. if(rawData[i][4]<data[j][i]){
  128. rawData[i][4]=data[j][i]; //置信值
  129. rawData[i][5]=j-4; //类型编号
  130. }
  131. }
  132. }
  133. List<ArrayList<Float>> boxes=new LinkedList<ArrayList<Float>>();
  134. ArrayList<Float> box=null;
  135. // 置信值过滤,xywh转xyxy
  136. for(float[] d:rawData){
  137. // 置信值过滤
  138. if(d[4]>confThreshold){
  139. // xywh(xy为中心点)转xyxy
  140. d[0]=d[0]-d[2]/2;
  141. d[1]=d[1]-d[3]/2;
  142. d[2]=d[0]+d[2];
  143. d[3]=d[1]+d[3];
  144. // 置信值符合的进行插入法排序保存
  145. box=new ArrayList<Float>();
  146. for(float num:d) {
  147. box.add(num);
  148. }
  149. if(boxes.size()==0){
  150. boxes.add(box);
  151. }else {
  152. int i;
  153. for(i=0;i<boxes.size();i++){
  154. if(box.get(4)>boxes.get(i).get(4)){
  155. boxes.add(i,box);
  156. break;
  157. }
  158. }
  159. // 插入到最后
  160. if(i==boxes.size()){
  161. boxes.add(box);
  162. }
  163. }
  164. }
  165. }
  166. // 每个框分别有x1、x1、x2、y2、conf、class
  167. //System.out.println(boxes);
  168. // 非极大值抑制
  169. int[] indexs=new int[boxes.size()];
  170. Arrays.fill(indexs,1); //用于标记1保留,0删除
  171. for(int cur=0;cur<boxes.size();cur++){
  172. if(indexs[cur]==0){
  173. continue;
  174. }
  175. ArrayList<Float> curMaxConf=boxes.get(cur); //当前框代表该类置信值最大的框
  176. for(int i=cur+1;i<boxes.size();i++){
  177. if(indexs[i]==0){
  178. continue;
  179. }
  180. float classIndex=boxes.get(i).get(5);
  181. // 两个检测框都检测到同一类数据,通过iou来判断是否检测到同一目标,这就是非极大值抑制
  182. if(classIndex==curMaxConf.get(5)){
  183. float x1=curMaxConf.get(0);
  184. float y1=curMaxConf.get(1);
  185. float x2=curMaxConf.get(2);
  186. float y2=curMaxConf.get(3);
  187. float x3=boxes.get(i).get(0);
  188. float y3=boxes.get(i).get(1);
  189. float x4=boxes.get(i).get(2);
  190. float y4=boxes.get(i).get(3);
  191. //将几种不相交的情况排除。提示:x1y1、x2y2、x3y3、x4y4对应两框的左上角和右下角
  192. if(x1>x4||x2<x3||y1>y4||y2<y3){
  193. continue;
  194. }
  195. // 两个矩形的交集面积
  196. float intersectionWidth =Math.max(x1, x3) - Math.min(x2, x4);
  197. float intersectionHeight=Math.max(y1, y3) - Math.min(y2, y4);
  198. float intersectionArea =Math.max(0,intersectionWidth * intersectionHeight);
  199. // 两个矩形的并集面积
  200. float unionArea = (x2-x1)*(y2-y1)+(x4-x3)*(y4-y3)-intersectionArea;
  201. // 计算IoU
  202. float iou = intersectionArea / unionArea;
  203. // 对交并比超过阈值的标记
  204. indexs[i]=iou>nmsThreshold?0:1;
  205. //System.out.println(cur+" "+i+" class"+curMaxConf.get(5)+" "+classIndex+" u:"+unionArea+" i:"+intersectionArea+" iou:"+ iou);
  206. }
  207. }
  208. }
  209. List<ArrayList<Float>> resBoxes=new LinkedList<ArrayList<Float>>();
  210. for(int index=0;index<indexs.length;index++){
  211. if(indexs[index]==1) {
  212. resBoxes.add(boxes.get(index));
  213. }
  214. }
  215. boxes=resBoxes;
  216. System.out.println("boxes.size : "+boxes.size());
  217. for(ArrayList<Float> box1:boxes){
  218. box1.set(0,box1.get(0)*scaleW);
  219. box1.set(1,box1.get(1)*scaleH);
  220. box1.set(2,box1.get(2)*scaleW);
  221. box1.set(3,box1.get(3)*scaleH);
  222. }
  223. System.out.println("boxes: "+boxes);
  224. //detect(boxes);
  225. Map<Object,Object> map=new HashMap<Object,Object>();
  226. map.put("boxes",boxes);
  227. map.put("classNames",names);
  228. return map;
  229. }
  230. public static Mat showDetect(Map<Object,Object> map){
  231. List<ArrayList<Float>> boxes=(List<ArrayList<Float>>)map.get("boxes");
  232. JSONObject names=(JSONObject) map.get("classNames");
  233. Imgproc.resize(src,src,new Size(srcw,srch));
  234. // 画框,加数据
  235. for(ArrayList<Float> box:boxes){
  236. float x1=box.get(0);
  237. float y1=box.get(1);
  238. float x2=box.get(2);
  239. float y2=box.get(3);
  240. float config=box.get(4);
  241. String className=(String)names.get((int)box.get(5).intValue());;
  242. Point point1=new Point(x1,y1);
  243. Point point2=new Point(x2,y2);
  244. Imgproc.rectangle(src,point1,point2,new Scalar(0,0,255),2);
  245. String conf=new DecimalFormat("#.###").format(config);
  246. Imgproc.putText(src,className+" "+conf,new Point(x1,y1-5),0,0.5,new Scalar(255,0,0),1);
  247. }
  248. HighGui.imshow("image",src);
  249. HighGui.waitKey();
  250. return src;
  251. }
  252. public static void main(String[] args) throws Exception {
  253. String modelPath="C:\\Users\\tianyu\\IdeaProjects\\test1\\src\\main\\java\\com\\sky\\best.onnx";
  254. String path="C:\\Users\\tianyu\\IdeaProjects\\test1\\src\\main\\resources\\img\\img.png";
  255. onnxLoadTest01.load(modelPath);
  256. Map<Object,Object> map=onnxLoadTest01.predict(path);
  257. showDetect(map);
  258. }
  259. }

效果:


参考文献:

使用 java-onnx 部署 yolovx 目标检测_java onnx-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/668996
推荐阅读
相关标签
  

闽ICP备14008679号