当前位置:   article > 正文

在微信小程序部署AI模型的几种方法_小程序 onnx

小程序 onnx

前言

本文只是分享思路,不提供可完整运行的项目代码

onnx部署

以目标检测类模型为例,该类模型会输出类别信息置信度包含检测框的4个坐标信息

但不是所有的onnx模型都能在微信小程序部署,有些算子不支持,这种情况需要点特殊操作,我暂时不能解决。

微信小程序提供的接口相当于使用onnxruntime的接口运行onnx模型,我们要做的就是将视频帧数据(包含RGBA的一维像素数组)转换成对应形状的数组(比如3*224*224的一维Float32Array)然后调用接口并将图像输入得到运行的结果(比如一个1*10*6的一维Float32Array,代表着10个预测框的类别,置信度和框的4个坐标)然后将结果处理(比如行人检测,给置信度设置一个阈值0.5,筛选置信度大于阈值的数组的index,然后按照index取出相应的类别和框坐标)最后在wxml中显示类别名或置信度或在canvas绘制框。

代码框架

这里采用的是实时帧数据,按预设频率调用一帧数据并后处理得到结果

onLoad主体

  1. onLoad(){
  2. // 创建相机上下文
  3. const context = wx.createCameraContext();
  4. // 定义实时帧回调函数
  5. this.listener=context.onCameraFrame((frame)=>this.CamFramCall(frame));
  6. // 初始化session
  7. this.initSession()
  8. },

相机实时帧回调函数 

得到的实时帧数据因系统而异,并不固定(这对后面画追踪框的时候不利)

我的处理方法是把帧数据和<camera>组件的长宽比例统一,这样得到坐标后再乘以一个比例系数即可映射到<camera>因为输进去模型的是帧数据,所以返回的追踪框坐标是基于帧数据的,直接画在<camera>上的canvas有可能出现框的位置有偏差

回调函数里的逻辑是设置<camera>的长(我把宽定死到手机屏幕长度的0.9)预处理图片数据进行推理关闭监听(至此完成一帧)

  1. CamFramCall(frame){
  2. // 根据实时帧的图片长宽比例设置<camera>组件展示大小
  3. this.setData({
  4. windowHeight:frame.height/frame.width*wx.getSystemInfoSync().windowWidth*0.9
  5. })
  6. var dstInput=new Float32Array(3*this.data.imgH*this.data.imgW).fill(255)
  7. // 调用图片预处理函数对实时帧数据进行处理
  8. this.preProcess(frame,dstInput)
  9. // 将处理完的数据进行推理得到结果
  10. this.infer(dstInput)
  11. console.log('完成一次帧循环')
  12. // 关闭监听
  13. this.listener.stop()
  14. },

初始化session

