赞
踩
目录
1.用torch.jit.script转torchscript,不要用torch.jit.trace
理由见:【Pytorch部署】TorchScript - 知乎 (zhihu.com)
https://zhuanlan.zhihu.com/p/135911580
- import vision_transformer
- from torch.utils.mobile_optimizer import optimize_for_mobile
- import torch
-
- model_vit = vision_transformer._create_vision_transformer('vit_tiny_patch16_384')
- model_vit = model_vit.eval()
- example = torch.rand(1, 3, 384, 384)
- traced_script_module = torch.jit.script(model_vit, example)
- traced_script_module_optimized = optimize_for_mobile(traced_script_module)
- traced_script_module_optimized._save_for_lite_interpreter(r"D:\paper_code\02android-demo-app-master\HelloWorldApp\app\src\main\assets\vit2.pt")
会报错UserWarning: `optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead
warnings.warn(不清楚原因,但是不影响运行。
- from PIL import Image
-
- img = Image.open(r'D:\paper_code\02android-demo-app-master\HelloWorldApp\app\src\main\assets\image.jpg')
- # img = img.resize((384, 384), Image.BILINEAR)
-
- # img.save(r'D:\paper_code\02android-demo-app-master\HelloWorldApp\app\src\main\assets\image.jpg')
- print(img.size)
- package org.pytorch.helloworld;
-
- import android.content.Context;
- import android.graphics.Bitmap;
- import android.graphics.BitmapFactory;
- import android.os.Bundle;
- import android.util.Log;
- import android.widget.ImageView;
- import android.widget.TextView;
-
- import org.pytorch.IValue;
- import org.pytorch.LiteModuleLoader;
- import org.pytorch.Module;
- import org.pytorch.Tensor;
- import org.pytorch.torchvision.TensorImageUtils;
- import org.pytorch.MemoryFormat;
-
- import java.io.File;
- import java.io.FileOutputStream;
- import java.io.IOException;
- import java.io.InputStream;
- import java.io.OutputStream;
-
- import androidx.appcompat.app.AppCompatActivity;
-
- public class MainActivity extends AppCompatActivity {
-
- @Override
- protected void onCreate(Bundle savedInstanceState) {
- super.onCreate(savedInstanceState);
- setContentView(R.layout.activity_main);
-
- Bitmap bitmap = null;
- Module module = null;
- try {
- // creating bitmap from packaged into app android asset 'image.jpg',
- // app/src/main/assets/image.jpg
- bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
- int width = bitmap.getWidth();
- int height = bitmap.getHeight();
- Log.e("width", String.format("width %d ", width)); //总时间
- Log.e("height", String.format("height %d ", height));
- // loading serialized torchscript module from packaged into app android asset model.pt,
- // app/src/model/assets/model.pt
- module = LiteModuleLoader.load(assetFilePath(this, "vit2.pt"));
- } catch (IOException e) {
- Log.e("PytorchHelloWorld", "Error reading assets", e);
- finish();
- }
-
-
- // showing image on UI
- ImageView imageView = findViewById(R.id.image);
- imageView.setImageBitmap(bitmap);
-
- // preparing input tensor
- final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
- TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);
-
-
-
-
-
-
-
-
-
- // running the model
- final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
- // getting tensor content as java array of floats
- final float[] scores = outputTensor.getDataAsFloatArray();
-
- // searching for the index with maximum score
- float maxScore = -Float.MAX_VALUE;
- int maxScoreIdx = -1;
- for (int i = 0; i < scores.length; i++) {
- if (scores[i] > maxScore) {
- maxScore = scores[i];
- maxScoreIdx = i;
- }
- }
-
- String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
- TextView textView = findViewById(R.id.text);
- textView.setText(className);
- }
-
-
-
-
-
-
- }
-
- /**
- * Copies specified asset to the file in /files app directory and returns this file absolute path.
- *
- * @return absolute file path
- */
-
- public static String assetFilePath(Context context, String assetName) throws IOException {
- File file = new File(context.getFilesDir(), assetName);
- if (file.exists() && file.length() > 0) {
- return file.getAbsolutePath();
- }
-
- try (InputStream is = context.getAssets().open(assetName)) {
- try (OutputStream os = new FileOutputStream(file)) {
- byte[] buffer = new byte[4 * 1024];
- int read;
- while ((read = is.read(buffer)) != -1) {
- os.write(buffer, 0, read);
- }
- os.flush();
- }
- return file.getAbsolutePath();
- }
- }
- }
-

随着更新,官网上的build.gradle所导入的dependencies版本太低,导入一些model时会报错,这时候,只需打开build.gradle文件,鼠标放在dependencies下的引用的包,就会出现更新的提示,更新即可。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。