赞
踩
在典型的机器学习和深度学习项目中,建模的常规流程是定义问题、数据收集、数据理解、数据处理、构建模型。但是,如果我们想要将模型提供给最终用户,以便用户能够使用它,就需要进行模型部署,模型部署要做的工作就是如何将机器学习模型传递给客户/利益相关者。模型的部署大致分为以下三个步骤:
依赖包:
将训练好的机器学习模型转化为PMML
格式,以供Java
调用。
Python
代码如下
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn2pmml.pipeline import PMMLPipeline
from sklearn2pmml import sklearn2pmml
if __name__ == '__main__':
iris = load_iris() # 经典的数据
X = iris.data # 样本特征
y = iris.target # 分类目标
pipeline = PMMLPipeline([("classifier", tree.DecisionTreeClassifier())]) # 用决策树分类
pipeline.fit(X, y) # 训练
sklearn2pmml(pipeline, "iris.pmml", with_repr=True) # 输出PMML文件
Java
读取模型文件并预测,具体代码如下:
import org.dmg.pmml.FieldName; import org.dmg.pmml.PMML; import org.jpmml.evaluator.*; import org.xml.sax.SAXException; import javax.xml.bind.JAXBException; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.util.*; public class TestPmml { public static void main(String args[]) throws Exception { String fp = "iris.pmml"; TestPmml obj = new TestPmml(); Evaluator model = obj.loadPmml(fp); List<Map<String, Object>> inputs = new ArrayList<>(); inputs.add(obj.getRawMap(5.1, 3.5, 1.4, 0.2)); inputs.add(obj.getRawMap(4.9, 3, 1.4, 0.2)); for (int i = 0; i < inputs.size(); i++) { Map<String, Object> output = obj.predict(model, inputs.get(i)); System.out.println("X=" + inputs.get(i) + " -> y=" + output.get("y")); } } private Evaluator loadPmml(String fp) throws FileNotFoundException, JAXBException, SAXException { InputStream is = new FileInputStream(fp); PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(is); try { is.close(); } catch (IOException e) { e.printStackTrace(); } ModelEvaluatorFactory factory = ModelEvaluatorFactory.newInstance(); return factory.newModelEvaluator(pmml); } private Map<String, Object> getRawMap(Object a, Object b, Object c, Object d) { Map<String, Object> data = new HashMap<String, Object>(); data.put("x1", a); data.put("x2", b); data.put("x3", c); data.put("x4", d); return data; } /** * 运行模型得到结果。 */ private Map<String, Object> predict(Evaluator evaluator, Map<String, Object> data) { Map<FieldName, FieldValue> input = getFieldMap(evaluator, data); Map<String, Object> output = evaluate(evaluator, input); return output; } /** * 把原始输入转换成PMML格式的输入。 */ private Map<FieldName, FieldValue> getFieldMap(Evaluator evaluator, Map<String, Object> input) { List<InputField> inputFields = evaluator.getInputFields(); Map<FieldName, FieldValue> map = new LinkedHashMap<FieldName, FieldValue>(); for (InputField field : inputFields) { FieldName fieldName = field.getName(); Object rawValue = input.get(fieldName.getValue()); FieldValue value = field.prepare(rawValue); map.put(fieldName, value); } return map; } /** * 运行模型得到结果。 */ private Map<String, Object> evaluate(Evaluator evaluator, Map<FieldName, FieldValue> input) { Map<FieldName, ?> results = evaluator.evaluate(input); List<TargetField> targetFields = evaluator.getTargetFields(); Map<String, Object> output = new LinkedHashMap<String, Object>(); for (int i = 0; i < targetFields.size(); i++) { TargetField field = targetFields.get(i); FieldName fieldName = field.getName(); Object value = results.get(fieldName); if (value instanceof Computable) { Computable computable = (Computable) value; value = computable.getResult(); } output.put(fieldName.getValue(), value); } return output; } }
可通过示例–python模型转换为ONNX格式了解简单的模型转换。详细内容可参考ONNX官方教程
该方法主要是通过一些Web
框架将预测模型打包成Web
服务接口的形式,是一种比较常见的线上部署方式。常用的Web
框架如下所示:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。