当前位置:   article > 正文

Java调用pyhton训练的机器学习模型_java调用python训练好的模型

java调用python训练好的模型

python广泛用于机器学习训练模型,java又被大量开发者所使用的,因此存在跨语言调用的问题。幸好有pmml的出现,将python模型直接保存为“.pmml”结尾的文件使其调用。最简单的流程基本走通。存在的问题:如何将PMML模型文件用于AS中,使模型能够用于App的使用,目前还在寻找,太难了。。。。。

python模型直接保存为pmml: 使用的为sklearn2pmml

from sklearn.ensemble import RandomForestRegressor
from sklearn import model_selection,metrics,cross_validation
import numpy as np
import xlrd
import csv
import pandas as pd
from openpyxl.workbook import Workbook
from sklearn.metrics import accuracy_score,recall_score,precision_score,f1_score
import time
from sklearn2pmml import sklearn2pmml
from sklearn2pmml.pipeline import PMMLPipeline

#读取excle
def readexcle(filename):
    fh=xlrd.open_workbook(filename)
    table=fh.sheets()[0]
    rows=table.nrows
    exdata=[]
    for row in range(rows):          #获取行数
        data=table.row_values(row) #读取每行的数值
        exdata.append(data)
    return exdata

#分离数据以及label
def splitnumandlab(data):
    culm=len(data[0][:])
    row=len(data)
    
    
    yslabel=[]
    ysnumdata=[]
    for j in range(row):
        yslabel.append(data[j][culm-1])
        
        ysnumdata.append(data[j][0:culm-12])
        print(len(data[j][0:culm-12]))
    return yslabel,ysnumdata
#获取excle原始数据
y_s_data=readexcle(r"C:\Users\Rui Kong\Desktop\ceshi.xlsx")

label1,numdata1=splitnumandlab(y_s_data) 
x_num=np.array(numdata1)
y_lab=np.array(label1)
x_train,x_test,y_train,y_test=model_selection.train_test_split(x_num,y_lab,train_size=0.85)
clf=RandomForestRegressor(n_estimators=20,max_depth=200)
pipline=PMMLPipeline([("classifier",clf)])
pipline.fit(x_train,y_train)
#pipline.score(x_train,y_train)
#pipline.score(x_test,y_test)
sklearn2pmml(pipline,r"C:\Users\Rui Kong\Desktop\ceshi.pmml")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

Eclipse中调用,需要导入的Jar包为:pmml-evaluator-example-executable-1.4.13.jar

package pmml;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import javax.xml.bind.JAXBException;

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.TargetField;
import org.xml.sax.SAXException;

public class main_enter {
	public static void main(String args[]) {
		Evaluator evaluator = loadPmml();
		ArrayList<Float> arraylist = new ArrayList<>();
		arraylist.add(10.1f);
		arraylist.add(158.1f);
		arraylist.add(1009.1f);
		arraylist.add(1800.1f);
		arraylist.add(158.1f);
		Object ab = predict(evaluator,arraylist);
		System.out.println(ab);
	}
	public static  Evaluator loadPmml(){
        PMML pmml  = new PMML();
        InputStream inputStream = null;
        File file = new File("C:\\Users\\Rui Kong\\Desktop\\ceshi.pmml");
        try{
            inputStream = new FileInputStream(file);//在非activity类中使用getResource需要传入context
           
        }catch (Exception e){
            e.printStackTrace();
        }
        if(inputStream==null){
            return null;
        }
        try{
            pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
        } catch (JAXBException e1) {
            e1.printStackTrace();
        } catch (SAXException e2) {
            e2.printStackTrace();
        }
        try{
            inputStream.close();
        }catch (IOException e){
            e.printStackTrace();
            }
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
        return evaluator;
    }
    public static Object predict(Evaluator evaluator, ArrayList<Float>a){
        int m=a.size();
        HashMap<String, Float>map = new HashMap<>();
        //将数组储存在map中
        for (int i =1;i<m+1;i++){
            map.put("x"+i,a.get(i-1));
        }
        System.out.println(map);
        List<InputField> inputFields = evaluator.getInputFields();
        //从画像中获取数据,作为模型的输入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
        for (InputField inputField :inputFields){
            FieldName inputFieldName = inputField.getName();
            Object rawValue = map.get(inputFieldName.getValue());
            FieldValue inputValue = inputField.prepare(rawValue);
            arguments.put(inputFieldName,inputValue);
        }
        System.out.println(arguments);
        Map<FieldName,?> results =evaluator.evaluate(arguments);//模型识别结果文件
        System.out.println(results);
        List<TargetField> targetFields = evaluator.getTargetFields();
        TargetField targetField = targetFields.get(0);//返回第一个数作为预测结果,预测只有一个结果,分类则有几个结果
        Object targetValue = results.get(targetField);

        return targetValue;
    }
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91

反正最后能跑通,给出了一个预测结果:
结果示意图
[github上面,evaluator的jar包]https://github.com/jpmml/jpmml-evaluator/releases,也可以好好看看大佬写的example

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/神奇cpp/article/detail/926239
推荐阅读
相关标签
  

闽ICP备14008679号