首先得将onnx上传至云端,获得一个存储路径(比如cloud://cloud1-8gcwcxqrb8722e9e.636c-cloud1-8gcwcxqrb8722e9e-1324077753/rtdetrWorker.onnx

当用户首次使用该小程序时,手机里没有onnx模型的存储,需要从云端下载;而已经非第一次使用该小程序的用户手机里已经保存了之前下载的onnx模型,就无需下载。所以此处代码逻辑是需要检测用户的存储里是否有该onnx模型,不存在就下载,下载完并保存模型文件后就执行下一步;存在就直接执行下一步

  1. initSession(){
  2. // onnx云端下载路径
  3. const cloudPath='cloud://cloud1-8gcwcxqrb8722e9e.636c-cloud1-8gcwcxqrb8722e9e-1324077753/best.onnx'
  4. const lastIndex=cloudPath.lastIndexOf('/')
  5. const filename=cloudPath.substring(lastIndex+1)
  6. const modelPath=`${wx.env.USER_DATA_PATH}/`+filename
  7. // 检测onnx文件是否存在
  8. wx.getFileSystemManager().access({
  9. path:modelPath,
  10. // 如果存在就创建session,定时开启监听实时帧
  11. success:(res)=>{
  12. console.log('file already exist')
  13. this.createInferenceSession(modelPath)
  14. setInterval(()=>{this.listener.start()},1000)
  15. },
  16. // 如果不存在
  17. fail:(res)=>{
  18. console.error(res)
  19. wx.cloud.init()
  20. console.log('begin download model')
  21. // 下载提示框
  22. wx.showLoading({
  23. title: '加载检测中',
  24. })
  25. // 调用自定义的下载文件函数
  26. this.downloadFile(cloudPath,function(r) {
  27. console.log(`下载进度:${r.progress}%,已下载${r.totalBytesWritten}B,共${r.totalBytesExpectedToWrite}B`)
  28. }).then(result=>{
  29. // 下载文件成功后保存
  30. wx.getFileSystemManager().saveFile({
  31. tempFilePath:result.tempFilePath,
  32. filePath:modelPath,
  33. // 保存文件成功后创建session,定时开启监听实时帧
  34. success:(res)=>{
  35. const modelPath=res.savedFilePath
  36. console.log('save onnx model at path:'+modelPath)
  37. this.createInferenceSession(modelPath)
  38. // 关闭下载提示框
  39. wx.hideLoading()
  40. setInterval(()=>{this.listener.start()},1000)
  41. },
  42. fail:(res)=>{
  43. console.error(res)
  44. }
  45. })
  46. })
  47. }
  48. })
  49. },

自定义的下载文件函数

  1. downloadFile(fileID, onCall = () => {}) {
  2. return new Promise((resolve) => {
  3. const task = wx.cloud.downloadFile({
  4. fileID,
  5. success: res => resolve(res),
  6. })
  7. task.onProgressUpdate((res) => {
  8. if (onCall(res) == false) {
  9. task.abort()
  10. }
  11. })
  12. })
  13. },

自定义创建session的函数

  1. createInferenceSession(modelPath) {
  2. return new Promise((resolve, reject) => {
  3. this.session = wx.createInferenceSession({
  4. model: modelPath,
  5. precisionLevel : 4,
  6. allowNPU : false,
  7. allowQuantize: false,
  8. });
  9. // 监听error事件
  10. this.session.onError((error) => {
  11. console.error(error);
  12. reject(error);
  13. });
  14. this.session.onLoad(() => {
  15. resolve();
  16. });
  17. })
  18. },

自定义的图像预处理函数

该函数接收帧数据(RGBA一维数组)和在外面初始化的Float32Array数组,执行归一化、去除透明度通道。

  1. preProcess(frame,dstInput){
  2. return new Promise(resolve=>{
  3. const origData = new Uint8Array(frame.data);
  4. for(var j=0;j<frame.height;j++){
  5. for(var i=0;i<frame.width;i++){
  6. dstInput[i*3+this.data.imgW*j*3]=origData[i*4+j*frame.width*4]/255
  7. dstInput[i*3+1+this.data.imgW*j*3]=origData[i*4+1+j*frame.width*4]/255
  8. dstInput[i*3+2+this.data.imgW*j*3]=origData[i*4+2+j*frame.width*4]/255
  9. }
  10. }
  11. resolve();
  12. })
  13. },

自定义的推理函数

推理接口接收数个键值对input,具体需要参照自己的onnx模型,在Netron查看相应的模型信息

我这里只有1个输入,对应的名字为"images",接收(1,3,640,640)形状的图像数组

我这里的onnx输出数组是1*6*10的,代表有10个检测框,还有4个坐标信息+类别编号+置信度。我的输出的数组名字叫 output0,注意参照自己的onnx输出名

接着就是获取最大置信度所在索引并按照索引取出其对应框的信息和类别编号

然后绘制在canvas上

为了在没有检测到物体时不绘制出框,检测到物体时绘制检测框,就先获取<canvas>对象,清空画布,再对session输出的数据进行后处理,然后给个阈值判断是否画框。

  1. infer(imgData){
  2. this.session.run({
  3. 'images':{
  4. shape:[1,3,this.data.imgH,this.data.imgW],
  5. data:imgData.buffer,
  6. type:'float32',
  7. }
  8. // 获得运行结果后
  9. }).then((res)=>{
  10. let results=new Float32Array(res.output0.data)
  11. // 获取canvas对象,填上id,这里对应”c1“
  12. wx.createSelectorQuery().select('#c1')
  13. .fields({node:true,size:true})
  14. .exec((res)=>{
  15. const canvas=res[0].node
  16. const ctx=canvas.getContext('2d')
  17. canvas.width=wx.getSystemInfoSync().windowWidth*0.9
  18. canvas.height=this.data.windowHeight
  19. // 对session数据进行后处理
  20. this.postProcess(results).then((index)=>{
  21. // 清空画布
  22. ctx.clearRect(0,0,canvas.width,canvas.height)
  23. // 大于阈值,就认为检测到物体
  24. if(this.data.conf>0.5){
  25. this.setData({
  26. class_name:'检测到苹果'
  27. })
  28. // 这里需要参考自己的session输出的数组上对应位置的具体含义
  29. // 比如我的session输出1*6*10的一维数组,可以看作6*10的二维数组,
  30. // 有6行数据,第一行对应中心点x坐标,第二行对应中心点y坐标,
  31. // 第3行对应检测框的w宽度,第4行对应检测框的h长度,
  32. // 第5行对应置信度,第6行对应类别编号
  33. var x=results[index]
  34. var y=results[10+index]
  35. var w=results[2*10+index]
  36. var h=results[3*10+index]
  37. var x1=Math.round(x-w/2)
  38. var y1=Math.round(y-h/2)
  39. var x2=Math.round(x+w/2)
  40. var y2=Math.round(y+h/2)
  41. ctx.strokeStyle='red'
  42. ctx.lineWidth=2
  43. ctx.strokeRect(x1,y1,x2,y2)
  44. }
  45. })
  46. })
  47. })
  48. },

 自定义的后处理函数

初始化置信度和index,对10个检测框进行遍历,取出置信度最大元素所在index,然后更新到全局变量中,这里设定阈值为0.5. 此函数接收session输出的数组,返回index

  1. postProcess(results){
  2. return new Promise((resolve)=>{
  3. var maxConf=results[10*4]
  4. var index=0
  5. for(var i=1;i<10;i+=1){
  6. var conf=results[10*4+i]
  7. if(conf>0.5 & maxConf<conf){
  8. maxConf=conf
  9. index=i
  10. }
  11. }
  12. this.setData({
  13. conf:maxConf,
  14. class_name:'未检测出苹果'
  15. })
  16. resolve(index)
  17. })
  18. },

代码总览

index.js

  1. Page({
  2. data: {
  3. imagePath: '/images/tree.png',
  4. windowHeight:wx.getSystemInfoSync().windowWidth*1.197,
  5. imgH:640,
  6. imgW:640,
  7. conf:0,
  8. class_name:'未检测到红火蚁',
  9. },
  10. onLoad(){
  11. const context = wx.createCameraContext();
  12. this.listener=context.onCameraFrame((frame)=>this.CamFramCall(frame));
  13. this.initSession()
  14. },
  15. initSession(){
  16. const cloudPath='cloud://cloud1-8gcwcxqrb8722e9e.636c-cloud1-8gcwcxqrb8722e9e-1324077753/best.onnx'
  17. const lastIndex=cloudPath.lastIndexOf('/')
  18. const filename=cloudPath.substring(lastIndex+1)
  19. const modelPath=`${wx.env.USER_DATA_PATH}/`+filename
  20. wx.getFileSystemManager().access({
  21. path:modelPath,
  22. success:(res)=>{
  23. console.log('file already exist')
  24. this.createInferenceSession(modelPath)
  25. setInterval(()=>{this.listener.start()},1000)
  26. },
  27. fail:(res)=>{
  28. console.error(res)
  29. wx.cloud.init()
  30. console.log('begin download model')
  31. wx.showLoading({
  32. title: '加载检测中',
  33. })
  34. this.downloadFile(cloudPath,function(r) {
  35. console.log(`下载进度:${r.progress}%,已下载${r.totalBytesWritten}B,共${r.totalBytesExpectedToWrite}B`)
  36. }).then(result=>{
  37. wx.getFileSystemManager().saveFile({
  38. tempFilePath:result.tempFilePath,
  39. filePath:modelPath,
  40. success:(res)=>{
  41. const modelPath=res.savedFilePath
  42. console.log('save onnx model at path:'+modelPath)
  43. this.createInferenceSession(modelPath)
  44. wx.hideLoading()
  45. setInterval(()=>{this.listener.start()},1000)
  46. },
  47. fail:(res)=>{
  48. console.error(res)
  49. }
  50. })
  51. })
  52. }
  53. })
  54. },
  55. createInferenceSession(modelPath){
  56. return new Promise((resolve,reject)=>{
  57. this.session=wx.createInferenceSession({
  58. model: modelPath,
  59. precesionLevel:4,
  60. allowNPU:false,
  61. allowQuantize:false,
  62. })
  63. this.session.onError((error) => {
  64. console.error(error)
  65. reject(error)
  66. })
  67. this.session.onLoad(()=>{
  68. resolve()
  69. })
  70. })
  71. },
  72. CamFramCall(frame){
  73. this.setData({
  74. windowHeight:frame.height/frame.width*wx.getSystemInfoSync().windowWidth*0.9
  75. })
  76. var dstInput=new Float32Array(3*this.data.imgH*this.data.imgW).fill(255)
  77. this.preProcess(frame,dstInput)
  78. this.infer(dstInput)
  79. console.log('完成一次帧循环')
  80. this.listener.stop()
  81. },
  82. preProcess(frame,dstInput){
  83. return new Promise(resolve=>{
  84. const origData = new Uint8Array(frame.data);
  85. for(var j=0;j<frame.height;j++){
  86. for(var i=0;i<frame.width;i++){
  87. dstInput[i*3+this.data.imgW*j*3]=origData[i*4+j*frame.width*4]
  88. dstInput[i*3+1+this.data.imgW*j*3]=origData[i*4+1+j*frame.width*4]
  89. dstInput[i*3+2+this.data.imgW*j*3]=origData[i*4+2+j*frame.width*4]
  90. }
  91. }
  92. resolve();
  93. })
  94. },
  95. postProcess(results){
  96. return new Promise((resolve)=>{
  97. var maxConf=results[10*4]
  98. var index=0
  99. for(var i=1;i<10;i+=1){
  100. var conf=results[10*4+i]
  101. if(conf>0.5 & maxConf<conf){
  102. maxConf=conf
  103. index=i
  104. }
  105. }
  106. this.setData({
  107. conf:maxConf,
  108. class_name:'未检测到红火蚁'
  109. })
  110. resolve(index)
  111. })
  112. },
  113. infer(imgData){
  114. this.session.run({
  115. 'images':{
  116. shape:[1,3,this.data.imgH,this.data.imgW],
  117. data:imgData.buffer,
  118. type:'float32',
  119. }
  120. }).then((res)=>{
  121. let results=new Float32Array(res.output0.data)
  122. wx.createSelectorQuery().select('#c1')
  123. .fields({node:true,size:true})
  124. .exec((res)=>{
  125. const canvas=res[0].node
  126. const ctx=canvas.getContext('2d')
  127. canvas.width=wx.getSystemInfoSync().windowWidth*0.9
  128. canvas.height=this.data.windowHeight
  129. this.postProcess(results).then((index)=>{
  130. ctx.clearRect(0,0,canvas.width,canvas.height)
  131. if(this.data.conf>0.5){
  132. this.setData({
  133. class_name:'检测到红火蚁'
  134. })
  135. var x=results[index]
  136. var y=results[8400+index]
  137. var w=results[2*8400+index]
  138. var h=results[3*8400+index]
  139. var x1=Math.round(x-w/2)
  140. var y1=Math.round(y-h/2)
  141. var x2=Math.round(x+w/2)
  142. var y2=Math.round(y+h/2)
  143. ctx.strokeStyle='red'
  144. ctx.lineWidth=2
  145. ctx.strokeRect(x1,y1,x2,y2)
  146. }
  147. })
  148. })
  149. })
  150. },
  151. downloadFile(fileID,onCall=()=>{}){
  152. return new Promise((resolve)=>{
  153. const task=wx.cloud.downloadFile({
  154. fileID,
  155. success:res=>resolve(res),
  156. })
  157. task.onProgressUpdate((res)=>{
  158. if(onCall(res)==false){
  159. task.abort()
  160. }
  161. })
  162. })
  163. },
  164. })

 index.wxss

  1. .c1{
  2. width: 100%;
  3. align-items: center;
  4. text-align: center;
  5. display: flex;
  6. flex-direction: column;
  7. }
  8. #myCanvas{
  9. width: 100%;
  10. height: 100%;
  11. }

