赞
踩
利用tensorflow.js训练模型,搭建神经网络模型,完成手写数字识别
简单三层神经网络
// 导入 TensorFlow.js 库 import tf from "@tensorflow/tfjs"; import * as tfjsnode from "@tensorflow/tfjs-node"; import * as tfvis from "@tensorflow/tfjs-vis"; import fs from "fs"; import plot from "nodeplotlib"; // 定义模型 const model = tf.sequential(); // 添加输入层 model.add( tf.layers.dense({ units: 64, inputShape: [784], activation: "relu" }) ); // 添加隐藏层 model.add(tf.layers.dense({ units: 100, activation: "relu" })); // 添加输出层 model.add(tf.layers.dense({ units: 10, activation: "softmax" })); // 编译模型 model.compile({ optimizer: "sgd", loss: "categoricalCrossentropy", metrics: ["accuracy"], }); const trainDataLen = 3000; const testDataLen = 2000; // 加载 MNIST 数据集 import pkg from "mnist"; const { set: Dataset } = pkg; const set = Dataset(trainDataLen, testDataLen); const trainingSet = set.training; const testSet = set.test; const trainXs = []; const testXs = []; const trainLabels = []; const testLabels = []; for (let i = 0; i < trainingSet.length; i++) { trainXs.push(trainingSet[i].input); trainLabels.push(trainingSet[i].output.indexOf(1)); } for (let i = 0; i < testSet.length; i++) { testXs.push(testSet[i].input); testLabels.push(testSet[i].output.indexOf(1)); } // 准备数据 const trainXsTensor = tf.tensor(trainXs, [trainDataLen, 784]); const trainYsOneHot = tf.oneHot(trainLabels, 10); //记录每轮模型训练中的损失和精度,为了绘制曲线图 var accPlot = []; var lossPlot = []; // 模型训练 model .fit(trainXsTensor, trainYsOneHot, { batchSize: 64, epochs: 100, validationSplit: 0.2, callbacks: { onEpochBegin: (epoch) => console.log(`Epoch ${epoch + 1} started...`), onEpochEnd: async (epoch, logs) => { console.log( `Epoch ${epoch + 1} completed. Loss: ${logs.loss.toFixed( 3 )}, Accuracy: ${logs.acc.toFixed(3)}` ); //记录loss和acc,绘制曲线图 accPlot.push(logs.acc.toFixed(3)); lossPlot.push(logs.loss.toFixed(3)); await tf.nextFrame(); // 防止阻塞 }, onBatchEnd: async (batch, logs) => { console.log( `Batch ${batch} completed. Loss: ${logs.loss.toFixed( 3 )}, Accuracy: ${logs.acc.toFixed(3)}` ); await tf.nextFrame(); // 防止阻塞 }, }, }) .then((history) => { console.log("Training completed!", history); //绘制模型训练过程中的损失函数和模型精度曲线变化 const epochs = Array.from({ length: lossPlot.length }, (_, i) => i + 1); plot.plot( [ { x: epochs, y: lossPlot, name: "Loss" }, { x: epochs, y: accPlot, name: "Accuracy" }, ], { filename: "loss_acc.png", } ); //模型评估 const testXsTensor = tf.tensor(testXs, [testDataLen, 784]); const testYsOneHot = tf.oneHot(testLabels, 10); const result = model.evaluate(testXsTensor, testYsOneHot); const testLoss = result[0].dataSync()[0]; const testAccuracy = result[1].dataSync()[0]; console.log(`Test loss: ${testLoss.toFixed(3)}`); console.log(`Test accuracy: ${testAccuracy.toFixed(3)}`); //保存模型 model.save("file://./my-model").then(() => { console.log("Model saved!"); }); });
{ "name": "neural_network", "version": "1.0.0", "description": "", "type": "module", "main": "mlpTest.js", "scripts": { "test": "echo \"Error: no test specified\" && exit 1", }, "author": "", "license": "ISC", "dependencies": { "@tensorflow/tfjs": "^4.17.0", "@tensorflow/tfjs-node": "^4.17.0", "@tensorflow/tfjs-vis": "^1.0.0", "mnist": "^1.1.0", "nodeplotlib": "^0.7.7" }, "devDependencies": { "@babel/core": "^7.0.0", "@babel/preset-env": "^7.0.0", "babel-loader": "^8.0.0", "webpack": "^5.0.0", "webpack-cli": "^4.0.0" } }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。