当前位置:   article > 正文

将一个BP神经网络预测模型部署Android studio_怎么把训练好的模型导入到android studio中

怎么把训练好的模型导入到android studio中

       由于读研导师的研究方向的需要,刚接触关于模型导入的问题,这次记录一下我的新手历程,文章参考pytorch的官方文档Android | PyTorch和大佬的文章如何将pytorch模型部署到安卓_pytorch移植到安卓_普通网友的博客-CSDN博客

1.模型的保存

       我们在pycharm里保存的模型为.pth格式,如果需要将其部署到Android studio上需要把格式转换为.pt。我是直接在模型训练完毕之后,直接保存为.pt格式。

  1. model.eval() # Set the model to evaluation mode
  2. example_input = torch.rand(1, X_train.shape[1]) # Example input tensor
  3. traced_model = torch.jit.trace(model, example_input)
  4. traced_model.save("model_mobile.pt")

2.Android studio的环境配置 

        Android studio部署pytorch模型需要导包,在app的build.gradle中导入包,而且你导入的包的版本需要和你的pytorch版本一致,由于我的pytorch版本为1.13.0,所以如下:

  1. dependencies {
  2. implementation fileTree(dir: 'libs', include: ['*.jar'])
  3. implementation 'androidx.appcompat:appcompat:1.6.1'
  4. implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
  5. testImplementation 'junit:junit:4.12'
  6. androidTestImplementation 'androidx.test.ext:junit:1.1.5'
  7. androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'
  8. // pytorch包
  9. implementation 'org.pytorch:pytorch_android:1.13.0' // 根据您的实际PyTorch版本更改
  10. }

       如果需要使用图像识别的模型,需要增加另一个包:

implementation 'org.pytorch:pytorch_android_torchvision:1.13.0'

3.将模型复制到Android studio中的特定文件夹中去

       在 app/src/main/目录下,新建assets文件夹,并将模型拷贝到文件夹中:

4.Android studio中的代码实现

4.1 创建一个Empty Activity并编写xml文件

具体代码如下(我这个模型有7个输入,输出3个类别是属于分类):

  1. <?xml version="1.0" encoding="utf-8"?>
  2. <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
  3. xmlns:app="http://schemas.android.com/apk/res-auto"
  4. xmlns:tools="http://schemas.android.com/tools"
  5. android:layout_width="match_parent"
  6. android:layout_height="match_parent"
  7. android:orientation="vertical"
  8. android:padding="16dp"
  9. tools:context=".MainActivity">
  10. <EditText
  11. android:id="@+id/input1"
  12. android:layout_width="match_parent"
  13. android:layout_height="wrap_content"
  14. android:hint="Input 1"
  15. android:inputType="numberDecimal" />
  16. <EditText
  17. android:id="@+id/input2"
  18. android:layout_width="match_parent"
  19. android:layout_height="wrap_content"
  20. android:hint="Input 2"
  21. android:inputType="numberDecimal" />
  22. <EditText
  23. android:id="@+id/input3"
  24. android:layout_width="match_parent"
  25. android:layout_height="wrap_content"
  26. android:hint="Input 3"
  27. android:inputType="numberDecimal" />
  28. <EditText
  29. android:id="@+id/input4"
  30. android:layout_width="match_parent"
  31. android:layout_height="wrap_content"
  32. android:hint="Input 4"
  33. android:inputType="numberDecimal" />
  34. <EditText
  35. android:id="@+id/input5"
  36. android:layout_width="match_parent"
  37. android:layout_height="wrap_content"
  38. android:hint="Input 5"
  39. android:inputType="numberDecimal" />
  40. <EditText
  41. android:id="@+id/input6"
  42. android:layout_width="match_parent"
  43. android:layout_height="wrap_content"
  44. android:hint="Input 6"
  45. android:inputType="numberDecimal" />
  46. <EditText
  47. android:id="@+id/input7"
  48. android:layout_width="match_parent"
  49. android:layout_height="wrap_content"
  50. android:hint="Input 7"
  51. android:inputType="numberDecimal" />
  52. <!-- Repeat for all 7 inputs -->
  53. <Button
  54. android:id="@+id/predictButton"
  55. android:layout_width="match_parent"
  56. android:layout_height="wrap_content"
  57. android:text="Predict"
  58. android:layout_marginTop="16dp" />
  59. <TextView
  60. android:id="@+id/resultTextView"
  61. android:layout_width="match_parent"
  62. android:layout_height="wrap_content"
  63. android:text="Prediction result will appear here."
  64. android:textSize="16sp"
  65. android:layout_marginTop="16dp" />
  66. </LinearLayout>

