当前位置:   article > 正文

java spark一元(多元)线性回归_spark dataset try catch

spark dataset try catch

配置

配置请看我的其他文章 点击跳转

spark官方文档

点击跳转官方文档

其它文章

推荐一个在蚂蚁做算法的人写的文章,不过他的文章偏专业化,有很多数学学公式。我是看的比较懵。点击跳转

数据

训练数据

在这里插入图片描述

预测数据


实体类

用了swagger和lombok 不需要的可以删掉


import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull;

/**
 * 线性回归参数
 *
 * @author teler
 * @date 2020-09-21
 */
@Data
public class LinearRegressionEntity {
	/**
	 * 训练数据集路径
	 */
	@ApiModelProperty("训练数据集路径")
	@NotEmpty(message = "必须有样本集")
	private String trainFilePath;
	/**
	 * 预测数据集路径
	 */
	@ApiModelProperty("预测数据集路径")
	@NotEmpty(message = "必须有预测集")
	private String dataFilePath;
	/**
	 * 用于测试模型的数据比例,范围[0,1]
	 */
	@ApiModelProperty("用于测试模型的数据比例,范围[0,1]")
	@Max(value = 1L, message = "数据比例最大值为1.0")
	@Min(value = 0L, message = "数据比例最小值为0.0")
	private double testDataPct;
	/**
	 * 迭代次数
	 */
	@ApiModelProperty("迭代次数")
	@NotNull(message = "迭代次数必填")
	@Min(value = 0, message = "迭代次数最小值为0")
	private Integer iter;
	/**
	 * 正则化参数,范围[0,1]
	 */
	@NotNull(message = "正则化参数必填")
	@Max(value = 1L, message = "正则化参数最大值为1.0")
	@Min(value = 0L, message = "正则化参数最小值为0.0")
	@ApiModelProperty("正则化参数,范围[0,1]")
	private double regParam;
	/**
	 * 弹性网络混合参数,范围[0,1]
	 */
	@ApiModelProperty("弹性网络混合参数,范围[0,1]")
	@Max(value = 1L, message = "弹性网络混合参数最大值为1.0")
	@Min(value = 0L, message = "弹性网络混合参数最小值为0.0")
	private double elasticNetParam;
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60

算法实现

里面有些方法是为了保留小数 不需要的自己改


@Resource
private transient SparkSession sparkSession;

@Override
	public Map<String, Object> linearRegression(LinearRegressionEntity record) {
		log.info("========== 线性回归计算开始 ==========");
		Map<String, Object> map = new HashMap<>(16);
		Dataset<Row> source = getDataSetByHdfs(record.getTrainFilePath());

		List<Map<String, String>> sourceList = toList(source);
		//训练数据
		map.put("training", sourceList);

		//根据比例从数据源中随机抽取数据 /训练数据和测试数据比例 建议设为0.8
		Dataset<Row>[] splits = source.randomSplit(new double[]{record.getTestDataPct(), 1 - record.getTestDataPct()},
				1234L);
		//训练数据
		Dataset<Row> trainingData = splits[0].cache();

		// 10 / 0.3 / 0.8
		LinearRegression lr = new LinearRegression()
				                      .setMaxIter(record.getIter())
				                      .setRegParam(record.getRegParam())
				                      .setElasticNetParam(record.getElasticNetParam());

		LinearRegressionModel lrModel = lr.fit(trainingData);

		//系数
		map.put("coefficients", Arrays.stream(lrModel.coefficients().toArray()).map(val -> NumberUtil.roundDown(val, 3).doubleValue()));
		//截距
		map.put("intercept", NumberUtil.roundDown(lrModel.intercept(), 3));

		//训练数据结果集
		LinearRegressionTrainingSummary trainingSummary = lrModel.summary();

		//迭代次数
		map.put("numIterations", trainingSummary.totalIterations());
		//损失率,一般会逐渐减小
		map.put("objectiveHistory", Arrays.stream(trainingSummary.objectiveHistory()).map(val -> NumberUtil.roundDown(val, 3).doubleValue()));
		//均方根误差
		map.put("rmse", NumberUtil.roundDown(trainingSummary.rootMeanSquaredError(), 3));
		//真实误差
		map.put("mae", NumberUtil.roundDown(trainingSummary.meanAbsoluteError(), 3));
		//r平方 越接近1说明效果越好
		map.put("r2", NumberUtil.roundDown(trainingSummary.r2(), 3));

		//预测数据
		Dataset<Row> predictionData = getDataSetByHdfs(record.getDataFilePath());
		Dataset<Row> predictionResult = lrModel.transform(predictionData).selectExpr("label", "features", "round(prediction,3) as prediction");

		predictionResult.show();
		List<Object> predictionFeaturesVal = dataSetToString(predictionResult.select("features"));
		
		map.put("data", toList(predictionResult));
		log.info("========== 线性回归计算结束 ==========");
		return map;
	}



  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61

getDataSetByHdfs方法

这个方法我与上面的方法放在一个类中,所以sparkSession没重复写

/**
	 * 从hdfs中取数据
	 *
	 * @param dataFilePath 数据路径
	 * @return 数据集合
	 */
	private Dataset<Row> getDataSetByHdfs(String dataFilePath) {

		//屏蔽日志
		Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
		Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);

		Dataset<Row> dataset;
		try {
			//我这里的数据是libsvm格式的 如果是其他格式请自行更改
			dataset = sparkSession.read().format("libsvm").load(dataFilePath);
			log.info("获取数据结束 ");
		} catch (Exception e) {
			log.info("读取失败:{} ", e.getMessage());
		}
		return dataset;
	}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

toList

/**
	 * dataset数据转化为list数据
	 *
	 * @param record 数据
	 * @return 数据集合
	 */
	private List<Map<String, String>> toList(Dataset<Row> record) {
		log.info("格式化结果数据集===============================");
		List<Map<String, String>> list = new ArrayList<>();
		String[] columns = record.columns();
		List<Row> rows = record.collectAsList();
		for (Row row : rows) {
			Map<String, String> obj = new HashMap<>(16);
			for (int j = 0; j < columns.length; j++) {
				String col = columns[j];
				Object rowAs = row.getAs(col);
				String val = "";
				//如果是数组 
				//这一段不需要的可以只留下else的内容
				if (rowAs instanceof DenseVector) {
					if (((DenseVector) rowAs).values() instanceof double[]) {
						val = ArrayUtil.join(
								Arrays.stream(((DenseVector) rowAs).values())
										.map(rowVal -> NumberUtil.roundDown(rowVal, 3).doubleValue()).toArray()
								, ",")
						;
					} else {
						val = rowAs.toString();
					}
				} else {
					val = rowAs.toString();
				}

				obj.put(col, val);
				log.info("列:{},名:{},值:{}", j, col, val);
			}
			list.add(obj);
		}
		return list;
	}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/700433
推荐阅读
相关标签
  

闽ICP备14008679号