当前位置:   article > 正文

Python模型部署工具_模型工程化部署

模型工程化部署

什么是模型部署?

在典型的机器学习和深度学习项目中,建模的常规流程是定义问题、数据收集、数据理解、数据处理、构建模型。但是,如果我们想要将模型提供给最终用户,以便用户能够使用它,就需要进行模型部署,模型部署要做的工作就是如何将机器学习模型传递给客户/利益相关者。模型的部署大致分为以下三个步骤:

  1. 模型持久化;
    持久化,通俗得讲,就是临时数据(比如内存中的数据,是不能永久保存的)持久化为持久数据(比如持久化至数据库中,能够长久保存)。那我们训练好的模型一般都是存储在内存中,这个时候就需要用到持久化方式,在 Python 中,常用的模型持久化方式一般都是以文件的方式持久化。
  2. 选择适合的服务器加载已经持久化的模型;
  3. 提高服务接口,拉通前后端数据交流;

模型部署工具介绍

MLflow

MLflow的Python接口

MLeap

PMML

依赖包:

  • sklearn
  • sklearn2pmml

将训练好的机器学习模型转化为PMML格式,以供Java调用。
Python代码如下

from sklearn import tree
from sklearn.datasets import load_iris
from sklearn2pmml.pipeline import PMMLPipeline
from sklearn2pmml import sklearn2pmml

if __name__ == '__main__':
    iris = load_iris() # 经典的数据
    X = iris.data  # 样本特征
    y = iris.target  # 分类目标
    pipeline = PMMLPipeline([("classifier", tree.DecisionTreeClassifier())]) # 用决策树分类
    pipeline.fit(X, y)  # 训练
    sklearn2pmml(pipeline, "iris.pmml", with_repr=True)  # 输出PMML文件
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

Java读取模型文件并预测,具体代码如下:

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;


public class TestPmml {
    public static void main(String args[]) throws Exception {
        String fp = "iris.pmml";
        TestPmml obj = new TestPmml();
        Evaluator model = obj.loadPmml(fp);
        List<Map<String, Object>> inputs = new ArrayList<>();
        inputs.add(obj.getRawMap(5.1, 3.5, 1.4, 0.2));
        inputs.add(obj.getRawMap(4.9, 3, 1.4, 0.2));
        for (int i = 0; i < inputs.size(); i++) {
            Map<String, Object> output = obj.predict(model, inputs.get(i));
            System.out.println("X=" + inputs.get(i) + " -> y=" + output.get("y"));
        }
    }

    private Evaluator loadPmml(String fp) throws FileNotFoundException, JAXBException, SAXException {
        InputStream is = new FileInputStream(fp);
        PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
        try {
            is.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        ModelEvaluatorFactory factory = ModelEvaluatorFactory.newInstance();
        return factory.newModelEvaluator(pmml);
    }

    private Map<String, Object> getRawMap(Object a, Object b, Object c, Object d) {
        Map<String, Object> data = new HashMap<String, Object>();
        data.put("x1", a);
        data.put("x2", b);
        data.put("x3", c);
        data.put("x4", d);
        return data;
    }

    /**
     * 运行模型得到结果。
     */
    private Map<String, Object> predict(Evaluator evaluator, Map<String, Object> data) {
        Map<FieldName, FieldValue> input = getFieldMap(evaluator, data);
        Map<String, Object> output = evaluate(evaluator, input);
        return output;
    }

    /**
     * 把原始输入转换成PMML格式的输入。
     */
    private Map<FieldName, FieldValue> getFieldMap(Evaluator evaluator, Map<String, Object> input) {
        List<InputField> inputFields = evaluator.getInputFields();
        Map<FieldName, FieldValue> map = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField field : inputFields) {
            FieldName fieldName = field.getName();
            Object rawValue = input.get(fieldName.getValue());
            FieldValue value = field.prepare(rawValue);
            map.put(fieldName, value);
        }
        return map;
    }

    /**
     * 运行模型得到结果。
     */
    private Map<String, Object> evaluate(Evaluator evaluator, Map<FieldName, FieldValue> input) {
        Map<FieldName, ?> results = evaluator.evaluate(input);
        List<TargetField> targetFields = evaluator.getTargetFields();
        Map<String, Object> output = new LinkedHashMap<String, Object>();
        for (int i = 0; i < targetFields.size(); i++) {
            TargetField field = targetFields.get(i);
            FieldName fieldName = field.getName();
            Object value = results.get(fieldName);
            if (value instanceof Computable) {
                Computable computable = (Computable) value;
                value = computable.getResult();
            }
            output.put(fieldName.getValue(), value);
        }
        return output;
    }

}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93

Pyspark

Sklearn

ONNX介绍

可通过示例–python模型转换为ONNX格式了解简单的模型转换。详细内容可参考ONNX官方教程

TensorRT

TensorFlow Serving

Web服务化部署

该方法主要是通过一些Web框架将预测模型打包成Web服务接口的形式,是一种比较常见的线上部署方式。常用的Web框架如下所示:

Docker菜鸟教程

参考

  1. 深度学习模型部署技术方案
  2. 谈谈机器学习模型的部署
  3. MLflow:一种机器学习生命周期管理平台
  4. 总结一下模型工程化部署的几种方式
  5. 用PMML实现机器学习模型的跨平台上线
  6. 机器学习模型之PMML
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/412101?site
推荐阅读
相关标签
  

闽ICP备14008679号