当前位置:   article > 正文

android 使用 onnxruntime 部署 yolov5_face_landmark 人脸检测_onnxruntime android

onnxruntime android

下面是使用 opencv-camera,实时处理区域内人脸检测 android 推理 demo。

首先是整合 opcv-camera 进去:

为了方便直接将整个 opencv-android-sdk 全部导入:

 然后在原来的项目模块app中添加 opencv的 java 相关依赖,主要添加红色两行:
app/build.grandle

  1. dependencies {
  2. implementation fileTree(dir: 'libs', include: ['*.jar'])
  3. implementation 'androidx.appcompat:appcompat:1.4.1'
  4. implementation 'com.google.android.material:material:1.5.0'
  5. implementation 'androidx.constraintlayout:constraintlayout:2.1.3'
  6. testImplementation 'junit:junit:4.13.2'
  7. androidTestImplementation 'androidx.test.ext:junit:1.1.3'
  8. androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'
  9. implementation project(':opencvsdk')
  10. }

最后在项目中要使用opencv的地方加载jni库,可以添加到 MainActivity 中:

System.loadLibrary("opencv_java4"); 或者 OpenCVLoader.initDebug();

要使用 opencv-camera,MainActivity 继承 CameraActivity,然后在回调函数中获取每一帧进行处理,比如下面对每一帧添加识别区域边框:

  1. // 获取每一帧回调数据
  2. private CameraBridgeViewBase.CvCameraViewListener2 cameraViewListener2 = new CameraBridgeViewBase.CvCameraViewListener2() {
  3. @Override
  4. public void onCameraViewStarted(int width, int height) {
  5. System.out.println("开始预览 width="+width+",height="+height);
  6. // 预览界面是 640*480,模型输入时 320*320,计算识别区域坐标
  7. int detection_x1 = (640 - OnnxUtil.w)/2;
  8. int detection_x2 = (640 - OnnxUtil.w)/2 + OnnxUtil.w;
  9. int detection_y1 = (480 - OnnxUtil.h)/2;
  10. int detection_y2 = (480 - OnnxUtil.h)/2 + OnnxUtil.h;;
  11. System.out.println("识别区域:"+"("+detection_x1+","+detection_y1+")"+"("+detection_x2+","+detection_y2+")");
  12. // 缓存识别区域两个点
  13. detection_p1 = new Point(detection_x1,detection_y1);
  14. detection_p2 = new Point(detection_x2,detection_y2);
  15. detection_box_color = new Scalar(255, 0, 0);
  16. detection_box_tickness = 2;
  17. }
  18. @Override
  19. public void onCameraViewStopped() {}
  20. @Override
  21. public Mat onCameraFrame(CameraBridgeViewBase.CvCameraViewFrame frame) {
  22. // 获取 cv::Mat
  23. Mat mat = frame.rgba();
  24. // 标注识别区域
  25. Imgproc.rectangle(mat, detection_p1, detection_p2,detection_box_color,detection_box_tickness);
  26. return mat;
  27. }
  28. };

在界面中开启预览:

  1. ui资源:
  2. <org.opencv.android.JavaCamera2View
  3. android:id="@+id/camera_view"
  4. app:layout_constraintTop_toTopOf="parent"
  5. app:layout_constraintLeft_toLeftOf="parent"
  6. android:layout_width="match_parent"
  7. android:layout_height="match_parent">
  8. </org.opencv.android.JavaCamera2View>
  9. java开启预览:
  10. private BaseLoaderCallback baseLoaderCallback = new BaseLoaderCallback(this) {
  11. @Override
  12. public void onManagerConnected(int status) {
  13. switch (status) {
  14. case LoaderCallbackInterface.SUCCESS: {
  15. if (camera2View != null) {
  16. // 设置前置还是后置摄像头 0后置 1前置
  17. camera2View.setCameraIndex(cameraId);
  18. // 注册每一帧回调
  19. camera2View.setCvCameraViewListener(cameraViewListener2);
  20. // 显示/关闭 帧率 disableFpsMeter/enableFpsMeter
  21. // 要修改字体和颜色直接修改 FpsMeter 类即可
  22. camera2View.enableFpsMeter();
  23. // 设置视图宽高和模型一致减少resize操作,模型输入一般尺寸不大,这样相机渲染fps会更高
  24. camera2View.setMaxFrameSize(win_w,win_h);
  25. // 开启
  26. camera2View.enableView();
  27. }
  28. }
  29. break;
  30. default:
  31. super.onManagerConnected(status);
  32. break;
  33. }
  34. }
  35. };

