当前位置:   article > 正文

java使用pmml调用sklearn算法模型_java语言实现sklearn

java语言实现sklearn

场景

需要在java调用python sklearn训练评估的模型,本文介绍使用pmml来实现。

生成pmml文件

#引入sklearn2pmml包
from sklearn2pmml import sklearn2pmml
from sklearn2pmml.pipeline import PMMLPipeline

#使用PMMLPipeline包裹具体评估器
clf = PMMLPipeline([("MLPClassifier", MLPClassifier(hidden_layer_sizes=(25,), random_state=1, max_iter=100, warm_start=True))])
clf.fit(value, label)

#保存模型到指定文件
sklearn2pmml(clf, "MLPClassifier.pmml", with_repr=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

JAVA调用模型

引用java maven依赖包

        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.5.15</version>
        </dependency>

        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator-extension</artifactId>
            <version>1.5.15</version>
        </dependency>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

java加载模型并评估

        Map<String, Object> paramData = new HashMap<>();
        paramData.put("x1", 180D);
        paramData.put("x2", 350D);

        FileInputStream inputStream = new FileInputStream("MLPClassifier.pmml");
        //解析pmml文件,实际上是用JAXB做xml的解析
        PMML pmml = PMMLUtil.unmarshal(inputStream);
        //生成评估器
        ModelEvaluator<?> evaluate = new ModelEvaluatorBuilder(pmml).build();

        //构建输入参数
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
        List<InputField> inputFields = evaluate.getInputFields();
        for (InputField inputField : inputFields) {            //将参数通过模型对应的名称进行添加
            FieldName inputFieldName = inputField.getName();   //获取模型中的参数名
            Object paramValue = paramData.get(inputFieldName.getValue());   //获取模型参数名对应的参数值
            FieldValue fieldValue = inputField.prepare(paramValue);   //将参数值填入模型中的参数中
            arguments.put(inputFieldName, fieldValue);          //存放在map列表中
        }

        //开始评估
        Map<FieldName, ?> target = evaluate.evaluate(arguments);

        //获取评估结果
        List<TargetField> targetFields = evaluate.getTargetFields();
        Object targetFieldValue = target.get(targetFields.get(0).getFieldName());
        System.out.println("targetFieldValue: " + targetFieldValue);
        System.out.println("target: " + target);
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

注意事项

1.注意生成模型的版本和java依赖包的版本要匹配,否则java侧会无法解析该pmml模型

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/979349
推荐阅读
相关标签
  

闽ICP备14008679号