当前位置:   article > 正文

将【深度学习】和【Spring Boot】集成:使用 DL4J 的综合指南

将【深度学习】和【Spring Boot】集成:使用 DL4J 的综合指南

1. 什么是DeepLearning4j?

DeepLearning4J (DL4J) 是一个基于 Java 的神经网络工具包,用于构建、训练和部署神经网络。DL4J 与 Hadoop 和Spark集成,支持分布式 CPU 和 GPU,专为商业环境而设计,而非研究工具用途。Skymind是 DL4J 的商业支持组织。Deeplearning4j 拥有先进的技术,旨在实现即插即用,有更多预设可供使用,避免冗余配置,即使非企业也可以快速进行原型设计。DL4J 还可以进行大规模定制。DL4J 在 Apache 2.0 许可下获得许可,所有基于它的衍生作品均为衍生作品。

2. Deeplearning4j 的功能

Deeplearning4j 包括分布式、多线程深度学习框架,以及常见的单线程深度学习框架。训练过程在集群中进行,这意味着 Deeplearning4j 可以快速处理大量数据。神经网络可以通过 [迭代简化] 并行训练,并且可以与 Java、Scala和Clojure并行使用,全部兼容。Deeplearning4j 能够作为开放堆栈中的模块组件,使其成为同类中第一个面向微服务架构的深度学习框架。

3. 场景设想

 示例:使用 Spring Boot、Java 和 DL4J 的贷款审批推荐系统

您想要在“贷款审批”应用程序中构建一个微服务,根据历史数据建议是否批准或拒绝贷款申请。该建议基于使用 DL4J 训练的机器学习模型。

4. 实施步骤

  1. 数据准备:收集和预处理历史贷款申请数据,包括信用评分、收入、贷款金额、就业状况和贷款违约历史等特征。
  2. 模型训练:使用 DL4J 对这些数据训练神经网络模型,将贷款申请分类为“批准”或“拒绝”。
  3. 集成到 Spring Boot:将训练好的模型作为 REST API 公开在 Spring Boot 应用程序中,以提供实时贷款审批建议。

5. 逐步实施细节

5.1 数据准备

创建一个 CSV 文件 (loan_data.csv),其中包含 credit_score、income、loan_amount、employment_status 和 label 等列(其中 label 为 1 表示贷款已获批准,为 0 表示贷款已拒绝)。

  1. csv file
  2. credit_score,income,loan_amount,employment_status,label
  3. 700,50000,20000,1,1
  4. 650,45000,15000,1,1
  5. 600,30000,25000,0,0
  6. 720,60000,22000,1,1
  7. 580,29000,18000,0,0

5.2 项目pom.xml设置

  • 首先创建一个新的 Spring Boot 项目。
  • 将 DL4J 和 ND4J 依赖项添加到项目的构建配置中(例如,在pom.xmlMaven 或build.gradleGradle 中):
  1. <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
  2. xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
  3. <modelVersion>4.0.0</modelVersion>
  4. <groupId>com.example</groupId>
  5. <artifactId>loan-approval</artifactId>
  6. <version>0.0.1-SNAPSHOT</version>
  7. <packaging>jar</packaging>
  8. <name>loan-approval</name>
  9. <description>Loan Approval Recommendation System using DL4J</description>
  10. <parent>
  11. <groupId>org.springframework.boot</groupId>
  12. <artifactId>spring-boot-starter-parent</artifactId>
  13. <version>2.7.1</version>
  14. <relativePath/> <!-- lookup parent from repository -->
  15. </parent>
  16. <properties>
  17. <java.version>11</java.version>
  18. <dl4j.version>1.0.0-M1.1</dl4j.version>
  19. </properties>
  20. <dependencies>
  21. <!-- Spring Boot Starter Web -->
  22. <dependency>
  23. <groupId>org.springframework.boot</groupId>
  24. <artifactId>spring-boot-starter-web</artifactId>
  25. </dependency>
  26. <!-- Deeplearning4j Dependencies -->
  27. <dependency>
  28. <groupId>org.deeplearning4j</groupId>
  29. <artifactId>deeplearning4j-core</artifactId>
  30. <version>${dl4j.version}</version>
  31. </dependency>
  32. <dependency>
  33. <groupId>org.deeplearning4j</groupId>
  34. <artifactId>deeplearning4j-nn</artifactId>
  35. <version>${dl4j.version}</version>
  36. </dependency>
  37. <dependency>
  38. <groupId>org.nd4j</groupId>
  39. <artifactId>nd4j-native-platform</artifactId>
  40. <version>${dl4j.version}</version>
  41. </dependency>
  42. <!-- DataVec (for CSV reading) -->
  43. <dependency>
  44. <groupId>org.datavec</groupId>
  45. <artifactId>datavec-api</artifactId>
  46. <version>${dl4j.version}</version>
  47. </dependency>
  48. <dependency>
  49. <groupId>org.datavec</groupId>
  50. <artifactId>datavec-local</artifactId>
  51. <version>${dl4j.version}</version>
  52. </dependency>
  53. <!-- Lombok (Optional, for reducing boilerplate code) -->
  54. <dependency>
  55. <groupId>org.projectlombok</groupId>
  56. <artifactId>lombok</artifactId>
  57. <scope>provided</scope>
  58. </dependency>
  59. <!-- Spring Boot DevTools (Optional, for development convenience) -->
  60. <dependency>
  61. <groupId>org.springframework.boot</groupId>
  62. <artifactId>spring-boot-devtools</artifactId>
  63. <scope>runtime</scope>
  64. <optional>true</optional>
  65. </dependency>
  66. <!-- Spring Boot Test (Optional, for unit tests) -->
  67. <dependency>
  68. <groupId>org.springframework.boot</groupId>
  69. <artifactId>spring-boot-starter-test</artifactId>
  70. <scope>test</scope>
  71. </dependency>
  72. </dependencies>
  73. <build>
  74. <plugins>
  75. <plugin>
  76. <groupId>org.springframework.boot</groupId>
  77. <artifactId>spring-boot-maven-plugin</artifactId>
  78. </plugin>
  79. </plugins>
  80. </build>
  81. </project>

