赞
踩
这篇文章演示如何将训练好的pytorch模型部署到安卓设备上。我也是刚开始学安卓,代码写的简单。
环境:
pytorch版本:1.10.0
pytorch_android支持的模型是.pt模型,我们训练出来的模型是.pth。所以需要转化才可以用。先看官网上给的转化方式:
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
model = torchvision.models.mobilenet_v3_small(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.ptl")
这个模型在安卓对应的包:
repositories {
jcenter()
}
dependencies {
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
}
注:pytorch_android_lite版本和转化模型用的版本要一致,不一致就会报各种错误。
目前用这种方法有点问题,我采用的另一种方法。
转化代码如下:
import torch import torch.utils.data.distributed # pytorch环境中 model_pth = 'model_31_0.96.pth' #模型的参数文件 mobile_pt ='model.pt' # 将模型保存为Android可以调用的文件 model = torch.load(model_pth) model.eval() # 模型设为评估模式 device = torch.device('cpu') model.to(device) # 1张3通道224*224的图片 input_tensor = torch.rand(1, 3, 224, 224) # 设定输入数据格式 mobile = torch.jit.trace(model, input_tensor) # 模型转化 mobile.save(mobile_pt) # 保存文件
对应的包:
//pytorch
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
定义模型文件和转化后的文件路径。
load模型。这里要注意,如果保存模型
torch.save(model,'models.pth')
加载模型则是
model=torch.load('models.pth')
如果保存模型是
torch.save(model.state_dict(),"models.pth")
加载模型则是
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。