赞
踩
需要在java调用python sklearn训练评估的模型,本文介绍使用pmml来实现。
#引入sklearn2pmml包
from sklearn2pmml import sklearn2pmml
from sklearn2pmml.pipeline import PMMLPipeline
#使用PMMLPipeline包裹具体评估器
clf = PMMLPipeline([("MLPClassifier", MLPClassifier(hidden_layer_sizes=(25,), random_state=1, max_iter=100, warm_start=True))])
clf.fit(value, label)
#保存模型到指定文件
sklearn2pmml(clf, "MLPClassifier.pmml", with_repr=True)
引用java maven依赖包
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.5.15</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.5.15</version>
</dependency>
java加载模型并评估
Map<String, Object> paramData = new HashMap<>(); paramData.put("x1", 180D); paramData.put("x2", 350D); FileInputStream inputStream = new FileInputStream("MLPClassifier.pmml"); //解析pmml文件,实际上是用JAXB做xml的解析 PMML pmml = PMMLUtil.unmarshal(inputStream); //生成评估器 ModelEvaluator<?> evaluate = new ModelEvaluatorBuilder(pmml).build(); //构建输入参数 Map<FieldName, FieldValue> arguments = new LinkedHashMap<>(); List<InputField> inputFields = evaluate.getInputFields(); for (InputField inputField : inputFields) { //将参数通过模型对应的名称进行添加 FieldName inputFieldName = inputField.getName(); //获取模型中的参数名 Object paramValue = paramData.get(inputFieldName.getValue()); //获取模型参数名对应的参数值 FieldValue fieldValue = inputField.prepare(paramValue); //将参数值填入模型中的参数中 arguments.put(inputFieldName, fieldValue); //存放在map列表中 } //开始评估 Map<FieldName, ?> target = evaluate.evaluate(arguments); //获取评估结果 List<TargetField> targetFields = evaluate.getTargetFields(); Object targetFieldValue = target.get(targetFields.get(0).getFieldName()); System.out.println("targetFieldValue: " + targetFieldValue); System.out.println("target: " + target);
1.注意生成模型的版本和java依赖包的版本要匹配,否则java侧会无法解析该pmml模型
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。