5.3 使用 DL4J 进行模型训练

使用 Java 中的 DL4J 创建一个简单的神经网络模型。

  1. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  2. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  3. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  4. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  5. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  6. import org.deeplearning4j.nn.weights.WeightInit;
  7. import org.deeplearning4j.optimize.api.IterationListener;
  8. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  9. import org.nd4j.linalg.activations.Activation;
  10. import org.nd4j.linalg.dataset.DataSet;
  11. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  12. import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
  13. import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
  14. import org.nd4j.linalg.factory.Nd4j;
  15. import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
  16. import org.nd4j.linalg.dataset.api.iterator.RecordReaderDataSetIterator;
  17. import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
  18. import org.datavec.api.split.FileSplit;
  19. import org.datavec.api.util.ClassPathResource;
  20. public class LoanApprovalModel {
  21. public static void main(String[] args) throws Exception {
  22. // Load dataset
  23. int numLinesToSkip = 0;
  24. char delimiter = ',';
  25. CSVRecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
  26. recordReader.initialize(new FileSplit(new ClassPathResource("loan_data.csv").getFile()));
  27. int labelIndex = 4; // Index of the label (approve/reject)
  28. int numClasses = 2; // Approve or Reject
  29. int batchSize = 5;
  30. DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
  31. // Normalize the data
  32. DataNormalization normalizer = new NormalizerStandardize();
  33. normalizer.fit(iterator);
  34. iterator.setPreProcessor(normalizer);
  35. // Define the network configuration
  36. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
  37. .iterations(1000)
  38. .activation(Activation.RELU)
  39. .weightInit(WeightInit.XAVIER)
  40. .learningRate(0.01)
  41. .list()
  42. .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build())
  43. .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
  44. .activation(Activation.SOFTMAX)
  45. .nIn(3).nOut(2).build())
  46. .backprop(true).pretrain(false).build();
  47. MultiLayerNetwork model = new MultiLayerNetwork(conf);
  48. model.init();
  49. model.setListeners(new ScoreIterationListener(100));
  50. // 训练模型
  51. for (int i = 0; i < 1000; i++) {
  52. iterator.reset();
  53. model.fit(iterator);
  54. }
  55. // 保存模型
  56. model.save(new File("loan_approval_model.zip"), true);
  57. }
  58. }

5.4 与 Spring Boot 集成

创建一个 Spring Boot REST API,加载经过训练的模型并使用它来对新的贷款申请进行预测。

  1. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  2. import org.nd4j.linalg.factory.Nd4j;
  3. import org.nd4j.linalg.api.ndarray.INDArray;
  4. import org.springframework.web.bind.annotation.*;
  5. import java.io.File;
  6. import java.io.IOException;
  7. @RestController
  8. @RequestMapping("/loan")
  9. public class LoanApprovalController {
  10. private MultiLayerNetwork model;
  11. public LoanApprovalController() throws IOException {
  12. // Load the trained model
  13. model = MultiLayerNetwork.load(new File("loan_approval_model.zip"), true);
  14. }
  15. @PostMapping("/approve")
  16. public String approveLoan(@RequestBody LoanApplication loanApplication) {
  17. // Prepare input data
  18. INDArray input = Nd4j.create(new double[]{
  19. loanApplication.getCreditScore(),
  20. loanApplication.getIncome(),
  21. loanApplication.getLoanAmount(),
  22. loanApplication.getEmploymentStatus()
  23. }, 1, 4);
  24. // Make prediction
  25. INDArray output = model.output(input);
  26. int prediction = Nd4j.argMax(output, 1).getInt(0);
  27. return prediction == 1 ? "Approved" : "Rejected";
  28. }
  29. }
  30. class LoanApplication {
  31. private double creditScore;
  32. private double income;
  33. private double loanAmount;
  34. private int employmentStatus;
  35. // Getters and setters
  36. }

5.5 运行应用程序

构建并运行您的 Spring Boot 应用程序。
向上面的web服务地址 /loan/approve 发送一个 POST 请求,其中包含代表贷款申请的 JSON 主体:

  1. {
  2. "creditScore": 710,
  3. "income": 55000,
  4. "loanAmount": 20000,
  5. "employmentStatus": 1
  6. }

API 将根据模型的预测返回“已批准”或“已拒绝”。

6. 结论

“使用 Spring Boot 和 DL4J 为‘贷款审批’应用程序开发了贷款审批推荐系统。实施了一个神经网络模型,根据申请人数据预测贷款审批,大大增强了决策过程并降低了违约风险。”


此示例演示了您使用现代 springboot框架将机器学习集成到生产环境的能力。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/1018748
推荐阅读
相关标签
  

闽ICP备14008679号