赞
踩
由于读研导师的研究方向的需要,刚接触关于模型导入的问题,这次记录一下我的新手历程,文章参考pytorch的官方文档Android | PyTorch和大佬的文章如何将pytorch模型部署到安卓_pytorch移植到安卓_普通网友的博客-CSDN博客
我们在pycharm里保存的模型为.pth格式,如果需要将其部署到Android studio上需要把格式转换为.pt。我是直接在模型训练完毕之后,直接保存为.pt格式。
- model.eval() # Set the model to evaluation mode
- example_input = torch.rand(1, X_train.shape[1]) # Example input tensor
- traced_model = torch.jit.trace(model, example_input)
- traced_model.save("model_mobile.pt")
Android studio部署pytorch模型需要导包,在app的build.gradle中导入包,而且你导入的包的版本需要和你的pytorch版本一致,由于我的pytorch版本为1.13.0,所以如下:
- dependencies {
- implementation fileTree(dir: 'libs', include: ['*.jar'])
-
- implementation 'androidx.appcompat:appcompat:1.6.1'
- implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
- testImplementation 'junit:junit:4.12'
- androidTestImplementation 'androidx.test.ext:junit:1.1.5'
- androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'
-
- // pytorch包
- implementation 'org.pytorch:pytorch_android:1.13.0' // 根据您的实际PyTorch版本更改
-
- }
如果需要使用图像识别的模型,需要增加另一个包:
implementation 'org.pytorch:pytorch_android_torchvision:1.13.0'
在 app/src/main/
目录下,新建assets文件夹,并将模型拷贝到文件夹中:
具体代码如下(我这个模型有7个输入,输出3个类别是属于分类):
- <?xml version="1.0" encoding="utf-8"?>
- <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
- xmlns:app="http://schemas.android.com/apk/res-auto"
- xmlns:tools="http://schemas.android.com/tools"
- android:layout_width="match_parent"
- android:layout_height="match_parent"
- android:orientation="vertical"
- android:padding="16dp"
- tools:context=".MainActivity">
-
- <EditText
- android:id="@+id/input1"
- android:layout_width="match_parent"
- android:layout_height="wrap_content"
- android:hint="Input 1"
- android:inputType="numberDecimal" />
-
- <EditText
- android:id="@+id/input2"
- android:layout_width="match_parent"
- android:layout_height="wrap_content"
- android:hint="Input 2"
- android:inputType="numberDecimal" />
-
- <EditText
- android:id="@+id/input3"
- android:layout_width="match_parent"
- android:layout_height="wrap_content"
- android:hint="Input 3"
- android:inputType="numberDecimal" />
-
- <EditText
- android:id="@+id/input4"
- android:layout_width="match_parent"
- android:layout_height="wrap_content"
- android:hint="Input 4"
- android:inputType="numberDecimal" />
-
- <EditText
- android:id="@+id/input5"
- android:layout_width="match_parent"
- android:layout_height="wrap_content"
- android:hint="Input 5"
- android:inputType="numberDecimal" />
-
- <EditText
- android:id="@+id/input6"
- android:layout_width="match_parent"
- android:layout_height="wrap_content"
- android:hint="Input 6"
- android:inputType="numberDecimal" />
-
- <EditText
- android:id="@+id/input7"
- android:layout_width="match_parent"
- android:layout_height="wrap_content"
- android:hint="Input 7"
- android:inputType="numberDecimal" />
-
- <!-- Repeat for all 7 inputs -->
-
- <Button
- android:id="@+id/predictButton"
- android:layout_width="match_parent"
- android:layout_height="wrap_content"
- android:text="Predict"
- android:layout_marginTop="16dp" />
-
- <TextView
- android:id="@+id/resultTextView"
- android:layout_width="match_parent"
- android:layout_height="wrap_content"
- android:text="Prediction result will appear here."
- android:textSize="16sp"
- android:layout_marginTop="16dp" />
-
-
- </LinearLayout>
- import android.annotation.SuppressLint;
- import android.content.Context;
- import android.os.Bundle;
- import android.util.Log;
- import android.view.View;
- import android.widget.Button;
- import android.widget.EditText;
- import android.widget.TextView;
- import androidx.appcompat.app.AppCompatActivity;
- import org.pytorch.IValue;
- import org.pytorch.Module;
- import org.pytorch.Tensor;
- import java.io.File;
- import java.io.FileOutputStream;
- import java.io.IOException;
- import java.io.InputStream;
- import java.io.OutputStream;
-
- public class MainActivity extends AppCompatActivity {
-
- private EditText[] inputs = new EditText[7];
- private Button predictButton;
- private Module model;
- private TextView resultTextView;
-
- @Override
- protected void onCreate(Bundle savedInstanceState) {
- super.onCreate(savedInstanceState);
- setContentView(R.layout.activity_main);
-
- // 1. Reference the EditText and Button components
- for (int i = 0; i < 7; i++) {
- int resourceId = getResources().getIdentifier("input" + (i + 1), "id", getPackageName());
- inputs[i] = findViewById(resourceId);
- }
- // similarly for other EditTexts
- predictButton = findViewById(R.id.predictButton);
- // Reference the TextView
- resultTextView = findViewById(R.id.resultTextView);
-
- // Load the model
- try {
- model = Module.load(assetFilePath(this, "model_mobile.pt"));
- } catch (IOException e) {
- e.printStackTrace();
- Log.e("PyTorchAndroid", "Error loading model", e);
- }
-
-
- // 2. Set up the button click listener
- predictButton.setOnClickListener(new View.OnClickListener() {
- @SuppressLint("SetTextI18n")
- @Override
- public void onClick(View v) {
- if (model ==null) {
- Log.e("PyTorchAndroid", "Model is null!");
- return;
- } else {
- float[] modelInput = new float[7];
- for (int i = 0; i < 7; i++) {
- modelInput[i] = Float.parseFloat(inputs[i].getText().toString());
- }
-
- // Run the model and get prediction
- Tensor inputTensor = Tensor.fromBlob(modelInput, new long[]{1, 7});
- IValue output = model.forward(IValue.from(inputTensor));
- float[] scores = output.toTensor().getDataAsFloatArray();
-
- // Display the result (you can adjust this to your needs)
- int predictedClass = argMax(scores);
- resultTextView.setText("Predicted Class: " + (predictedClass + 1)); // Adjust based on your class labels
- }
- }
- });
- }
-
- public static String assetFilePath(Context context, String assetName) throws IOException {
- File file = new File(context.getFilesDir(), assetName);
- if (file.exists() && file.length() > 0) {
- return file.getAbsolutePath();
- }
-
- try (InputStream is = context.getAssets().open(assetName)) {
- try (OutputStream os = new FileOutputStream(file)) {
- byte[] buffer = new byte[4 * 1024];
- int read;
- while ((read = is.read(buffer)) != -1) {
- os.write(buffer, 0, read);
- }
- os.flush();
- }
- return file.getAbsolutePath();
- }
- }
-
- public static int argMax(float[] scores) {
- int maxIndex = -1;
- float maxScore = Float.NEGATIVE_INFINITY;
- for (int i = 0; i < scores.length; i++) {
- if (scores[i] > maxScore) {
- maxScore = scores[i];
- maxIndex = i;
- }
- }
- return maxIndex;
- }
- }
那个结果分类我没有写具体,可以根据实际的类别更改
我最开始也是参考大佬的文章,发现当我部署自己的模型时,app闪退,我总结有以下两种问题导致:
1.Android studio的版本太低运行不了模型所以闪退,推荐使用较高版本的Android studio,附图我的:
app中的build gradle配置:
- android {
- compileSdkVersion 33
- buildToolsVersion '33.0.0'
-
- defaultConfig {
- applicationId "com.example.testmodel"
- minSdkVersion 23
- targetSdkVersion 33
- versionCode 1
- versionName "1.0"
-
- testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
- }
-
- buildTypes {
- release {
- minifyEnabled false
- proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
- }
- }
-
- }
项目中的 build gradle
- dependencies {
- classpath 'com.android.tools.build:gradle:7.4.2'
-
-
- // NOTE: Do not place your application dependencies here; they belong
- // in the individual module build.gradle files
- }
2.pytorch的版本与 Android studio中的org.pytorch:pytorch_android:1.13.0包的本版本不一致,需要查看pytorch的版本进行导包(我这个好像是最高的版本了)
3.模型的格式转换一定要保存为.pt的格式
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。