当前位置:   article > 正文

matlab深度学习之LSTM预测_lstm 随机数预测

lstm 随机数预测

 matlab深度学习之LSTM

利用历史序列进行预测

  1. clc
  2. clear
  3. %% 加载示例数据。
  4. %chickenpox_dataset 包含一个时序,其时间步对应于月份,值对应于病例数。
  5. %输出是一个元胞数组,其中每个元素均为单一时间步。将数据重构为行向量。
  6. data = chickenpox_dataset;
  7. data = [data{:}];
  8. figure
  9. plot(data)
  10. xlabel("Month")
  11. ylabel("Cases")
  12. title("Monthly Cases of Chickenpox")
  13. %% 对训练数据和测试数据进行分区。
  14. %序列的前 90% 用于训练,后 10% 用于测试。
  15. numTimeStepsTrain = floor(0.9*numel(data));
  16. dataTrain = data(1:numTimeStepsTrain+1);
  17. dataTest = data(numTimeStepsTrain+1:end);
  18. %% 标准化数据
  19. %为了获得较好的拟合并防止训练发散,将训练数据标准化为具有零均值和单位方差。
  20. %在预测时,您必须使用与训练数据相同的参数来标准化测试数据。
  21. mu = mean(dataTrain);
  22. sig = std(dataTrain);
  23. dataTrainStandardized = (dataTrain - mu) / sig;
  24. %% 准备预测变量和响应
  25. %要预测序列在将来时间步的值,请将响应指定为将值移位了一个时间步的训练序列。
  26. %也就是说,在输入序列的每个时间步,LSTM 网络都学习预测下一个时间步的值。
  27. %预测变量是没有最终时间步的训练序列。
  28. XTrain = dataTrainStandardized(1:end-1);
  29. YTrain = dataTrainStandardized(2:end);
  30. %% 定义 LSTM 网络架构
  31. %创建 LSTM 回归网络。指定 LSTM 层有 200 个隐含单元
  32. numFeatures = 1;
  33. numResponses = 1;
  34. numHiddenUnits = 200;
  35. layers = [ ...
  36. sequenceInputLayer(numFeatures)
  37. lstmLayer(numHiddenUnits)
  38. fullyConnectedLayer(numResponses)
  39. regressionLayer];
  40. %指定训练选项。
  41. %将求解器设置为 'adam' 并进行 250 轮训练。
  42. %要防止梯度爆炸,请将梯度阈值设置为 1
  43. %指定初始学习率 0.005,在 125 轮训练后通过乘以因子 0.2 来降低学习率。
  44. options = trainingOptions('adam', ...
  45. 'MaxEpochs',250, ...
  46. 'GradientThreshold',1, ...
  47. 'InitialLearnRate',0.005, ...
  48. 'LearnRateSchedule','piecewise', ...
  49. 'LearnRateDropPeriod',125, ...
  50. 'LearnRateDropFactor',0.2, ...
  51. 'Verbose',0, ...
  52. 'Plots','training-progress');
  53. %% 训练 LSTM 网络
  54. %使用 trainNetwork 以指定的训练选项训练 LSTM 网络。
  55. net = trainNetwork(XTrain,YTrain,layers,options);
  56. %% 预测将来时间步
  57. %要预测将来多个时间步的值,请使用 predictAndUpdateState 函数一次预测一个时间步,并在每次预测时更新网络状态。对于每次预测,使用前一次预测作为函数的输入。
  58. %使用与训练数据相同的参数来标准化测试数据。
  59. dataTestStandardized = (dataTest - mu) / sig;
  60. XTest = dataTestStandardized(1:end-1);
  61. %要初始化网络状态,请先对训练数据 XTrain 进行预测。
  62. %接下来,使用训练响应的最后一个时间步 YTrain(end) 进行第一次预测。
  63. %循环其余预测并将前一次预测输入到 predictAndUpdateState。
  64. %对于大型数据集合、长序列或大型网络,在 GPU 上进行预测计算通常比在 CPU 上快。
  65. %其他情况下,在 CPU 上进行预测计算通常更快。
  66. %对于单时间步预测,请使用 CPU。
  67. %使用 CPU 进行预测,请将 predictAndUpdateState 的 'ExecutionEnvironment' 选项设置为 'cpu'
  68. net = predictAndUpdateState(net,XTrain);
  69. [net,YPred] = predictAndUpdateState(net,YTrain(end));
  70. numTimeStepsTest = numel(XTest);
  71. for i = 2:numTimeStepsTest
  72. [net,YPred(:,i)] = predictAndUpdateState(net,YPred(:,i-1),'ExecutionEnvironment','cpu');
  73. end
  74. %使用先前计算的参数对预测去标准化。
  75. YPred = sig*YPred + mu;
  76. %训练进度图会报告根据标准化数据计算出的均方根误差 (RMSE)。根据去标准化的预测值计算 RMSE。
  77. YTest = dataTest(2:end);
  78. rmse = sqrt(mean((YPred-YTest).^2))
  79. %使用预测值绘制训练时序。
  80. figure
  81. plot(dataTrain(1:end-1))
  82. hold on
  83. idx = numTimeStepsTrain:(numTimeStepsTrain+numTimeStepsTest);
  84. plot(idx,[data(numTimeStepsTrain) YPred],'.-')
  85. hold off
  86. xlabel("Month")
  87. ylabel("Cases")
  88. title("Forecast")
  89. legend(["Observed" "Forecast"])
  90. %将预测值与测试数据进行比较。
  91. figure
  92. subplot(2,1,1)
  93. plot(YTest)
  94. hold on
  95. plot(YPred,'.-')
  96. hold off
  97. legend(["Observed" "Forecast"])
  98. ylabel("Cases")
  99. title("Forecast")
  100. subplot(2,1,2)
  101. stem(YPred - YTest)
  102. xlabel("Month")
  103. ylabel("Error")
  104. title("RMSE = " + rmse)
  105. %% 使用观测值更新网络状态
  106. %如果您可以访问预测之间的时间步的实际值,则可以使用观测值而不是预测值更新网络状态。
  107. %首先,初始化网络状态。要对新序列进行预测,请使用 resetState 重置网络状态。
  108. %重置网络状态可防止先前的预测影响对新数据的预测。重
  109. %置网络状态,然后通过对训练数据进行预测来初始化网络状态。
  110. net = resetState(net);
  111. net = predictAndUpdateState(net,XTrain);
  112. %对每个时间步进行预测。对于每次预测,使用前一时间步的观测值预测下一个时间步。
  113. %将 predictAndUpdateState 的 'ExecutionEnvironment' 选项设置为 'cpu'
  114. YPred = [];
  115. numTimeStepsTest = numel(XTest);
  116. for i = 1:numTimeStepsTest
  117. [net,YPred(:,i)] = predictAndUpdateState(net,XTest(:,i),'ExecutionEnvironment','cpu');
  118. end
  119. %使用先前计算的参数对预测去标准化。
  120. YPred = sig*YPred + mu;
  121. %计算均方根误差 (RMSE)。
  122. rmse = sqrt(mean((YPred-YTest).^2))
  123. %将预测值与测试数据进行比较。
  124. figure
  125. subplot(2,1,1)
  126. plot(YTest)
  127. hold on
  128. plot(YPred,'.-')
  129. hold off
  130. legend(["Observed" "Predicted"])
  131. ylabel("Cases")
  132. title("Forecast with Updates")
  133. subplot(2,1,2)
  134. stem(YPred - YTest)
  135. xlabel("Month")
  136. ylabel("Error")
  137. title("RMSE = " + rmse)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/518759
推荐阅读
相关标签
  

闽ICP备14008679号