index.wxml

  1. <view class="c1">
  2. <camera class="camera" binderror="error" mode="normal" style="width: 90%; height: {{windowHeight}}px;">
  3. <canvas id="c1" type="2d"></canvas>
  4. </camera>
  5. <view>结果:{{class_name}}</view>
  6. <view>置信度:{{conf}}</view>
  7. </view>

flask部署

微信小程序负责把图像数据或帧数据传到服务器,在服务器用flask搭建相关模型运行环境,将接收到的图像数据或帧数据预处理后输入模型里,在将结果返回给微信小程序,微信小程序再显示结果。

我这里给的例子是传送帧数据的,也就是实时检测。

但是目前存在一个问题,速度,检测框的速度跟不上物体移动,只能慢速检测,一旦提高频率小程序就抓不到实时帧。

前端

在前端,获得帧数据后,因为帧数据的格式是一维RGBA数组,为了将其转成png,方便服务器处理,把帧数据绘制到画布上,再导出为png送入服务器。接收到服务器的结果后,将检测框绘制到相机的界面,需要在<camera>标签里加上<canvas>标签,然后画上矩形框,并在下方显示分类结果。

主体代码框架

初始化页面数据,camH是<camera>组件在展示页面的高度,k是比例系数

每个系统运行小程序,所导出的frame大小是不同的,为了更好的画检测框:首先模型接收的是frame,其运行的结果的检测框坐标数据是基于frame的。而<camera>组件的展示大小也是要设定的,我把<camera>的宽度定死在整个页面宽度的0.9(在wxss中设定),然后使<camera>与frame成比例(就只需要设定camH,我这里给了一个初始值1.2,后面的程序会更精确的调),<camera>与frame的比例系数为k,再让camera中的画布<canvas>完全贴合于其父元素<camera>,只要把模型跑出的坐标乘以比例系数k即可映射到<camera>上。

  1. onLoad(){
  2. // 执行自定义的初始化函数
  3. this.init().then(()=>{
  4. // 创建相机上下文
  5. const context = wx.createCameraContext();
  6. // 设定监听回调函数
  7. this.listener=context.onCameraFrame(frame=>this.CamFramCall(frame));
  8. // 每500ms开启一次监听
  9. setInterval(()=>{this.listener.start()}, 500);
  10. })
  11. },

