赞
踩
提示:
注意文章时效性
,2022.04.02。
最近在搞图像分类模型移植到Android上,本来是准备用Tensorflow来搞的,但是百度到的一些博文案例都有些老,17、18年的,然后找Tensorflow官方实现的例子,发现最开始的例子已经弃用了,换了个地方。但是这新例子里的README也没讲怎么处理模型,Tensorflow官网时常出现Service Unavailable,再加上我用Tensorflow实现的模型跑出的结果很奇怪。Pytorch倒是能找到比较新一点的例子:
果断放弃Tensorflow,改用Pytorch,参考官方的给例子操作,模型还是能够跑出来的。
这里就简单记录下实现过程和遇到的一些错误。
废话结束,正文开始。
使用的环境 | 版本 |
---|---|
训练模型: | ↓ |
Python | 3.7.3 |
Pytorch | 1.11.0 |
导出模型: | ↓ |
Python | 3.9 |
Pytorch | 1.9.0 |
Android部署: | ↓ |
Android Studio | 4.1.1 |
pytorch_android_lite | 1.9.0 |
pytorch_android_torchvision | 1.9.0 |
如果有类似这样的报错:
No toolchains found in the NDK toolchains folder for ABI with prefix: arm-linux-androideabi
可能是NDK的问题,没安装NDK或者安装了ND但K缺少对应的库,可以参考这篇博文安装(完美解决 No toolchains found in the NDK toolchains folder for ABI with prefix: mips64el-linux-android_CodeForCoffee的博客-CSDN博客 )。不过,里面下载NDK的网址进不去了,可以到这里下载(AndroidDevTools - Android开发工具 Android SDK下载 Android Studio下载 Gradle下载 SDK Tools下载 )
按照参考的博文和官方教程讲的,都是要导出自己的模型的。博文里的方法也试了,不过最后我自己成功跑出来的,是在官方的例子上改的,如下:
import torch from torch.utils.mobile_optimizer import optimize_for_mobile from model_v3 import mobilenet_v3_large # 导入自己的模型 model_pth = './MobileNetV3-20220330-01.pth' # 训练得到的模型参数文件的路径 mobile_ptl = './mobilenetV3large.ptl' # 模型保存为Android可以调用的文件的路径 model = mobilenet_v3_large(num_classes=7) # 实例化模型 pre_weights = torch.load(model_pth, map_location='cpu') # 读取参数 model.load_state_dict(pre_weights, strict=True) # 将参数载入到模型 device = torch.device('cpu') # 将torch.Tensor分配到的设备的对象,有cpu和cuda两种 model.to(device) # 将模型加载到指定设备上 model.eval() # 将模型设为验证模式 example = torch.rand(1, 3, 224, 224) # 输入样例的格式为一张224*224的3通道图像 # 上面是准备模型,下面就是转换了 traced_script_module = torch.jit.trace(model, example) traced_script_module_optimized = optimize_for_mobile(traced_script_module) traced_script_module_optimized._save_for_lite_interpreter(mobile_ptl)
Pytorch官方的例子用的模型是预训练好的MobileNetV2,导入torchvision,然后调用。
……
import torchvision
……
model = torchvision.models.mobilenet_v2(pretrained=True)
……
如果只载入参数,会报错;
AttributeError: 'collections.OrderedDict' object has no attribute 'eval' ……
只载入模型网络训练模型,等于没训练,模型没参数。
所以保存模型文件的时候,一般有两种不同的方式:
# Save:
torch.save(model.state_dict(), PATH)
# Load:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
# Save:
torch.save(model, PATH)
# Load:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
更具体的说明可以看官方的文档(SAVING AND LOADING MODELS)
虽然那些参考博文都说是要导出为.pt
文件,但是我在运行载入完整的模型导出的.pt
文件运行会报错:
java.lang.RuntimeException: Unable to start activity ComponentInfo{org.pytorch.helloworld/org.pytorch.helloworld.MainActivity}: com.facebook.jni.CppException: PytorchStreamReader failed locating file bytecode.pkl: file not found ()
Exception raised from valid at ../caffe2/serialize/inline_container.cc:157 (most recent call first):
(no backtrace available)
照官方例子写的,导出成.ptl
文件就能成功运行。
这部分参考这篇博文(如何将pytorch模型部署到安卓,实现的和官方的例子差不多)的安卓部署部分,虽然最开始参考这篇博文写,没跑成功。
下面就参考大佬的步骤再走一遍。
直接新建一个Empty Activity
,点击Next。
取个名字,就叫myModel
了,其他保持默认,点击Finish。
导入pytorch_android_lite的包(与pytorch_android不同区分,载入模型的方法不同)。
//Pytorch
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
完整build.gradle(:app)
如下:
plugins { id 'com.android.application' } android { compileSdkVersion 30 buildToolsVersion "30.0.3" defaultConfig { applicationId "com.test.mymodel" minSdkVersion 23 targetSdkVersion 30 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" } buildTypes { release { minifyEnabled false proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' } } compileOptions { sourceCompatibility JavaVersion.VERSION_1_8 targetCompatibility JavaVersion.VERSION_1_8 } } dependencies { implementation 'androidx.appcompat:appcompat:1.2.0' implementation 'com.google.android.material:material:1.2.1' implementation 'androidx.constraintlayout:constraintlayout:2.0.4' testImplementation 'junit:junit:4.+' androidTestImplementation 'androidx.test.ext:junit:1.1.2' androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0' //Pytorch implementation 'org.pytorch:pytorch_android_lite:1.9.0' implementation 'org.pytorch:pytorch_android_torchvision:1.9.0' }
注意:如果导出模型使用的Pytorch版本与Android项目使用的pytorch_andorid_lite包的版本不一样会报错。
java.lang.RuntimeException: Unable to start activity ComponentInfo{org.pytorch.helloworld/org.pytorch.helloworld.MainActivity}: com.facebook.jni.CppException: Lite Interpreter verson number does not match. The model version must be between 3 and 5But the model version is 7 ()
Exception raised from parseMethods at ../torch/csrc/jit/mobile/import.cpp:320 (most recent call first):
(no backtrace available)
我训练模型用的Pytorch版本是1.11.0
,用这个版本导出的来跑会有上面这个错误,换成Android上相同版本的1.9.0
就能跑了。
放了一个TextView
用来显示文字结果,一个ImageView
用来展示图片。
完整activity_main.xml
文件如下:
<?xml version="1.0" encoding="utf-8"?> <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" xmlns:tools="http://schemas.android.com/tools" android:layout_width="match_parent" android:layout_height="match_parent" android:orientation="vertical" tools:context=".MainActivity"> <TextView android:id="@+id/tv" android:layout_weight="1" android:layout_width="match_parent" android:layout_height="0dp" android:layout_margin="10dp" android:layout_gravity="center" android:text="Hello World!" android:textSize="50sp" android:textAlignment="center" android:textStyle="bold"/> <ImageView android:id="@+id/iv" android:layout_weight="4" android:layout_width="match_parent" android:layout_height="0dp" android:layout_margin="10dp" android:background="#f0f0f0" android:layout_gravity="center" android:contentDescription="@string/iv_text" /> </LinearLayout>
新建EmotionClasses.java
类文件,我这里是表情分类,有七个类别,按训练的标签顺序放里面。(顺序不对的话,结果也会错位)
package com.test.mymodel;
public class EmotionClasses {
public static String[] EMOTION_CLASSES = new String[]{
"anger",
"disgust",
"fear",
"happy",
"normal",
"sad",
"surprised"
};
}
在main
文件夹下新建assets
文件夹,并将模型的.ptl
文件和要识别图片放入其中。(图片需要是前面导出模型时设置的example
的大小,我这里是224*224的彩色图片)
在MainActivity.java
载入模型,对图片进行识别。
package com.test.mymodel; import android.content.Context; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.os.Bundle; import android.util.Log; import android.widget.ImageView; import android.widget.TextView; import androidx.appcompat.app.AppCompatActivity; import org.pytorch.IValue; import org.pytorch.LiteModuleLoader; import org.pytorch.MemoryFormat; import org.pytorch.Module; import org.pytorch.Tensor; import org.pytorch.torchvision.TensorImageUtils; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; public class MainActivity extends AppCompatActivity { @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); Bitmap bitmap = null; Module module = null; try { // creating bitmap from packaged into app android asset 'image.jpg', // app/src/main/assets/image.jpg bitmap = BitmapFactory.decodeStream(getAssets().open("happy01.jpg")); // loading serialized torchscript module from packaged into app android asset model.pt, // app/src/model/assets/model.pt module = LiteModuleLoader.load(assetFilePath(this, "mobilenetV3large.ptl")); } catch (IOException e) { Log.e("PytorchHelloWorld", "Error reading assets", e); finish(); } // showing image on UI ImageView imageView = findViewById(R.id.iv); imageView.setImageBitmap(bitmap); // preparing input tensor final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST); // running the model final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor(); // getting tensor content as java array of floats final float[] scores = outputTensor.getDataAsFloatArray(); // searching for the index with maximum score float maxScore = -Float.MAX_VALUE; int maxScoreIdx = -1; for (int i = 0; i < scores.length; i++) { if (scores[i] > maxScore) { maxScore = scores[i]; maxScoreIdx = i; } } String className = EmotionClasses.EMOTION_CLASSES[maxScoreIdx]; // showing className on UI TextView textView = findViewById(R.id.tv); textView.setText(className); } /** * Copies specified asset to the file in /files app directory and returns this file absolute path. * * @return absolute file path */ public static String assetFilePath(Context context, String assetName) throws IOException { File file = new File(context.getFilesDir(), assetName); if (file.exists() && file.length() > 0) { return file.getAbsolutePath(); } try (InputStream is = context.getAssets().open(assetName)) { try (OutputStream os = new FileOutputStream(file)) { byte[] buffer = new byte[4 * 1024]; int read; while ((read = is.read(buffer)) != -1) { os.write(buffer, 0, read); } os.flush(); } return file.getAbsolutePath(); } } }
注意:如果使用的是pytorch_android_lite
依赖库,却使用Module.load()
方法载入模型,会报错,提示找不到libpytorch_jni.so
这个库,就需要使用LiteModuleLoader.load()
方法来载入模型。官方的issue有人提过couldn’t find “libpytorch_jni.so”。
java.lang.UnsatisfiedLinkError: dlopen failed: library "libpytorch_jni.so" not found
运行结果如下:
如果类别顺序错位,识别结果也会错位,如下图所示,将anger
调到第四位,识别结果就成了anger
。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。