当前位置:   article > 正文

pytorch的安卓部署_pytorch-openpose框架部署在安卓

pytorch-openpose框架部署在安卓

Pytorch

PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。

首先是使用pytorch

cuda 和 cudnn

cuDNN是NVIDIA专门针对深度神经网络中的基础操作而设计的基于GPU的加速库。cuDNN为深度神经网络中的标准流程提供了高度优化的实现方式,例如卷积、池化、归一化以及激活层的前向以及后向过程。当开发者们需要用到深度学习GPU加速时,才安装cuDNN库,工作速度相较CPU快很多。

需要安装cuda 和 cudnn 库的卸载与安装,可以参考以下教程:
cuda和cudnn安装

安装完成后,可按照下面教程安装pytorch:
此处需要用到anaconda:Anaconda是一个安装、管理Python相关包的软件,其包含了conda、Python等180多个科学包及其依赖项,可以用于在同一个机器上安装不同版本的软件包及其依赖,并能够在不同的环境之间切换。

pytorch安装

另外,如需在pycharm进行操作,可安装以下教程配置虚拟环境:
pycharm配置讯环境(此步需要完成pytorch环境的安装,有anaconda环境)

所有环境准备好以后,查看配置是否完成:

import torch
# 查看版本
print(torch.__version__)
# 查看gup是否可用
print(torch.cuda.is_available())
# 返回gpu个数
print(torch.cuda.device_count())
# 查看对应cuda版本号
print(torch.backends.cudnn.version())
print(torch.version.cuda)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

在这里插入图片描述
现在就开始准备移植所需的文件吧。

pt模型和pth模型的不同

PT模型是一种完整的模型文件,不仅包含了模型的参数,还包括了模型的结构,可以直接被加载到模型中,开始进行训练和预测。而pth文件则只保存了模型的参数,因此在加载模型时需要重新定义模型结构。
所以,需要一个安卓端的pt模型
这里我们使用到图像识别的resnet101模型,新建一个py文件,运行以下代码:

模型转换

import torch
import torchvision