自定义的初始化函数

 为了设定<camera>的高和比例系数,需要知道frame的尺寸,所以这里调用了一次相机帧。而后面调用相机帧是为了获取帧数据,两者目的不同。

  1. init(){
  2. return new Promise(resolve=>{
  3. const context = wx.createCameraContext();
  4. const listener=context.onCameraFrame(frame=>{
  5. this.setData({
  6. camH:wx.getSystemInfoSync().windowWidth*0.9*frame.height/frame.width,
  7. k:wx.getSystemInfoSync().windowWidth*0.9/frame.width
  8. })
  9. listener.stop()
  10. })
  11. listener.start()
  12. resolve()
  13. })
  14. },

 实时帧回调函数

在回调函数里,将接收到的数据转base64,然后将数据传到服务器,最后停止监听,至此完成一帧

  1. CamFramCall(frame){
  2. this.base64ToPNG(frame).then(result=>{
  3. this.interWithServer({'img':result})
  4. console.log('完成一次帧循环')
  5. this.listener.stop()
  6. })
  7. },

自定义帧数据转base64的函数

参考http://t.csdnimg.cn/2hc7k

这里增加了异步编程的语句,更合理

  1. base64ToPNG(frame){
  2. return new Promise(resolve=>{
  3. const query = wx.createSelectorQuery()
  4. query.select('#canvas')
  5. .fields({node:true,size:true})
  6. .exec((res)=>{
  7. const canvas=res[0].node
  8. const ctx=canvas.getContext('2d')
  9. canvas.width=frame.width
  10. canvas.height=frame.height
  11. var imageData=ctx.createImageData(canvas.width,canvas.height)
  12. var ImgU8Array = new Uint8ClampedArray(frame.data);
  13. for(var i=0;i<ImgU8Array.length;i+=4){
  14. imageData.data[0+i]=ImgU8Array[i+0]
  15. imageData.data[1+i]=ImgU8Array[i+1]
  16. imageData.data[2+i]=ImgU8Array[i+2]
  17. imageData.data[3+i]=ImgU8Array[i+3]
  18. }
  19. ctx.putImageData(imageData,0,0,0,0,canvas.width,canvas.height)
  20. resolve(canvas.toDataURL())
  21. })
  22. })
  23. },

