赞
踩
PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。
cuDNN是NVIDIA专门针对深度神经网络中的基础操作而设计的基于GPU的加速库。cuDNN为深度神经网络中的标准流程提供了高度优化的实现方式,例如卷积、池化、归一化以及激活层的前向以及后向过程。当开发者们需要用到深度学习GPU加速时,才安装cuDNN库,工作速度相较CPU快很多。
需要安装cuda 和 cudnn 库的卸载与安装,可以参考以下教程:
cuda和cudnn安装
安装完成后,可按照下面教程安装pytorch:
此处需要用到anaconda:Anaconda是一个安装、管理Python相关包的软件,其包含了conda、Python等180多个科学包及其依赖项,可以用于在同一个机器上安装不同版本的软件包及其依赖,并能够在不同的环境之间切换。
另外,如需在pycharm进行操作,可安装以下教程配置虚拟环境:
pycharm配置讯环境(此步需要完成pytorch环境的安装,有anaconda环境)
所有环境准备好以后,查看配置是否完成:
import torch
# 查看版本
print(torch.__version__)
# 查看gup是否可用
print(torch.cuda.is_available())
# 返回gpu个数
print(torch.cuda.device_count())
# 查看对应cuda版本号
print(torch.backends.cudnn.version())
print(torch.version.cuda)
现在就开始准备移植所需的文件吧。
PT模型是一种完整的模型文件,不仅包含了模型的参数,还包括了模型的结构,可以直接被加载到模型中,开始进行训练和预测。而pth文件则只保存了模型的参数,因此在加载模型时需要重新定义模型结构。
所以,需要一个安卓端的pt模型
这里我们使用到图像识别的resnet101模型,新建一个py文件,运行以下代码:
import torch
import torchvision
model = torchvision.models.resnet101(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model101.pt")
这样,我们就获得了一个图像识别的pt模型。
然后,需要找到该模型的标签文件,我们从以下项目中获取
https://github.com/ethereon/caffe-tensorflow/blob/master/examples/imagenet/imagenet-classes.txt
拿到txt标签文件以后,我们需要转换成java中String[]的形式,所以稍微处理一下,
import re
with open('imagenet-classes.txt','r',encoding='utf-8') as f:
text = f.read()
text = re.sub(r'^', '"', text, flags=re.M)
text = re.sub(r'$', '",', text, flags=re.M)
with open('imagenet-classes.txt','w',encoding='utf-8') as f:
f.write(text)
这里是给开头和结尾都加上了"(双引号)以及结尾用,(逗号)分隔。
完成以后我们打开AndroidStudio
新建一个空的项目
在资源文件中放入我们的model101.pt文件:
然后Gradle导入pytorch的依赖:
implementation("org.pytorch:pytorch_android:1.12.1")
implementation("org.pytorch:pytorch_android_torchvision:1.12.1")
完成以后修改activity
<?xml version="1.0" encoding="utf-8"?> <RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android" android:layout_width="match_parent" android:layout_height="match_parent"> <ImageView android:id="@+id/image" android:layout_width="match_parent" android:layout_height="match_parent" android:scaleType="centerCrop" /> <TextView android:id="@+id/text" android:layout_width="match_parent" android:layout_height="wrap_content" android:layout_alignParentTop="true" android:layout_marginTop="644dp" android:padding="8dp" android:textColor="@android:color/white" android:textSize="20sp" /> <TextView android:id="@+id/textTitle" android:layout_width="match_parent" android:layout_height="wrap_content" android:layout_alignParentTop="true" android:layout_marginTop="6dp" android:padding="8dp" android:text="默认测试图片" android:textColor="@android:color/white" android:textSize="18sp" /> <Button android:id="@+id/button" android:layout_width="wrap_content" android:layout_height="wrap_content" android:layout_centerInParent="true" android:textColor="@android:color/white" android:text="自行测试" android:padding="16dp"/> </RelativeLayout>
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android" android:layout_width="match_parent" android:layout_height="match_parent"> <ImageView android:id="@+id/image" android:layout_width="match_parent" android:layout_height="match_parent" android:scaleType="fitCenter" /> <Button android:id="@+id/btn_capture" android:layout_width="wrap_content" android:layout_height="wrap_content" android:layout_gravity="center" android:text="选择照片" /> <TextView android:id="@+id/text_prediction" android:layout_width="match_parent" android:layout_height="wrap_content" android:layout_gravity="bottom" android:background="#80000000" android:padding="10dp" android:textColor="@android:color/white" android:textSize="18sp"/> </FrameLayout>
MainAvity:
public class MainActivity extends AppCompatActivity { Button takePictureBtn=null; Bitmap bitmap = null; Module module = null; // 将图片显示在界面上 ImageView imageView = null; TextView textView = null; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); initView(); //从assets中获得图片数据 try { bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg")); imgClassify(bitmap); } catch (IOException e) { Log.e("PytorchHelloWorld", "Error reading assets", e); finish(); } } private void imgClassify(Bitmap img){ try { // 加载PyTorch序列化模型 module = Module.load(assetFilePath(this, "model101.pt")); } catch (IOException e) { Log.e("PytorchHelloWorld", "Error reading assets", e); finish(); } imageView.setImageBitmap(img); // 建立输入张量 final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(img, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB); // 运行模型推理 final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor(); // 获取推理结果 final float[] scores = outputTensor.getDataAsFloatArray(); // 获取概率最高的分类的索引号 float maxScore = -Float.MAX_VALUE; int maxScoreIdx = -1; for (int i = 0; i < scores.length; i++) { if (scores[i] > maxScore) { maxScore = scores[i]; maxScoreIdx = i; } } //通过索引号获得分类名称 String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx]; textView.setText("识别结果为:"+className); } private void initView(){ takePictureBtn = findViewById(R.id.button); // 将图片显示在界面上 imageView = findViewById(R.id.image); // 显示分类名称 textView = findViewById(R.id.text); takePictureBtn.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { Intent intent = new Intent(MainActivity.this,TakePicturesActivity.class); startActivity(intent); } }); } //返回模型文件路径 public static String assetFilePath(Context context, String assetName) throws IOException { File file = new File(context.getFilesDir(), assetName); if (file.exists() && file.length() > 0) { return file.getAbsolutePath(); } try (InputStream is = context.getAssets().open(assetName)) { try (OutputStream os = new FileOutputStream(file)) { byte[] buffer = new byte[4 * 1024]; int read; while ((read = is.read(buffer)) != -1) { os.write(buffer, 0, read); } os.flush(); } return file.getAbsolutePath(); } } }
TakePicturesActivity:
package com.example.newdemo05; import static com.example.newdemo05.MainActivity.assetFilePath; import androidx.annotation.NonNull; import androidx.appcompat.app.AppCompatActivity; import androidx.core.app.ActivityCompat; import androidx.core.content.ContextCompat; import android.content.ContentResolver; import android.content.Intent; import android.content.pm.PackageManager; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.hardware.camera2.CameraAccessException; import android.hardware.camera2.CameraCharacteristics; import android.hardware.camera2.CameraDevice; import android.hardware.camera2.CameraManager; import android.hardware.camera2.params.StreamConfigurationMap; import android.net.Uri; import android.os.Bundle; import android.provider.MediaStore; import android.util.Log; import android.view.TextureView; import android.view.View; import android.widget.Button; import android.widget.ImageView; import android.widget.TextView; import com.example.newdemo05.classify.ImageNetClasses; import org.pytorch.IValue; import org.pytorch.Module; import org.pytorch.Tensor; import org.pytorch.torchvision.TensorImageUtils; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; public class TakePicturesActivity extends AppCompatActivity { private TextureView textureView; private TextView textPrediction; private Button openCam; private ImageView v_img; Module module = null; private CameraDevice cameraDevice; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_take_pictures); initView(); } private void initView(){ v_img = findViewById(R.id.image); textPrediction = findViewById(R.id.text_prediction); openCam = findViewById(R.id.btn_capture); openCam.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { startCamera(); } }); } // 打开相机 public void startCamera(){ Intent intent = new Intent(Intent.ACTION_PICK, null); //调用setDataAndType方法,指定了选择的数据类型为图片 //设置数据的URI为MediaStore.Images.Media.EXTERNAL_CONTENT_URI,表示选择外部存储中的图片 intent.setDataAndType(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, "image/*"); //调用startActivityForResult方法,将Intent发送给系统,并指定一个请求码为2,以便在之后的回调中处理用户选择的图片 startActivityForResult(intent, 2); } @Override protected void onActivityResult(int requestCode, int resultCode, Intent data) { super.onActivityResult(requestCode, resultCode, data); if (requestCode == 2) { // 从相册返回的数据 Log.e(this.getClass().getName(), "Result:" + data.toString()); if (data != null) { // 得到图片的全路径 Uri uri = data.getData(); v_img.setImageURI(uri); ContentResolver cr = getContentResolver(); InputStream inputStream = null; try { inputStream = cr.openInputStream(uri); } catch (FileNotFoundException e) { throw new RuntimeException(e); } Bitmap bitmap = BitmapFactory.decodeStream(inputStream); imgClassify(bitmap); Log.e(this.getClass().getName(), "Uri:" + String.valueOf(uri)); } } } private void imgClassify(Bitmap img){ try { // 加载PyTorch序列化模型 module = Module.load(assetFilePath(this, "model101.pt")); } catch (IOException e) { Log.e("PytorchHelloWorld", "Error reading assets", e); finish(); } v_img.setImageBitmap(img); // 建立输入张量 final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(img, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB); // 运行模型推理 final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor(); // 获取推理结果 final float[] scores = outputTensor.getDataAsFloatArray(); // 获取概率最高的分类的索引号 float maxScore = -Float.MAX_VALUE; int maxScoreIdx = -1; for (int i = 0; i < scores.length; i++) { if (scores[i] > maxScore) { maxScore = scores[i]; maxScoreIdx = i; } } //通过索引号获得分类名称 String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx]; textPrediction.setText("识别结果为:"+className); } }
然后新建一个ImageNetClasses,来装我们的标签文件:
public class ImageNetClasses {
public static String[] IMAGENET_CLASSES = new String[]{};}
然后负责粘贴处理好的标签txt文件数据:
具体代码逻辑如下:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。