model = torchvision.models.resnet101(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model101.pt")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • import torch 和 import torchvision 是导入PyTorch和PyTorch的计算机视觉库。
  • model = torchvision.models.resnet101(pretrained=True) 加载一个预训练的ResNet-101模型。pretrained=True表示使用预训练的权重,这些权重是在ImageNet数据集上训练得到的。
  • model.eval() 将模型设置为评估模式。这在推理时使用,例如在测试集或生产环境中。
  • example = torch.rand(1, 3, 224, 224) 创建一个随机输入张量,大小为[1, 3, 224, 224],模拟一个输入图像。
  • traced_script_module = torch.jit.trace(model, example) 使用torch.jit.trace将模型转换为一个TorchScript模型。TorchScript是一种可以优化PyTorch模型的方式,使其在没有Python运行环境的情况下运行。
  • traced_script_module.save(“model101.pt”) 将转换后的模型保存为"model101.pt"。

标签获取

这样,我们就获得了一个图像识别的pt模型。
然后,需要找到该模型的标签文件,我们从以下项目中获取
https://github.com/ethereon/caffe-tensorflow/blob/master/examples/imagenet/imagenet-classes.txt

拿到txt标签文件以后,我们需要转换成java中String[]的形式,所以稍微处理一下,

import re

with open('imagenet-classes.txt','r',encoding='utf-8') as f:
    text = f.read()

text = re.sub(r'^', '"', text, flags=re.M)
text = re.sub(r'$', '",', text, flags=re.M)

with open('imagenet-classes.txt','w',encoding='utf-8') as f:
    f.write(text)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

这里是给开头和结尾都加上了"(双引号)以及结尾用,(逗号)分隔。

安卓端代码

完成以后我们打开AndroidStudio
新建一个空的项目
在资源文件中放入我们的model101.pt文件:
在这里插入图片描述
然后Gradle导入pytorch的依赖:

implementation("org.pytorch:pytorch_android:1.12.1")
implementation("org.pytorch:pytorch_android_torchvision:1.12.1")
  • 1
  • 2

完成以后修改activity

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout
    xmlns:android="http://schemas.android.com/apk/res/android"
    android:layout_width="match_parent"
    android:layout_height="match_parent">

    <ImageView
        android:id="@+id/image"
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        android:scaleType="centerCrop" />

    <TextView
        android:id="@+id/text"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_alignParentTop="true"
        android:layout_marginTop="644dp"
        android:padding="8dp"
        android:textColor="@android:color/white"
        android:textSize="20sp" />

    <TextView
        android:id="@+id/textTitle"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_alignParentTop="true"
        android:layout_marginTop="6dp"
        android:padding="8dp"
        android:text="默认测试图片"
        android:textColor="@android:color/white"
        android:textSize="18sp" />

    <Button
        android:id="@+id/button"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_centerInParent="true"

        android:textColor="@android:color/white"
        android:text="自行测试"
        android:padding="16dp"/>

</RelativeLayout>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
    android:layout_width="match_parent"
    android:layout_height="match_parent">

    <ImageView
        android:id="@+id/image"
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        android:scaleType="fitCenter" />

    <Button
        android:id="@+id/btn_capture"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_gravity="center"
        android:text="选择照片" />

    <TextView
        android:id="@+id/text_prediction"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_gravity="bottom"
        android:background="#80000000"
        android:padding="10dp"
        android:textColor="@android:color/white"
        android:textSize="18sp"/>

</FrameLayout>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

MainAvity:

public class MainActivity extends AppCompatActivity {
    Button takePictureBtn=null;
    Bitmap bitmap = null;
    Module module = null;
    // 将图片显示在界面上
    ImageView imageView = null;
    TextView textView = null;
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        initView();
        //从assets中获得图片数据
        try {
            bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
            imgClassify(bitmap);
        } catch (IOException e) {
            Log.e("PytorchHelloWorld", "Error reading assets", e);
            finish();
        }


    }

    private void imgClassify(Bitmap img){
        try {
            // 加载PyTorch序列化模型
            module = Module.load(assetFilePath(this, "model101.pt"));
        } catch (IOException e) {
            Log.e("PytorchHelloWorld", "Error reading assets", e);
            finish();
        }
        imageView.setImageBitmap(img);
        // 建立输入张量
        final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(img,
                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
        // 运行模型推理
        final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
        // 获取推理结果
        final float[] scores = outputTensor.getDataAsFloatArray();
        //  获取概率最高的分类的索引号
        float maxScore = -Float.MAX_VALUE;
        int maxScoreIdx = -1;
        for (int i = 0; i < scores.length; i++) {
            if (scores[i] > maxScore) {
                maxScore = scores[i];
                maxScoreIdx = i;
            }
        }
        //通过索引号获得分类名称
        String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
        textView.setText("识别结果为:"+className);
    }

    private void initView(){
        takePictureBtn = findViewById(R.id.button);
        // 将图片显示在界面上
        imageView = findViewById(R.id.image);
        // 显示分类名称
        textView = findViewById(R.id.text);

        takePictureBtn.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View v) {
                Intent intent = new Intent(MainActivity.this,TakePicturesActivity.class);
                startActivity(intent);
            }
        });
    }

    //返回模型文件路径
    public static String assetFilePath(Context context, String assetName) throws IOException {
        File file = new File(context.getFilesDir(), assetName);
        if (file.exists() && file.length() > 0) {
            return file.getAbsolutePath();
        }

        try (InputStream is = context.getAssets().open(assetName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }
    }
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91

TakePicturesActivity:

package com.example.newdemo05;

import static com.example.newdemo05.MainActivity.assetFilePath;

import androidx.annotation.NonNull;
import androidx.appcompat.app.AppCompatActivity;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;

import android.content.ContentResolver;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.hardware.camera2.CameraAccessException;
import android.hardware.camera2.CameraCharacteristics;
import android.hardware.camera2.CameraDevice;
import android.hardware.camera2.CameraManager;
import android.hardware.camera2.params.StreamConfigurationMap;
import android.net.Uri;
import android.os.Bundle;
import android.provider.MediaStore;
import android.util.Log;
import android.view.TextureView;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;

import com.example.newdemo05.classify.ImageNetClasses;

import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;

public class TakePicturesActivity extends AppCompatActivity {
    private TextureView textureView;
    private TextView textPrediction;
    private Button openCam;
    private ImageView v_img;
    Module module = null;
    private CameraDevice cameraDevice;
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_take_pictures);
        initView();
    }
    private void initView(){
        v_img = findViewById(R.id.image);
        textPrediction = findViewById(R.id.text_prediction);
        openCam = findViewById(R.id.btn_capture);
        openCam.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View v) {

                startCamera();
            }
        });
    }
    // 打开相机
    public void startCamera(){
        Intent intent = new Intent(Intent.ACTION_PICK, null);
        //调用setDataAndType方法,指定了选择的数据类型为图片
        //设置数据的URI为MediaStore.Images.Media.EXTERNAL_CONTENT_URI,表示选择外部存储中的图片
        intent.setDataAndType(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, "image/*");
        //调用startActivityForResult方法,将Intent发送给系统,并指定一个请求码为2,以便在之后的回调中处理用户选择的图片
        startActivityForResult(intent, 2);

    }
    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        super.onActivityResult(requestCode, resultCode, data);
        if (requestCode == 2) {
            // 从相册返回的数据
            Log.e(this.getClass().getName(), "Result:" + data.toString());
            if (data != null) {
                // 得到图片的全路径
                Uri uri = data.getData();
                v_img.setImageURI(uri);
                ContentResolver cr = getContentResolver();
                InputStream inputStream = null;
                try {
                    inputStream = cr.openInputStream(uri);
                } catch (FileNotFoundException e) {
                    throw new RuntimeException(e);
                }

                Bitmap bitmap = BitmapFactory.decodeStream(inputStream);
                imgClassify(bitmap);
                Log.e(this.getClass().getName(), "Uri:" + String.valueOf(uri));

            }
        }
    }
    private void imgClassify(Bitmap img){
        try {
            // 加载PyTorch序列化模型
            module = Module.load(assetFilePath(this, "model101.pt"));
        } catch (IOException e) {
            Log.e("PytorchHelloWorld", "Error reading assets", e);
            finish();
        }
        v_img.setImageBitmap(img);
        // 建立输入张量
        final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(img,
                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
        // 运行模型推理
        final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
        // 获取推理结果
        final float[] scores = outputTensor.getDataAsFloatArray();
        //  获取概率最高的分类的索引号
        float maxScore = -Float.MAX_VALUE;
        int maxScoreIdx = -1;
        for (int i = 0; i < scores.length; i++) {
            if (scores[i] > maxScore) {
                maxScore = scores[i];
                maxScoreIdx = i;
            }
        }
        //通过索引号获得分类名称
        String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
        textPrediction.setText("识别结果为:"+className);
    }


}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132