自定义传数据到服务器函数 

  1. interWithServer(imgData){
  2. const header = {
  3. 'content-type': 'application/x-www-form-urlencoded'
  4. };
  5. wx.request({
  6. // 填上自己的服务器地址
  7. url: 'http://172.16.3.186:5000/predict',
  8. method: 'POST',
  9. header: header,
  10. data: imgData,
  11. success: (res) => {
  12. // 返回的坐标数据,调用自定义的画检测框函数
  13. this.drawRect(res.data['conf'],res.data['x'],res.data['y'],res.data['w'],res.data['h'])
  14. },
  15. fail: () => {
  16. wx.showToast({
  17. title: 'Failed to connect server!',
  18. icon: 'none',
  19. });
  20. }
  21. });
  22. },

自定义的画检测框函数 

  1. drawRect(conf,x,y,w,h){
  2. // 填上<camera>内<canvas>的id
  3. wx.createSelectorQuery().select('#c1')
  4. .fields({node:true,size:true})
  5. .exec((res)=>{
  6. const canvas=res[0].node
  7. const ctx=canvas.getContext('2d')
  8. // 设置宽高,完全填充于<camera>组件的大小
  9. canvas.width=wx.getSystemInfoSync().windowWidth*0.9
  10. canvas.height=this.data.camH
  11. // 清空画布,避免遗留上次的检测框
  12. ctx.clearRect(0,0,canvas.width,canvas.height)
  13. // 如果置信度大于0.5,才画框
  14. if(conf>0.5){
  15. ctx.strokeStyle='red'
  16. ctx.lineWidth=2
  17. const k =this.data.k
  18. // 经过真机测试,发现在x和y上乘以比例系数即可,较为精确
  19. // 虽然理论上要按比例计算,但可以根据实际的情况做出一点调整,对检测框进行修正
  20. ctx.strokeRect(k*x,k*y,x+w,y+h)
  21. }
  22. })
  23. },

