赞
踩
MNN简介及应用场景
深度学习模型应用流程:
Pytorch应用示例
模型离线训练
导出模型
def save_model(m, loss, base_name, all_model=False):
save_path = r'/home/hc/trainned_weight_params'
model_name = str(type(m))
name = '%s/%s_epoch%d_loss%.4f.pth' \
% (save_path, base_name, epoch_num, loss)
name_onnx = name + '.onnx'
print('>>>save model:', model_name, name)
if all_model:
torch.save(m, name)
else:
torch.save(m.state_dict(), name)
# create the right input shape
m_input_data = torch.randn(1, 4)
# switch pytorch model to onnx model
torch.onnx.export(m, m_input_data, name_onnx, verbose=True)
模型转换
./MNNConvert --framework ONNX --modelFile pfld-lite.onnx --MNNModel pfld-lite.mnn --bizCode MNN
模型部署
使用MNN API加载模型和相关配置,然后设置输入和输出
示例代码(4输入3输出):
// load model and set config (work thread num and platform auto ALSnet = MNN::Interpreter::createFromFile((const char*)model_path); MNN::ScheduleConfig netConfig; netConfig.type = MNN_FORWARD_CPU; netConfig.numThread = 4; // set model precison MNN::BackendConfig backendConfig; backendConfig.precision = MNN::BackendConfig::Precision_High; netConfig.backendConfig = &backendConfig; auto session = ALSnet->createSession(netConfig); auto inputTensor = ALSnet->getSessionInput(session, NULL); auto net_input_data = inputTensor->host<float>(); // get input data net_input_data[0] = block_r[round] / (pixel_cnt / 4); net_input_data[1] = block_g[round] / (pixel_cnt / 4); net_input_data[2] = block_b[round] / (pixel_cnt / 4); // run session { ALSnet->runSession(session); } // get output data auto outputTensor = ALSnet->getSessionOutput(session, NULL); predDiff[round].r = outputTensor->host<float>()[0]; predDiff[round].g = outputTensor->host<float>()[1]; predDiff[round].b = outputTensor->host<float>()[2];
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。