然后新建一个ImageNetClasses,来装我们的标签文件:

public class ImageNetClasses {
  public static String[] IMAGENET_CLASSES = new String[]{};}
  • 1
  • 2

然后负责粘贴处理好的标签txt文件数据:

在这里插入图片描述
具体代码逻辑如下:

  1. TakePicturesActivity: 主Activity,用于打开相机或相册,获取图片并显示。
  2. initView(): 初始化视图控件,如ImageView、TextView和Button。
  3. startCamera(): 打开系统相册,调用startActivityForResult获取图片。
  4. onActivityResult(): 接收选择的图片,将图片显示在ImageView上。
  5. imgClassify(): 使用PyTorch模型对图片进行图像分类,将结果显示在TextView上。
  6. Module: PyTorch的模型类,用于加载pt模型并进行推理。
  7. Tensor: PyTorch的张量类,用于转换Bitmap为模型输入张量。
  8. TensorImageUtils: PyTorch的图片工具类,用于图像数据的预处理。
  9. ImageNetClasses: 包含ImageNet数据集的所有类别名称。
    主要步骤是:
  10. 初始化视图控件
  11. 打开相册,获取图片uri
  12. 将图片解码为Bitmap
  13. 用TensorImageUtils处理为模型输入张量
  14. 加载PyTorch模型并进行推理
  15. 处理输出结果,显示分类类别名称
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/767351
推荐阅读
相关标签
  

闽ICP备14008679号