下面是全部推理 MainActivity 代码:

  1. package com.example.camera_opencv;
  2. import android.content.pm.ActivityInfo;
  3. import android.os.Bundle;
  4. import android.view.WindowManager;
  5. import com.example.camera_opencv.databinding.ActivityMainBinding;
  6. import org.opencv.android.*;
  7. import org.opencv.core.Mat;
  8. import org.opencv.core.Point;
  9. import org.opencv.core.Scalar;
  10. import org.opencv.imgproc.Imgproc;
  11. import java.util.Arrays;
  12. import java.util.List;
  13. public class MainActivity extends CameraActivity{
  14. // 动态库
  15. static {
  16. // 我们自己的jni
  17. System.loadLibrary("camera_opencv");
  18. // 新加的 opencv 的jni
  19. System.loadLibrary("opencv_java4");
  20. }
  21. private ActivityMainBinding binding;
  22. // 预览界面
  23. private JavaCamera2View camera2View;
  24. // 相机编号 0后置 1前置
  25. private int cameraId = 1;
  26. // 设置预览界面宽高,在次宽高基础上限制识别区域
  27. private int win_w = 640;
  28. private int win_h = 480;
  29. // 识别区域两个点
  30. private Point detection_p1;
  31. private Point detection_p2;
  32. private Scalar detection_box_color;
  33. private int detection_box_tickness;
  34. @Override
  35. protected void onCreate(Bundle savedInstanceState) {
  36. super.onCreate(savedInstanceState);
  37. binding = ActivityMainBinding.inflate(getLayoutInflater());
  38. setContentView(binding.getRoot());
  39. // 加载模型
  40. OnnxUtil.loadModule(getAssets());
  41. // 强制横屏
  42. setRequestedOrientation(ActivityInfo.SCREEN_ORIENTATION_LANDSCAPE);
  43. // 隐藏上方状态栏
  44. getWindow().setFlags(WindowManager.LayoutParams.FLAG_FULLSCREEN, WindowManager.LayoutParams.FLAG_FULLSCREEN);
  45. // 预览界面
  46. camera2View = findViewById(R.id.camera_view);
  47. }
  48. @Override
  49. protected List<? extends CameraBridgeViewBase> getCameraViewList() {
  50. return Arrays.asList(camera2View);
  51. }
  52. @Override
  53. public void onPause() {
  54. super.onPause();
  55. if (camera2View != null) {
  56. // 关闭预览
  57. camera2View.disableView();
  58. }
  59. }
  60. @Override
  61. public void onResume() {
  62. super.onResume();
  63. if (OpenCVLoader.initDebug()) {
  64. baseLoaderCallback.onManagerConnected(LoaderCallbackInterface.SUCCESS);
  65. } else {
  66. OpenCVLoader.initAsync(OpenCVLoader.OPENCV_VERSION, this, baseLoaderCallback);
  67. }
  68. }
  69. // 获取每一帧回调数据
  70. private CameraBridgeViewBase.CvCameraViewListener2 cameraViewListener2 = new CameraBridgeViewBase.CvCameraViewListener2() {
  71. @Override
  72. public void onCameraViewStarted(int width, int height) {
  73. System.out.println("开始预览 width="+width+",height="+height);
  74. // 预览界面是 640*480,模型输入时 320*320,计算识别区域坐标
  75. int detection_x1 = (640 - OnnxUtil.w)/2;
  76. int detection_x2 = (640 - OnnxUtil.w)/2 + OnnxUtil.w;
  77. int detection_y1 = (480 - OnnxUtil.h)/2;
  78. int detection_y2 = (480 - OnnxUtil.h)/2 + OnnxUtil.h;;
  79. System.out.println("识别区域:"+"("+detection_x1+","+detection_y1+")"+"("+detection_x2+","+detection_y2+")");
  80. // 缓存识别区域两个点
  81. detection_p1 = new Point(detection_x1,detection_y1);
  82. detection_p2 = new Point(detection_x2,detection_y2);
  83. detection_box_color = new Scalar(255, 0, 0);
  84. detection_box_tickness = 2;
  85. }
  86. @Override
  87. public void onCameraViewStopped() {}
  88. @Override
  89. public Mat onCameraFrame(CameraBridgeViewBase.CvCameraViewFrame frame) {
  90. // 获取 cv::Mat
  91. Mat mat = frame.rgba();
  92. // 标注识别区域
  93. Imgproc.rectangle(mat, detection_p1, detection_p2,detection_box_color,detection_box_tickness);
  94. // 推理并标注
  95. OnnxUtil.inference(mat,detection_p1,detection_p2);
  96. return mat;
  97. }
  98. };
  99. // 开启预览
  100. private BaseLoaderCallback baseLoaderCallback = new BaseLoaderCallback(this) {
  101. @Override
  102. public void onManagerConnected(int status) {
  103. switch (status) {
  104. case LoaderCallbackInterface.SUCCESS: {
  105. if (camera2View != null) {
  106. // 设置前置还是后置摄像头 0后置 1前置
  107. camera2View.setCameraIndex(cameraId);
  108. // 注册每一帧回调
  109. camera2View.setCvCameraViewListener(cameraViewListener2);
  110. // 显示/关闭 帧率 disableFpsMeter/enableFpsMeter
  111. // 要修改字体和颜色直接修改 FpsMeter 类即可
  112. camera2View.enableFpsMeter();
  113. // 设置视图宽高和模型一致减少resize操作,模型输入一般尺寸不大,这样相机渲染fps会更高
  114. camera2View.setMaxFrameSize(win_w,win_h);
  115. // 开启
  116. camera2View.enableView();
  117. }
  118. }
  119. break;
  120. default:
  121. super.onManagerConnected(status);
  122. break;
  123. }
  124. }
  125. };
  126. }

