赞
踩
将python中训练的深度学习模型(图像分类、目标检测、语义分割等)部署到Android中使用。
1、下载Pytorch Android库。
在Pytorch的官网pytorch.org上找到最新版本的库。下载后,将其解压缩到项目的某个目录下。
2、配置项目gradle文件
配置项目的gradle文件,向项目添加Pytorch Android库的依赖项。打开项目的build.gradle文件,添加以下代码:
repositories {
// 添加以下两行代码
maven {
url "https://oss.sonatype.org/content/repositories/snapshots/"
}
}
dependencies {
// 添加以下两行代码
implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT'
}
3、将库文件添加到项目中
将Pytorch Android库的库文件添加到项目中。可以将其复制到“libs”文件夹中,并在项目的gradle文件中添加以下代码:
android {
sourceSets {
main {
jniLibs.srcDirs = ['libs']
}
}
}
4、配置NDK版本
确保项目使用了支持Pytorch Android库的NDK版本。打开项目的local.properties文件,添加以下代码:
//NDK目录
ndk.dir=/path/to/your/ndk
5、同步gradle文件
在Android Studio中,点击“Sync Project with Gradle Files”按钮,等待同步完成。
到这就集成了Pytorch Android库。可以在应用程序中使用Pytorch Android库提供的API加载模型文件并进行预测。
假如我们的深度学习模型输入图片大小尺寸为(640,640,3),并且已经在python中训练好了my_model.pth,那么我们需要将其转换为.pt格式:
import torch
# 加载PyTorch模型
model = torch.load("my_model.pth")
# 将PyTorch模型转换为TorchScript格式
traced_script_module = torch.jit.trace(model, torch.randn(1, 3, 640, 640))
traced_script_module.save("my_model.pt")
转换Pytorch模型为TorchScript格式时,需要确保使用的所有操作都是TorchScript支持的。否则,在转换模型时可能会出现错误。
要在Android Studio中创建新项目并将m_model.pt模型文件放入该项目中,包含以下步骤:
1、打开Android Studio,并选择“Create New Project”选项。
2、在“Create New Project”向导中,输入项目名称,选择项目保存位置,并选择“Phone and Tablet”作为您的应用程序目标设备。然后,单击“Next”继续。
3、选择“Empty Activity”模板,并单击“Next”继续。
4、在“Configure Activity”对话框中,输入活动名称并单击“Finish”完成项目创建过程。
5、在项目中创建一个名为“assets”的文件夹。要创建该文件夹,请右键单击项目根目录,选择“New” -> “Folder” -> “Assets Folder”。
6、将m_model.pt模型文件复制到“assets”文件夹中。要将文件复制到“assets”文件夹中,右键单击该文件夹,选择“Show in Explorer”或“Show in Finder”,然后将文件复制到打开的文件夹中。
7、在代码中加载模型文件使用以下代码示例加载模型文件:
AssetManager assetManager = getAssets(); String modelPath = "m_model.pt"; File modelFile = new File(getCacheDir(), modelPath); try (InputStream inputStream = assetManager.open(modelPath); FileOutputStream outputStream = new FileOutputStream(modelFile)) { byte[] buffer = new byte[4 * 1024]; int read; while ((read = inputStream.read(buffer)) != -1) { outputStream.write(buffer, 0, read); } outputStream.flush(); } catch (IOException e) { e.printStackTrace(); } // 加载PyTorch模型 Module model = Module.load(modelFile.getAbsolutePath());
在这里需要注意将模型文件保存到应用程序的缓存目录中,而不是将其保存在项目资源中。这是因为在运行时,Android应用程序不能直接读取项目资源,而是需要使用AssetManager类从“assets”文件夹中读取文件。
接下来示例运行模型、获取模型输出和在主线程中更新UI的代码:
import org.pytorch.IValue; import org.pytorch.Module; import org.pytorch.Tensor; import org.pytorch.torchvision.TensorImageUtils; import android.content.res.AssetManager; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.os.Bundle; import android.os.Handler; import android.os.Looper; import android.util.Log; import androidx.appcompat.app.AppCompatActivity; import androidx.camera.core.CameraX; import androidx.camera.core.ImageAnalysis; import androidx.camera.core.ImageProxy; import androidx.camera.core.Preview; import androidx.camera.lifecycle.ProcessCameraProvider; import androidx.camera.view.PreviewView; import androidx.core.content.ContextCompat; import androidx.lifecycle.LifecycleOwner; import java.io.IOException; import java.io.InputStream; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; public class MainActivity extends AppCompatActivity { private static final String MODEL_PATH = "m_model.pt"; private static final int INPUT_SIZE = 224; private Module mModule; private ExecutorService mExecutorService; private Handler mHandler; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); // 加载PyTorch模型和创建执行线程池 loadModel(); // 创建主线程处理程序 mHandler = new Handler(Looper.getMainLooper()); // 启动相机 startCamera(); } private void loadModel() { // 加载PyTorch模型 try { AssetManager assetManager = getAssets(); InputStream inputStream = assetManager.open(MODEL_PATH); mModule = Module.load(inputStream); } catch (IOException e) { Log.e("MainActivity", "Error reading model file: " + e.getMessage()); finish(); } // 创建执行线程池 mExecutorService = Executors.newSingleThreadExecutor(); } private void startCamera() { // 创建PreviewView PreviewView previewView = findViewById(R.id.preview_view); // 配置相机生命周期所有者 LifecycleOwner lifecycleOwner = this; // 配置相机预览 Preview preview = new Preview.Builder().build(); preview.setSurfaceProvider(previewView.getSurfaceProvider()); // 配置图像分析 ImageAnalysis imageAnalysis = new ImageAnalysis.Builder() .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST) .build(); // 设置图像分析的处理程序 imageAnalysis.setAnalyzer( mExecutorService, new ImageAnalysis.Analyzer() { @Override public void analyze(ImageProxy image, int rotationDegrees) { // 将ImageProxy转换为Bitmap Bitmap bitmap = Bitmap.createScaledBitmap( image.getImage(), INPUT_SIZE, INPUT_SIZE, false); // 将Bitmap转换为Tensor Tensor tensor = TensorImageUtils.bitmapToFloat32Tensor( bitmap, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB); // 创建输入列表 final IValue[] inputs = {IValue.from(tensor)}; // 运行模型 Tensor outputTensor = mModule.forward(inputs).toTensor(); // 获取模型输出 float[] scores = outputTensor.getDataAsFloatArray(); // 查找最高分数 float maxScore = -Float.MAX_VALUE; int maxScoreIndex = -1; for (int i = 0; i < scores.length; i++) { if (scores[i] > maxScore) { maxScore = scores[i]; maxScoreIndex = i; } } // 获取分类标签 String[] labels = getLabels(); String predictedLabel = labels[maxScoreIndex]; // 更新UI updateUI(predictedLabel); } }); // 绑定相机生命周期所有者 CameraX.bindToLifecycle(lifecycleOwner, preview, imageAnalysis); } private String[] getLabels() { // 在此处替换为标签文件 return new String[]{ "tench", "goldfish", "great white shark", "tiger shark", // ... }; } private void updateUI(String predictedLabel) { mHandler.post( new Runnable() { @Override public void run() { // 更新UI // 例如,将预测标签写入TextView // TextView textView = findViewById(R.id.text_view); // textView.setText(predictedLabel); } }); } @Override protected void onDestroy() { super.onDestroy(); // 释放模型和执行线程池 mModule.destroy(); mExecutorService.shutdown(); } }
当模型预测输入图像时,它将返回一个整数,该整数表示模型预测的图像类型的索引。可以使用该索引来查找对应的标签并更新UI。例如:
// 查找最高分数 float maxScore = -Float.MAX_VALUE; int maxScoreIndex = -1; for (int i = 0; i < scores.length; i++) { if (scores[i] > maxScore) { maxScore = scores[i]; maxScoreIndex = i; } } // 获取分类标签 String[] labels = getLabels(); String predictedLabel = labels[maxScoreIndex]; // 更新UI updateUI(predictedLabel);
编译和运行应用程序,并在Android Studio调试上测试图像识别功能。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。