当前位置:   article > 正文

数学建模——LSTM时间序列预测的Matlab实现_数学建模lstm

数学建模lstm

  1. data = kw;
  2. %导入数据,数据必须是一维向量
  3. figure
  4. plot(data)
  5. xlabel("15min Period")
  6. ylabel("Kw")
  7. title("15-minute load")
  8. numTimeStepsTrain = floor(0.9*numel(data));
  9. %划分训练集与测试集,这里选了0.9
  10. %确定训练次数
  11. dataTrain = data(1:numTimeStepsTrain+1);
  12. dataTest = data(numTimeStepsTrain+1:end);
  13. %创建两个数据集 dataTrain和dataTest
  14. mu = mean(dataTrain);
  15. sig = std(dataTrain);
  16. dataTrainStandardized = (dataTrain - mu) / sig;
  17. XTrain = dataTrainStandardized(1:end-1);
  18. YTrain = dataTrainStandardized(2:end);
  19. numFeatures = 1;
  20. numResponses = 1;
  21. numHiddenUnits = 200;
  22. layers = [ ...
  23. sequenceInputLayer(numFeatures)
  24. lstmLayer(numHiddenUnits)
  25. fullyConnectedLayer(numResponses)
  26. regressionLayer];
  27. %接下来设置求解的各项参数,指定训练选项。将求解器设置为 'adam' 并进行 250 轮训练。要防止梯度爆炸,请将梯度阈值设置为 1。指定初始学习率 0.005,在 125 轮训练后通过乘以因子 0.2 来降低学习率。
  28. options = trainingOptions('adam', ...
  29. 'MaxEpochs',800, ...
  30. 'GradientThreshold',1, ...
  31. 'InitialLearnRate',0.005, ...
  32. 'LearnRateSchedule','piecewise', ...
  33. 'LearnRateDropPeriod',400, ...
  34. 'LearnRateDropFactor',0.2, ...
  35. 'Verbose',0, ...
  36. 'Plots','training-progress');
  37. net = trainNetwork(XTrain,YTrain,layers,options);
  38. dataTestStandardized = (dataTest - mu) / sig;
  39. XTest = dataTestStandardized(1:end-1);
  40. net = predictAndUpdateState(net,XTrain);
  41. [net,YPred] = predictAndUpdateState(net,YTrain(end));
  42. numTimeStepsTest = numel(XTest);
  43. for i = 2:numTimeStepsTest
  44. [net,YPred(:,i)] = predictAndUpdateState(net,YPred(:,i-1),'ExecutionEnvironment','cpu');
  45. end
  46. YPred = sig*YPred + mu;
  47. YTest = dataTest(2:end);
  48. rmse = sqrt(mean((YPred-YTest).^2))
  49. figure
  50. plot(dataTrain(1:end-1))
  51. hold on
  52. idx = numTimeStepsTrain:(numTimeStepsTrain+numTimeStepsTest);
  53. plot(idx,[data(numTimeStepsTrain) YPred],'.-')
  54. hold off
  55. xlabel("Periods of 15-mins")
  56. ylabel("Kw")
  57. title("Forecast")
  58. legend(["Observed" "Forecast"])
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/97227
推荐阅读
相关标签
  

闽ICP备14008679号