index.js

  1. Page({
  2. data: {
  3. camH:wx.getSystemInfoSync().windowWidth*1.2,
  4. k:1
  5. },
  6. onLoad(){
  7. this.init().then(()=>{
  8. const context = wx.createCameraContext();
  9. this.listener=context.onCameraFrame(frame=>this.CamFramCall(frame));
  10. setInterval(()=>{this.listener.start()}, 500);
  11. })
  12. },
  13. init(){
  14. return new Promise(resolve=>{
  15. const context = wx.createCameraContext();
  16. const listener=context.onCameraFrame(frame=>{
  17. this.setData({
  18. camH:wx.getSystemInfoSync().windowWidth*0.9*frame.height/frame.width,
  19. k:wx.getSystemInfoSync().windowWidth*0.9/frame.width
  20. })
  21. listener.stop()
  22. })
  23. listener.start()
  24. resolve()
  25. })
  26. },
  27. CamFramCall(frame){
  28. this.base64ToPNG(frame).then(result=>{
  29. this.interWithServer({'img':result})
  30. console.log('完成一次帧循环')
  31. this.listener.stop()
  32. })
  33. },
  34. drawRect(conf,x,y,w,h){
  35. wx.createSelectorQuery().select('#c1')
  36. .fields({node:true,size:true})
  37. .exec((res)=>{
  38. const canvas=res[0].node
  39. const ctx=canvas.getContext('2d')
  40. canvas.width=wx.getSystemInfoSync().windowWidth*0.9
  41. canvas.height=this.data.camH
  42. ctx.clearRect(0,0,canvas.width,canvas.height)
  43. if(conf>0.5){
  44. ctx.strokeStyle='red'
  45. ctx.lineWidth=2
  46. const k =this.data.k
  47. ctx.strokeRect(k*x,k*y,x+w,y+h)
  48. }
  49. })
  50. },
  51. interWithServer(imgData){
  52. const header = {
  53. 'content-type': 'application/x-www-form-urlencoded'
  54. };
  55. wx.request({
  56. url: 'http://172.16.3.186:5000/predict',
  57. method: 'POST',
  58. header: header,
  59. data: imgData,
  60. success: (res) => {
  61. this.drawRect(res.data['conf'],res.data['x'],res.data['y'],res.data['w'],res.data['h'])
  62. },
  63. fail: () => {
  64. wx.showToast({
  65. title: 'Failed to connect server!',
  66. icon: 'none',
  67. });
  68. }
  69. });
  70. },
  71. base64ToPNG(frame){
  72. return new Promise(resolve=>{
  73. const query = wx.createSelectorQuery()
  74. query.select('#tranPng')
  75. .fields({node:true,size:true})
  76. .exec((res)=>{
  77. const canvas=res[0].node
  78. const ctx=canvas.getContext('2d')
  79. canvas.width=frame.width
  80. canvas.height=frame.height
  81. var imageData=ctx.createImageData(canvas.width,canvas.height)
  82. var ImgU8Array = new Uint8ClampedArray(frame.data);
  83. for(var i=0;i<ImgU8Array.length;i+=4){
  84. imageData.data[0+i]=ImgU8Array[i+0]
  85. imageData.data[1+i]=ImgU8Array[i+1]
  86. imageData.data[2+i]=ImgU8Array[i+2]
  87. imageData.data[3+i]=ImgU8Array[i+3]
  88. }
  89. ctx.putImageData(imageData,0,0,0,0,canvas.width,canvas.height)
  90. resolve(canvas.toDataURL())
  91. })
  92. })
  93. },
  94. })

 index.wxml