4.2 Mainactivity的代码实现

  1. import android.annotation.SuppressLint;
  2. import android.content.Context;
  3. import android.os.Bundle;
  4. import android.util.Log;
  5. import android.view.View;
  6. import android.widget.Button;
  7. import android.widget.EditText;
  8. import android.widget.TextView;
  9. import androidx.appcompat.app.AppCompatActivity;
  10. import org.pytorch.IValue;
  11. import org.pytorch.Module;
  12. import org.pytorch.Tensor;
  13. import java.io.File;
  14. import java.io.FileOutputStream;
  15. import java.io.IOException;
  16. import java.io.InputStream;
  17. import java.io.OutputStream;
  18. public class MainActivity extends AppCompatActivity {
  19. private EditText[] inputs = new EditText[7];
  20. private Button predictButton;
  21. private Module model;
  22. private TextView resultTextView;
  23. @Override
  24. protected void onCreate(Bundle savedInstanceState) {
  25. super.onCreate(savedInstanceState);
  26. setContentView(R.layout.activity_main);
  27. // 1. Reference the EditText and Button components
  28. for (int i = 0; i < 7; i++) {
  29. int resourceId = getResources().getIdentifier("input" + (i + 1), "id", getPackageName());
  30. inputs[i] = findViewById(resourceId);
  31. }
  32. // similarly for other EditTexts
  33. predictButton = findViewById(R.id.predictButton);
  34. // Reference the TextView
  35. resultTextView = findViewById(R.id.resultTextView);
  36. // Load the model
  37. try {
  38. model = Module.load(assetFilePath(this, "model_mobile.pt"));
  39. } catch (IOException e) {
  40. e.printStackTrace();
  41. Log.e("PyTorchAndroid", "Error loading model", e);
  42. }
  43. // 2. Set up the button click listener
  44. predictButton.setOnClickListener(new View.OnClickListener() {
  45. @SuppressLint("SetTextI18n")
  46. @Override
  47. public void onClick(View v) {
  48. if (model ==null) {
  49. Log.e("PyTorchAndroid", "Model is null!");
  50. return;
  51. } else {
  52. float[] modelInput = new float[7];
  53. for (int i = 0; i < 7; i++) {
  54. modelInput[i] = Float.parseFloat(inputs[i].getText().toString());
  55. }
  56. // Run the model and get prediction
  57. Tensor inputTensor = Tensor.fromBlob(modelInput, new long[]{1, 7});
  58. IValue output = model.forward(IValue.from(inputTensor));
  59. float[] scores = output.toTensor().getDataAsFloatArray();
  60. // Display the result (you can adjust this to your needs)
  61. int predictedClass = argMax(scores);
  62. resultTextView.setText("Predicted Class: " + (predictedClass + 1)); // Adjust based on your class labels
  63. }
  64. }
  65. });
  66. }
  67. public static String assetFilePath(Context context, String assetName) throws IOException {
  68. File file = new File(context.getFilesDir(), assetName);
  69. if (file.exists() && file.length() > 0) {
  70. return file.getAbsolutePath();
  71. }
  72. try (InputStream is = context.getAssets().open(assetName)) {
  73. try (OutputStream os = new FileOutputStream(file)) {
  74. byte[] buffer = new byte[4 * 1024];
  75. int read;
  76. while ((read = is.read(buffer)) != -1) {
  77. os.write(buffer, 0, read);
  78. }
  79. os.flush();
  80. }
  81. return file.getAbsolutePath();
  82. }
  83. }
  84. public static int argMax(float[] scores) {
  85. int maxIndex = -1;
  86. float maxScore = Float.NEGATIVE_INFINITY;
  87. for (int i = 0; i < scores.length; i++) {
  88. if (scores[i] > maxScore) {
  89. maxScore = scores[i];
  90. maxIndex = i;
  91. }
  92. }
  93. return maxIndex;
  94. }
  95. }

4.3 运行结果

那个结果分类我没有写具体,可以根据实际的类别更改

5.我遇到的问题

我最开始也是参考大佬的文章,发现当我部署自己的模型时,app闪退,我总结有以下两种问题导致:

1.Android studio的版本太低运行不了模型所以闪退,推荐使用较高版本的Android studio,附图我的:

app中的build gradle配置:

  1. android {
  2. compileSdkVersion 33
  3. buildToolsVersion '33.0.0'
  4. defaultConfig {
  5. applicationId "com.example.testmodel"
  6. minSdkVersion 23
  7. targetSdkVersion 33
  8. versionCode 1
  9. versionName "1.0"
  10. testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
  11. }
  12. buildTypes {
  13. release {
  14. minifyEnabled false
  15. proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
  16. }
  17. }
  18. }

项目中的 build gradle

  1. dependencies {
  2. classpath 'com.android.tools.build:gradle:7.4.2'
  3. // NOTE: Do not place your application dependencies here; they belong
  4. // in the individual module build.gradle files
  5. }

2.pytorch的版本与 Android studio中的org.pytorch:pytorch_android:1.13.0包的本版本不一致,需要查看pytorch的版本进行导包(我这个好像是最高的版本了)

3.模型的格式转换一定要保存为.pt的格式

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/木道寻08/article/detail/767355
推荐阅读
相关标签
  

闽ICP备14008679号