当前位置:   article > 正文

双色球预测算法(Java),——森林机器学习、时间序列_java实现deeplearning4j做数据预测分析

java实现deeplearning4j做数据预测分析

最近AI很火,老想着利用AI的什么算法,干点什么有意义的事情。其中之一便想到了双色球,然后让AI给我预测,结果基本都是简单使用随机算法列出了几个数字。

额,,,,咋说呢,双色球确实是随机的,但是,如果只是随机,我用你AI干嘛,直接写个随机数就行了嘛。

于是乎,问了下市面上的一些预测算法,给出了俩,一个是:森林机器学习,一个是时间序列。

然后,我让它给我把这俩算法写出来,给是给了,但是,,,无力吐槽。

于是,在我和它的共同配合下,这俩算法的java版诞生了,仅供参考:

森林机器学习:
  1. package com.ruoyi.web.controller.test;
  2. import java.io.BufferedReader;
  3. import java.io.FileInputStream;
  4. import java.io.InputStreamReader;
  5. import java.nio.charset.StandardCharsets;
  6. import java.util.ArrayList;
  7. import java.util.HashSet;
  8. import java.util.List;
  9. import java.util.Set;
  10. import lombok.val;
  11. import org.apache.commons.csv.CSVFormat;
  12. import org.apache.commons.csv.CSVParser;
  13. import org.apache.commons.csv.CSVRecord;
  14. import weka.classifiers.Classifier;
  15. import weka.classifiers.trees.RandomForest;
  16. import weka.core.Attribute;
  17. import weka.core.DenseInstance;
  18. import weka.core.Instances;
  19. public class LotteryPredictor {
  20. public static void main(String[] args) throws Exception {
  21. String csvFilePath = "D:\\12.csv"; // 请替换为你的CSV文件的绝对路径
  22. // Step 1: Read historical data from CSV
  23. List<int[]> historicalData = readCSV(csvFilePath);
  24. // Step 2: Prepare data for Weka
  25. Instances trainingData = prepareTrainingData(historicalData);
  26. // Step 3: Train RandomForest model
  27. Classifier model = new RandomForest();
  28. model.buildClassifier(trainingData);
  29. // Step 4: Make a prediction
  30. int[] prediction = predictNextNumbers(model, trainingData);
  31. // Output the prediction
  32. System.out.println("Predicted numbers: ");
  33. for (int num : prediction) {
  34. System.out.print(num + " ");
  35. }
  36. }
  37. private static List<int[]> readCSV(String csvFilePath) throws Exception {
  38. List<int[]> data = new ArrayList<>();
  39. try (BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(csvFilePath), StandardCharsets.UTF_8))) {
  40. CSVParser csvParser = new CSVParser(reader, CSVFormat.DEFAULT.withDelimiter(',').withTrim());
  41. for (CSVRecord record : csvParser) {
  42. if(record.size() == 1) {
  43. val rec = record.get(0).split(","); // Remove non-numeric characters
  44. int[] row = new int[rec.length];
  45. for (int i = 0; i < rec.length; i++) {
  46. String value = rec[i].replaceAll("[^0-9]", ""); // Remove non-numeric characters
  47. if (!value.isEmpty()) {
  48. row[i] = Integer.parseInt(value);
  49. }
  50. }
  51. data.add(row);
  52. }
  53. else {
  54. int[] row = new int[record.size()];
  55. for (int i = 0; i < record.size(); i++) {
  56. String value = record.get(i).replaceAll("[^0-9]", ""); // Remove non-numeric characters
  57. if (!value.isEmpty()) {
  58. row[i] = Integer.parseInt(value);
  59. }
  60. }
  61. data.add(row);
  62. }
  63. }
  64. }
  65. return data;
  66. }
  67. private static Instances prepareTrainingData(List<int[]> historicalData) {
  68. // Define attributes
  69. ArrayList<Attribute> attributes = new ArrayList<>();
  70. for (int i = 0; i < historicalData.get(0).length; i++) {
  71. attributes.add(new Attribute("num" + (i + 1)));
  72. }
  73. // Create dataset
  74. Instances dataset = new Instances("LotteryData", attributes, historicalData.size());
  75. dataset.setClassIndex(dataset.numAttributes() - 1);
  76. // Add data
  77. for (int[] row : historicalData) {
  78. dataset.add(new DenseInstance(1.0, toDoubleArray(row)));
  79. }
  80. return dataset;
  81. }
  82. private static double[] toDoubleArray(int[] intArray) {
  83. double[] doubleArray = new double[intArray.length];
  84. for (int i = 0; i < intArray.length; i++) {
  85. doubleArray[i] = intArray[i];
  86. }
  87. return doubleArray;
  88. }
  89. private static int[] predictNextNumbers(Classifier model, Instances trainingData) throws Exception {
  90. int numAttributes = trainingData.numAttributes();
  91. Set<Integer> predictedNumbers = new HashSet<>();
  92. while (predictedNumbers.size() < numAttributes) {
  93. DenseInstance instance = new DenseInstance(numAttributes);
  94. instance.setDataset(trainingData);
  95. for (int i = 0; i < numAttributes; i++) {
  96. instance.setValue(i, Math.random() * 33 + 1); // Random values for prediction
  97. }
  98. double prediction = model.classifyInstance(instance);
  99. int predictedNumber = (int) Math.round(prediction);
  100. // Ensure the predicted number is within the valid range and not a duplicate
  101. if (predictedNumber >= 1 && predictedNumber <= 33) {
  102. predictedNumbers.add(predictedNumber);
  103. }
  104. }
  105. int[] predictionArray = new int[numAttributes];
  106. int index = 0;
  107. for (int num : predictedNumbers) {
  108. predictionArray[index++] = num;
  109. }
  110. return predictionArray;
  111. }
  112. }