onnx 模型加载和推理代码:
使用的微软onnx推理框架:

implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release'

  1. package com.example.camera_opencv;
  2. import ai.onnxruntime.*;
  3. import android.content.res.AssetManager;
  4. import org.opencv.core.*;
  5. import org.opencv.dnn.Dnn;
  6. import org.opencv.imgproc.Imgproc;
  7. import java.io.ByteArrayOutputStream;
  8. import java.io.InputStream;
  9. import java.nio.FloatBuffer;
  10. import java.util.*;
  11. public class OnnxUtil {
  12. // onnxruntime 环境
  13. public static OrtEnvironment env;
  14. public static OrtSession session;
  15. // 模型输入
  16. public static int w = 0;
  17. public static int h = 0;
  18. public static int c = 3;
  19. // 标注颜色
  20. public static Scalar green = new Scalar(0, 255, 0);
  21. public static int tickness = 2;
  22. // 模型加载
  23. public static void loadModule(AssetManager assetManager){
  24. // 下面包含了多个模型
  25. // yolov5face-blazeface-640x640.onnx 3.4Mb
  26. // yolov5face-l-640x640.onnx 181Mb
  27. // yolov5face-m-640x640.onnx 83Mb
  28. // yolov5face-n-0.5-320x320.onnx 2.5Mb
  29. // yolov5face-n-0.5-640x640.onnx 4.6Mb
  30. // yolov5face-n-640x640.onnx 9.5Mb
  31. // yolov5face-s-640x640.onnx 30Mb
  32. w = 320;
  33. h = 320;
  34. c = 3;
  35. try {
  36. // 模型输入: input -> [1, 3, 320, 320] -> FLOAT
  37. // 模型输出: output -> [1, 6300, 16] -> FLOAT
  38. InputStream inputStream = assetManager.open("yolov5face-n-0.5-320x320.onnx");
  39. ByteArrayOutputStream buffer = new ByteArrayOutputStream();
  40. int nRead;
  41. byte[] data = new byte[1024];
  42. while ((nRead = inputStream.read(data, 0, data.length)) != -1) {
  43. buffer.write(data, 0, nRead);
  44. }
  45. buffer.flush();
  46. byte[] module = buffer.toByteArray();
  47. System.out.println("开始加载模型");
  48. env = OrtEnvironment.getEnvironment();
  49. session = env.createSession(module, new OrtSession.SessionOptions());
  50. session.getInputInfo().entrySet().stream().forEach(n -> {
  51. String inputName = n.getKey();
  52. NodeInfo inputInfo = n.getValue();
  53. long[] shape = ((TensorInfo) inputInfo.getInfo()).getShape();
  54. String javaType = ((TensorInfo) inputInfo.getInfo()).type.toString();
  55. System.out.println("模型输入: "+inputName + " -> " + Arrays.toString(shape) + " -> " + javaType);
  56. });
  57. session.getOutputInfo().entrySet().stream().forEach(n -> {
  58. String outputName = n.getKey();
  59. NodeInfo outputInfo = n.getValue();
  60. long[] shape = ((TensorInfo) outputInfo.getInfo()).getShape();
  61. String javaType = ((TensorInfo) outputInfo.getInfo()).type.toString();
  62. System.out.println("模型输出: "+outputName + " -> " + Arrays.toString(shape) + " -> " + javaType);
  63. });
  64. } catch (Exception e) {
  65. e.printStackTrace();
  66. }
  67. }
  68. // 模型推理,输入原始图片和识别区域两个点
  69. public static void inference(Mat mat,Point detection_p1,Point detection_p2){
  70. int px = Double.valueOf(detection_p1.x).intValue();
  71. int py = Double.valueOf(detection_p1.y).intValue();
  72. // 提取rgb(chw存储)并做归一化,也就是 rrrrr bbbbb ggggg
  73. float[] chw = new float[c*h*w];
  74. // 像素点索引
  75. int index = 0;
  76. for(int j=0 ; j<h ; j++){
  77. for(int i=0 ; i<w ; i++){
  78. // 第j行,第i列,根据识别区域p1得到xy坐标的偏移,直接加就行
  79. double[] rgb = mat.get(j+py,i+px);
  80. // 缓存到 chw 中,mat 是 rgba 数据对应的下标 2103
  81. chw[index] = (float)(rgb[2]/255);//r
  82. chw[index + w * h * 1 ] = (float)(rgb[1]/255);//G
  83. chw[index + w * h * 2 ] = (float)(rgb[0]/255);//b
  84. index ++;
  85. }
  86. }
  87. // 创建张量并进行推理
  88. try {
  89. OnnxTensor tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), new long[]{1,c,h,w});
  90. OrtSession.Result output = session.run(Collections.singletonMap("input", tensor));
  91. float[][] out = ((float[][][])(output.get(0)).getValue())[0];
  92. ArrayList<float[]> datas = new ArrayList<>();
  93. for(int i=0;i<out.length;i++){
  94. float[] data = out[i];
  95. float score1 = data[4]; // 边框置信度
  96. float score2 = data[15];// 人脸置信度
  97. if( score1 >= 0.2 && score2>= 0.2){
  98. // xywh 转 x1y1x2y2
  99. float xx = data[0];
  100. float yy = data[1];
  101. float ww = data[2];
  102. float hh = data[3];
  103. float[] xyxy = xywh2xyxy(new float[]{xx,yy,ww,hh},w,h);
  104. data[0] = xyxy[0];
  105. data[1] = xyxy[1];
  106. data[2] = xyxy[2];
  107. data[3] = xyxy[3];
  108. datas.add(data);
  109. }
  110. }
  111. // nms
  112. ArrayList<float[]> datas_after_nms = new ArrayList<>();
  113. while (!datas.isEmpty()){
  114. float[] max = datas.get(0);
  115. datas_after_nms.add(max);
  116. Iterator<float[]> it = datas.iterator();
  117. while (it.hasNext()) {
  118. // nsm阈值
  119. float[] obj = it.next();
  120. double iou = calculateIoU(max,obj);
  121. if (iou > 0.5f) {
  122. it.remove();
  123. }
  124. }
  125. }
  126. // 标注
  127. datas_after_nms.stream().forEach(n->{
  128. // x y w h score 中心点坐标和分数
  129. // x y 关键点坐标
  130. // x y 关键点坐标
  131. // x y 关键点坐标
  132. // x y 关键点坐标
  133. // x y 关键点坐标
  134. // cls_conf 人脸置信度
  135. // 画边框和关键点需要添加偏移
  136. int x1 = Float.valueOf(n[0]).intValue() + px;
  137. int y1 = Float.valueOf(n[1]).intValue() + py;
  138. int x2 = Float.valueOf(n[2]).intValue() + px;
  139. int y2 = Float.valueOf(n[3]).intValue() + py;
  140. Imgproc.rectangle(mat, new Point(x1, y1), new Point(x2, y2), green, tickness);
  141. float point1_x = Float.valueOf(n[5]).intValue() + px;// 关键点1
  142. float point1_y = Float.valueOf(n[6]).intValue() + py;//
  143. float point2_x = Float.valueOf(n[7]).intValue() + px;// 关键点2
  144. float point2_y = Float.valueOf(n[8]).intValue() + py;//
  145. float point3_x = Float.valueOf(n[9]).intValue() + px;// 关键点3
  146. float point3_y = Float.valueOf(n[10]).intValue() + py;//
  147. float point4_x = Float.valueOf(n[11]).intValue() + px;// 关键点4
  148. float point4_y = Float.valueOf(n[12]).intValue() + py;//
  149. float point5_x = Float.valueOf(n[13]).intValue() + px;// 关键点5
  150. float point5_y = Float.valueOf(n[14]).intValue() + py;//
  151. Imgproc.circle(mat, new Point(point1_x, point1_y), 1, green, tickness);
  152. Imgproc.circle(mat, new Point(point2_x, point2_y), 1, green, tickness);
  153. Imgproc.circle(mat, new Point(point3_x, point3_y), 1, green, tickness);
  154. Imgproc.circle(mat, new Point(point4_x, point4_y), 1, green, tickness);
  155. Imgproc.circle(mat, new Point(point5_x, point5_y), 1, green, tickness);
  156. });
  157. }
  158. catch (Exception e){
  159. e.printStackTrace();
  160. }
  161. }
  162. // 中心点坐标转 xin xmax ymin ymax
  163. public static float[] xywh2xyxy(float[] bbox,float maxWidth,float maxHeight) {
  164. // 中心点坐标
  165. float x = bbox[0];
  166. float y = bbox[1];
  167. float w = bbox[2];
  168. float h = bbox[3];
  169. // 计算
  170. float x1 = x - w * 0.5f;
  171. float y1 = y - h * 0.5f;
  172. float x2 = x + w * 0.5f;
  173. float y2 = y + h * 0.5f;
  174. // 限制在图片区域内
  175. return new float[]{
  176. x1 < 0 ? 0 : x1,
  177. y1 < 0 ? 0 : y1,
  178. x2 > maxWidth ? maxWidth:x2,
  179. y2 > maxHeight? maxHeight:y2};
  180. }
  181. // 计算两个框的交并比
  182. private static double calculateIoU(float[] box1, float[] box2) {
  183. // getXYXY() 返回 xmin-0 ymin-1 xmax-2 ymax-3
  184. double x1 = Math.max(box1[0], box2[0]);
  185. double y1 = Math.max(box1[1], box2[1]);
  186. double x2 = Math.min(box1[2], box2[2]);
  187. double y2 = Math.min(box1[3], box2[3]);
  188. double intersectionArea = Math.max(0, x2 - x1 + 1) * Math.max(0, y2 - y1 + 1);
  189. double box1Area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1);
  190. double box2Area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1);
  191. double unionArea = box1Area + box2Area - intersectionArea;
  192. return intersectionArea / unionArea;
  193. }
  194. }

项目详细代码:

https://github.com/TangYuFan/deeplearn-mobile

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/寸_铁/article/detail/747120
推荐阅读
相关标签
  

闽ICP备14008679号