赞
踩
配置请看我的其他文章 点击跳转
推荐一个在蚂蚁做算法的人写的文章,不过他的文章偏专业化,有很多数学学公式。我是看的比较懵。点击跳转
用了swagger和lombok 不需要的可以删掉
import io.swagger.annotations.ApiModelProperty; import lombok.Data; import javax.validation.constraints.Max; import javax.validation.constraints.Min; import javax.validation.constraints.NotEmpty; import javax.validation.constraints.NotNull; /** * 线性回归参数 * * @author teler * @date 2020-09-21 */ @Data public class LinearRegressionEntity { /** * 训练数据集路径 */ @ApiModelProperty("训练数据集路径") @NotEmpty(message = "必须有样本集") private String trainFilePath; /** * 预测数据集路径 */ @ApiModelProperty("预测数据集路径") @NotEmpty(message = "必须有预测集") private String dataFilePath; /** * 用于测试模型的数据比例,范围[0,1] */ @ApiModelProperty("用于测试模型的数据比例,范围[0,1]") @Max(value = 1L, message = "数据比例最大值为1.0") @Min(value = 0L, message = "数据比例最小值为0.0") private double testDataPct; /** * 迭代次数 */ @ApiModelProperty("迭代次数") @NotNull(message = "迭代次数必填") @Min(value = 0, message = "迭代次数最小值为0") private Integer iter; /** * 正则化参数,范围[0,1] */ @NotNull(message = "正则化参数必填") @Max(value = 1L, message = "正则化参数最大值为1.0") @Min(value = 0L, message = "正则化参数最小值为0.0") @ApiModelProperty("正则化参数,范围[0,1]") private double regParam; /** * 弹性网络混合参数,范围[0,1] */ @ApiModelProperty("弹性网络混合参数,范围[0,1]") @Max(value = 1L, message = "弹性网络混合参数最大值为1.0") @Min(value = 0L, message = "弹性网络混合参数最小值为0.0") private double elasticNetParam; }
里面有些方法是为了保留小数 不需要的自己改
@Resource private transient SparkSession sparkSession; @Override public Map<String, Object> linearRegression(LinearRegressionEntity record) { log.info("========== 线性回归计算开始 =========="); Map<String, Object> map = new HashMap<>(16); Dataset<Row> source = getDataSetByHdfs(record.getTrainFilePath()); List<Map<String, String>> sourceList = toList(source); //训练数据 map.put("training", sourceList); //根据比例从数据源中随机抽取数据 /训练数据和测试数据比例 建议设为0.8 Dataset<Row>[] splits = source.randomSplit(new double[]{record.getTestDataPct(), 1 - record.getTestDataPct()}, 1234L); //训练数据 Dataset<Row> trainingData = splits[0].cache(); // 10 / 0.3 / 0.8 LinearRegression lr = new LinearRegression() .setMaxIter(record.getIter()) .setRegParam(record.getRegParam()) .setElasticNetParam(record.getElasticNetParam()); LinearRegressionModel lrModel = lr.fit(trainingData); //系数 map.put("coefficients", Arrays.stream(lrModel.coefficients().toArray()).map(val -> NumberUtil.roundDown(val, 3).doubleValue())); //截距 map.put("intercept", NumberUtil.roundDown(lrModel.intercept(), 3)); //训练数据结果集 LinearRegressionTrainingSummary trainingSummary = lrModel.summary(); //迭代次数 map.put("numIterations", trainingSummary.totalIterations()); //损失率,一般会逐渐减小 map.put("objectiveHistory", Arrays.stream(trainingSummary.objectiveHistory()).map(val -> NumberUtil.roundDown(val, 3).doubleValue())); //均方根误差 map.put("rmse", NumberUtil.roundDown(trainingSummary.rootMeanSquaredError(), 3)); //真实误差 map.put("mae", NumberUtil.roundDown(trainingSummary.meanAbsoluteError(), 3)); //r平方 越接近1说明效果越好 map.put("r2", NumberUtil.roundDown(trainingSummary.r2(), 3)); //预测数据 Dataset<Row> predictionData = getDataSetByHdfs(record.getDataFilePath()); Dataset<Row> predictionResult = lrModel.transform(predictionData).selectExpr("label", "features", "round(prediction,3) as prediction"); predictionResult.show(); List<Object> predictionFeaturesVal = dataSetToString(predictionResult.select("features")); map.put("data", toList(predictionResult)); log.info("========== 线性回归计算结束 =========="); return map; }
这个方法我与上面的方法放在一个类中,所以sparkSession没重复写
/** * 从hdfs中取数据 * * @param dataFilePath 数据路径 * @return 数据集合 */ private Dataset<Row> getDataSetByHdfs(String dataFilePath) { //屏蔽日志 Logger.getLogger("org.apache.spark").setLevel(Level.WARN); Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF); Dataset<Row> dataset; try { //我这里的数据是libsvm格式的 如果是其他格式请自行更改 dataset = sparkSession.read().format("libsvm").load(dataFilePath); log.info("获取数据结束 "); } catch (Exception e) { log.info("读取失败:{} ", e.getMessage()); } return dataset; }
/** * dataset数据转化为list数据 * * @param record 数据 * @return 数据集合 */ private List<Map<String, String>> toList(Dataset<Row> record) { log.info("格式化结果数据集==============================="); List<Map<String, String>> list = new ArrayList<>(); String[] columns = record.columns(); List<Row> rows = record.collectAsList(); for (Row row : rows) { Map<String, String> obj = new HashMap<>(16); for (int j = 0; j < columns.length; j++) { String col = columns[j]; Object rowAs = row.getAs(col); String val = ""; //如果是数组 //这一段不需要的可以只留下else的内容 if (rowAs instanceof DenseVector) { if (((DenseVector) rowAs).values() instanceof double[]) { val = ArrayUtil.join( Arrays.stream(((DenseVector) rowAs).values()) .map(rowVal -> NumberUtil.roundDown(rowVal, 3).doubleValue()).toArray() , ",") ; } else { val = rowAs.toString(); } } else { val = rowAs.toString(); } obj.put(col, val); log.info("列:{},名:{},值:{}", j, col, val); } list.add(obj); } return list; }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。