当前位置:   article > 正文

Pytorch-Mobile-Android(3) 部署自己模型_pytorch mobile

pytorch mobile

目录

一、例子:

1.用torch.jit.script转torchscript,不要用torch.jit.trace

2.将图像的width和height用PIL改成符合的输入

3.套用pytorch-mobile官网的代码运行即可

4.Lite version update


一、例子:

1.用torch.jit.script转torchscript,不要用torch.jit.trace

理由见:【Pytorch部署】TorchScript - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/135911580

  1. import vision_transformer
  2. from torch.utils.mobile_optimizer import optimize_for_mobile
  3. import torch
  4. model_vit = vision_transformer._create_vision_transformer('vit_tiny_patch16_384')
  5. model_vit = model_vit.eval()
  6. example = torch.rand(1, 3, 384, 384)
  7. traced_script_module = torch.jit.script(model_vit, example)
  8. traced_script_module_optimized = optimize_for_mobile(traced_script_module)
  9. traced_script_module_optimized._save_for_lite_interpreter(r"D:\paper_code\02android-demo-app-master\HelloWorldApp\app\src\main\assets\vit2.pt")

会报错UserWarning: `optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead
  warnings.warn(

不清楚原因,但是不影响运行。

2.将图像的width和height用PIL改成符合的输入

  1. from PIL import Image
  2. img = Image.open(r'D:\paper_code\02android-demo-app-master\HelloWorldApp\app\src\main\assets\image.jpg')
  3. # img = img.resize((384, 384), Image.BILINEAR)
  4. # img.save(r'D:\paper_code\02android-demo-app-master\HelloWorldApp\app\src\main\assets\image.jpg')
  5. print(img.size)

3.套用pytorch-mobile官网的代码运行即可

  1. package org.pytorch.helloworld;
  2. import android.content.Context;
  3. import android.graphics.Bitmap;
  4. import android.graphics.BitmapFactory;
  5. import android.os.Bundle;
  6. import android.util.Log;
  7. import android.widget.ImageView;
  8. import android.widget.TextView;
  9. import org.pytorch.IValue;
  10. import org.pytorch.LiteModuleLoader;
  11. import org.pytorch.Module;
  12. import org.pytorch.Tensor;
  13. import org.pytorch.torchvision.TensorImageUtils;
  14. import org.pytorch.MemoryFormat;
  15. import java.io.File;
  16. import java.io.FileOutputStream;
  17. import java.io.IOException;
  18. import java.io.InputStream;
  19. import java.io.OutputStream;
  20. import androidx.appcompat.app.AppCompatActivity;
  21. public class MainActivity extends AppCompatActivity {
  22. @Override
  23. protected void onCreate(Bundle savedInstanceState) {
  24. super.onCreate(savedInstanceState);
  25. setContentView(R.layout.activity_main);
  26. Bitmap bitmap = null;
  27. Module module = null;
  28. try {
  29. // creating bitmap from packaged into app android asset 'image.jpg',
  30. // app/src/main/assets/image.jpg
  31. bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
  32. int width = bitmap.getWidth();
  33. int height = bitmap.getHeight();
  34. Log.e("width", String.format("width %d ", width)); //总时间
  35. Log.e("height", String.format("height %d ", height));
  36. // loading serialized torchscript module from packaged into app android asset model.pt,
  37. // app/src/model/assets/model.pt
  38. module = LiteModuleLoader.load(assetFilePath(this, "vit2.pt"));
  39. } catch (IOException e) {
  40. Log.e("PytorchHelloWorld", "Error reading assets", e);
  41. finish();
  42. }
  43. // showing image on UI
  44. ImageView imageView = findViewById(R.id.image);
  45. imageView.setImageBitmap(bitmap);
  46. // preparing input tensor
  47. final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
  48. TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);
  49. // running the model
  50. final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
  51. // getting tensor content as java array of floats
  52. final float[] scores = outputTensor.getDataAsFloatArray();
  53. // searching for the index with maximum score
  54. float maxScore = -Float.MAX_VALUE;
  55. int maxScoreIdx = -1;
  56. for (int i = 0; i < scores.length; i++) {
  57. if (scores[i] > maxScore) {
  58. maxScore = scores[i];
  59. maxScoreIdx = i;
  60. }
  61. }
  62. String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
  63. TextView textView = findViewById(R.id.text);
  64. textView.setText(className);
  65. }
  66. }
  67. /**
  68. * Copies specified asset to the file in /files app directory and returns this file absolute path.
  69. *
  70. * @return absolute file path
  71. */
  72. public static String assetFilePath(Context context, String assetName) throws IOException {
  73. File file = new File(context.getFilesDir(), assetName);
  74. if (file.exists() && file.length() > 0) {
  75. return file.getAbsolutePath();
  76. }
  77. try (InputStream is = context.getAssets().open(assetName)) {
  78. try (OutputStream os = new FileOutputStream(file)) {
  79. byte[] buffer = new byte[4 * 1024];
  80. int read;
  81. while ((read = is.read(buffer)) != -1) {
  82. os.write(buffer, 0, read);
  83. }
  84. os.flush();
  85. }
  86. return file.getAbsolutePath();
  87. }
  88. }
  89. }

4.Lite version update

随着更新,官网上的build.gradle所导入的dependencies版本太低,导入一些model时会报错,这时候,只需打开build.gradle文件,鼠标放在dependencies下的引用的包,就会出现更新的提示,更新即可。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/人工智能uu/article/detail/767363
推荐阅读
相关标签
  

闽ICP备14008679号