时间序列算法:
  1. package com.ruoyi.web.controller.test;
  2. import lombok.val;
  3. import org.apache.commons.csv.CSVFormat;
  4. import org.apache.commons.csv.CSVParser;
  5. import org.apache.commons.csv.CSVRecord;
  6. import org.deeplearning4j.nn.api.Model;
  7. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  8. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  9. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  10. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  11. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  12. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  13. import org.nd4j.linalg.activations.Activation;
  14. import org.nd4j.linalg.api.ndarray.INDArray;
  15. import org.nd4j.linalg.dataset.DataSet;
  16. import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
  17. import org.nd4j.linalg.factory.Nd4j;
  18. import org.nd4j.linalg.learning.config.Adam;
  19. import org.nd4j.linalg.lossfunctions.LossFunctions;
  20. import java.io.BufferedReader;
  21. import java.io.FileInputStream;
  22. import java.io.InputStreamReader;
  23. import java.nio.charset.StandardCharsets;
  24. import java.util.ArrayList;
  25. import java.util.HashSet;
  26. import java.util.List;
  27. import java.util.Set;
  28. public class LotteryPredictor3 {
  29. public static void main(String[] args) throws Exception {
  30. String csvFilePath = "D:\\12.csv"; // 请替换为你的CSV文件的绝对路径
  31. // Step 1: Read historical data from CSV
  32. List<int[]> historicalData = readCSV(csvFilePath);
  33. // Step 2: Prepare data for time series analysis
  34. double[][] timeSeriesData = prepareTimeSeriesData(historicalData);
  35. // Step 3: Train neural network model
  36. MultiLayerNetwork model = trainModel(timeSeriesData);
  37. // Step 4: Make a prediction
  38. int[] redBallPrediction = predictRedBalls(model, timeSeriesData);
  39. int blueBallPrediction = predictBlueBall(model, timeSeriesData);
  40. // Output the prediction
  41. System.out.println("Predicted numbers: ");
  42. for (int num : redBallPrediction) {
  43. System.out.print(num + " ");
  44. }
  45. System.out.println("Blue ball: " + blueBallPrediction);
  46. }
  47. private static List<int[]> readCSV(String csvFilePath) throws Exception {
  48. List<int[]> data = new ArrayList<>();
  49. try (BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(csvFilePath), StandardCharsets.UTF_8))) {
  50. CSVParser csvParser = new CSVParser(reader, CSVFormat.DEFAULT.withDelimiter(',').withTrim());
  51. for (CSVRecord record : csvParser) {
  52. if(record.size() == 1) {
  53. val rec = record.get(0).split(","); // Remove non-numeric characters
  54. int[] row = new int[rec.length];
  55. for (int i = 0; i < rec.length; i++) {
  56. String value = rec[i].replaceAll("[^0-9]", ""); // Remove non-numeric characters
  57. if (!value.isEmpty()) {
  58. row[i] = Integer.parseInt(value);
  59. }
  60. }
  61. data.add(row);
  62. }else {
  63. int[] row = new int[record.size()];
  64. for (int i = 0; i < record.size(); i++) {
  65. String value = record.get(i).replaceAll("[^0-9]", ""); // Remove non-numeric characters
  66. if (!value.isEmpty()) {
  67. row[i] = Integer.parseInt(value);
  68. }
  69. }
  70. data.add(row);
  71. }
  72. }
  73. }
  74. return data;
  75. }
  76. private static double[][] prepareTimeSeriesData(List<int[]> historicalData) {
  77. // Flatten the historical data into a 2D array
  78. double[][] timeSeriesData = new double[historicalData.size()][];
  79. for (int i = 0; i < historicalData.size(); i++) {
  80. timeSeriesData[i] = new double[historicalData.get(i).length];
  81. for (int j = 0; j < historicalData.get(i).length; j++) {
  82. timeSeriesData[i][j] = historicalData.get(i)[j];
  83. }
  84. }
  85. return timeSeriesData;
  86. }
  87. private static MultiLayerNetwork trainModel(double[][] timeSeriesData) {
  88. int numInputs = timeSeriesData[0].length;
  89. int numOutputs = numInputs;
  90. int numHiddenNodes = 10;
  91. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
  92. .updater(new Adam(0.01))
  93. .list()
  94. .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
  95. .activation(Activation.RELU)
  96. .build())
  97. .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
  98. .activation(Activation.IDENTITY)
  99. .nIn(numHiddenNodes).nOut(numOutputs).build())
  100. .build();
  101. MultiLayerNetwork model = new MultiLayerNetwork(conf);
  102. model.init();
  103. model.setListeners(new ScoreIterationListener(10));
  104. // Prepare the data
  105. INDArray input = Nd4j.create(timeSeriesData);
  106. INDArray output = Nd4j.create(timeSeriesData);
  107. DataSet dataSet = new DataSet(input, output);
  108. // Normalize the data
  109. NormalizerMinMaxScaler scaler = new NormalizerMinMaxScaler(0, 1);
  110. scaler.fit(dataSet);
  111. scaler.transform(dataSet);
  112. // Train the model
  113. for (int i = 0; i < 2000; i++) {
  114. model.fit(dataSet);
  115. }
  116. return model;
  117. }
  118. private static int[] predictRedBalls(MultiLayerNetwork model, double[][] timeSeriesData) {
  119. INDArray input = Nd4j.create(timeSeriesData);
  120. INDArray output = model.output(input);
  121. double[] lastPrediction = output.getRow(output.rows() - 1).toDoubleVector();
  122. Set<Integer> predictedNumbers = new HashSet<>();
  123. for (double num : lastPrediction) {
  124. int scaledNum = (int) Math.round(num * 32) + 1; // Scale back to 1-33 range
  125. if (scaledNum >= 1 && scaledNum <= 33) {
  126. predictedNumbers.add(scaledNum);
  127. }
  128. if (predictedNumbers.size() == 6) {
  129. break;
  130. }
  131. }
  132. // Ensure we have exactly 6 unique numbers
  133. while (predictedNumbers.size() < 6) {
  134. int randomNum = (int) (Math.random() * 33) + 1;
  135. predictedNumbers.add(randomNum);
  136. }
  137. int[] predictionArray = new int[6];
  138. int index = 0;
  139. for (int num : predictedNumbers) {
  140. predictionArray[index++] = num;
  141. }
  142. return predictionArray;
  143. }
  144. private static int predictBlueBall(MultiLayerNetwork model, double[][] timeSeriesData) {
  145. INDArray input = Nd4j.create(timeSeriesData);
  146. INDArray output = model.output(input);
  147. double lastPrediction = output.getDouble(output.rows() - 1);
  148. // Predict blue ball number
  149. int blueBallPrediction = (int) Math.round(lastPrediction * 15) + 1; // Scale back to 1-16 range
  150. if (blueBallPrediction < 1) blueBallPrediction = 1;
  151. if (blueBallPrediction > 16) blueBallPrediction = 16;
  152. return blueBallPrediction;
  153. }
  154. }

对比了下,时间序列的相对容易让人相信,机器学习,不知道咋评价,大家可以试试。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/喵喵爱编程/article/detail/758507
推荐阅读
相关标签
  

闽ICP备14008679号