注意,<camera>中的<canvas>是为了画检测框,另一个<canvas>是为了将frame数据转base64Png。

  1. <view class="c1">
  2. <camera class="camera" binderror="error" style="width: 90%; height: {{camH}}px;">
  3. <canvas id="c1" type="2d"></canvas>
  4. </camera>
  5. <canvas id="tranPng" hidden="true" type="2d"></canvas>
  6. </view>

index.wxss

  1. .c1{
  2. width: 100%;
  3. align-items: center;
  4. text-align: center;
  5. display: flex;
  6. flex-direction: column;
  7. }
  8. #c1{
  9. width: 100%;
  10. height: 100%;
  11. }
  12. #canvas{
  13. width: 100%;
  14. }

后端

接收数据,预处理图像,送入模型,得到初始结果,转化初始结果得到最终结果,返回数据到前端

这里仅作演示,不提供完整项目运行代码和依赖项

  1. from PIL import Image
  2. from gevent import monkey
  3. from flask import Flask, jsonify, request
  4. from gevent.pywsgi import WSGIServer
  5. import cv2
  6. import paddle
  7. import numpy as np
  8. from ppdet.core.workspace import load_config
  9. from ppdet.engine import Trainer
  10. from ppdet.metrics import get_infer_results
  11. from ppdet.data.transform.operators import NormalizeImage, Permute
  12. import base64
  13. import io
  14. app = Flask(__name__)
  15. monkey.patch_all()
  16. # 准备基础的参数
  17. config_path = 'face_detection\\blazeface_1000e.yml'
  18. cfg = load_config(config_path)
  19. weight_path = '202.pdparams'
  20. infer_img_path = '1.png'
  21. cfg.weights = weight_path
  22. bbox_thre = 0.8
  23. paddle.set_device('cpu')
  24. # 创建所需的类
  25. trainer = Trainer(cfg, mode='test')
  26. trainer.load_weights(cfg.weights)
  27. trainer.model.eval()
  28. normaler = NormalizeImage(mean=[123, 117, 104], std=[127.502231, 127.502231, 127.502231], is_scale=False)
  29. permuter = Permute()
  30. model_dir = "face_detection\\blazeface_1000e.yml" # 模型路径
  31. save_path = "output" # 推理结果保存路径
  32. def infer(img, threshold=0.2):
  33. img = img.replace("data:image/png;base64,", "")
  34. img = base64.b64decode(img)
  35. img = Image.open(io.BytesIO(img))
  36. img = img.convert('RGB')
  37. img = np.array(img)
  38. # 准备数据字典
  39. data_dict = {'image': img}
  40. data_dict = normaler(data_dict)
  41. data_dict = permuter(data_dict)
  42. h, w, c = img.shape
  43. data_dict['im_id'] = paddle.Tensor(np.array([[0]]))
  44. data_dict['im_shape'] = paddle.Tensor(np.array([[h, w]], dtype=np.float32))
  45. data_dict['scale_factor'] = paddle.Tensor(np.array([[1., 1.]], dtype=np.float32))
  46. data_dict['image'] = paddle.Tensor(data_dict['image'].reshape((1, c, h, w)))
  47. data_dict['curr_iter'] = paddle.Tensor(np.array([0]))
  48. # 进行预测
  49. outs = trainer.model(data_dict)
  50. # 对预测的数据进行后处理得到最终的bbox信息
  51. for key in ['im_shape', 'scale_factor', 'im_id']:
  52. outs[key] = data_dict[key]
  53. for key, value in outs.items():
  54. outs[key] = value.numpy()
  55. clsid2catid, catid2name = {0: 'face'}, {0: 0}
  56. batch_res = get_infer_results(outs, clsid2catid)
  57. for sub_dict in batch_res['bbox']:
  58. if sub_dict['score'] > bbox_thre:
  59. image_id=sub_dict['image_id']
  60. category_id=sub_dict['category_id']
  61. x,y,w,h=[int(i) for i in sub_dict['bbox']]
  62. conf=sub_dict['score']
  63. print(x,y,w,h,conf)
  64. return jsonify({'conf':conf,'x':x,'y':y,'w':w,'h':h})
  65. else:
  66. return jsonify({'conf':0,'x':0,'y':0,'w':0,'h':0})
  67. @app.route('/predict', methods=['POST'])
  68. def predict():
  69. if request.method == 'POST':
  70. img = request.form.get('img')
  71. w=request.form.get('w')
  72. h=request.form.get('h')
  73. return infer(img)
  74. if __name__ == '__main__':
  75. server = WSGIServer(('0.0.0.0', 5000), app)
  76. server.serve_forever()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/空白诗007/article/detail/1002835
推荐阅读
  

闽ICP备14008679号