当前位置:   article > 正文

LSTM时间序列回归matlab实现(附代码+数据集)_lstm matlab

lstm matlab

原理部分

  LSTM在1997年被提出,从发表时间上来看已经是个"老"方法了。和其他的神经网络一样,LSTM可用于分类、回归以及时间序列预测等。原理部分的介绍可参考这篇博客。本文主要涉及利用matlab实现LSTM。
在这里插入图片描述

代码部分

  任务:以青霉素发酵过程仿真数据为例,利用LSTM建模预测质量变量。
  青霉素发酵过程仿真过程简介:共有18个过程变量,其中15个可测变量,剩余3个一般作为质量变量。共生成30个批次数据,每批次运行时长为400小时,采样时间为1小时,其中25批次用于训练,5批次用于测试。
  本文所用数据下载,基于matlab深度学习工具箱实现青霉素浓度的预测。

数据标准化

XTrain_mu = mean([XTrain{:}],2);
XTrain_sig = std([XTrain{:}],0,2);
XTest_mu = mean([XTest{:}],2);
XTest_sig = std([XTest{:}],0,2);
YTrain_mu = mean([YTrain{:}],2);
YTrain_sig = std([YTrain{:}],0,2);
YTest_mu = mean([YTest{:}],2);
YTest_sig = std([YTest{:}],0,2);

for i = 1:numel(XTrain)
    XTrain{i} = (XTrain{i} - XTrain_mu) ./ XTrain_sig ;
    YTrain{i}=(YTrain{i} - YTrain_mu) ./ YTrain_sig;
end

for i = 1:numel(XTest)
    XTest{i}=(XTest{i} - XTest_mu) ./ XTest_sig;
    YTest{i}=(YTest{i} - YTest_mu) ./ YTest_sig;
end
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

定义网络结构

numResponses = size(YTrain{1},1);
numHiddenUnits = 200;
numFeatures=15;%变量个数
layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','sequence')
    fullyConnectedLayer(50)
    dropoutLayer(0.5)
    fullyConnectedLayer(numResponses)
    regressionLayer];
maxEpochs = 90;
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

设定超参数

options = trainingOptions('adam', ...
    'MaxEpochs',maxEpochs, ...
    'InitialLearnRate',0.01, ...
    'GradientThreshold',1, ...
    'Shuffle','never', ...
    'Plots','training-progress',...
    'Verbose',0);
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

模型训练

net = trainNetwork(XTrain,YTrain,layers,options);
  • 1

回归预测

YPred = predict(net,XTest);
  • 1

输出可视化

idx = randperm(numel(YPred),4);
figure
for i = 1:numel(idx)
    subplot(2,2,i)
    plot(YTest{idx(i)},'--')
    hold on
    plot(YPred{idx(i)},'.-')
    hold off
    title("Test Observation " + idx(i))
    xlabel("Time Step")
    ylabel("青霉素浓度")
    rmse = sqrt(mean((YPred{i} - YTest{i}).^2))
end
legend(["True" "Predicted"],'Location','southeast')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

结果

训练过程:请添加图片描述

回归预测:

在这里插入图片描述

整体代码

XTrain_mu = mean([XTrain{:}],2);
XTrain_sig = std([XTrain{:}],0,2);
XTest_mu = mean([XTest{:}],2);
XTest_sig = std([XTest{:}],0,2);
YTrain_mu = mean([YTrain{:}],2);
YTrain_sig = std([YTrain{:}],0,2);
YTest_mu = mean([YTest{:}],2);
YTest_sig = std([YTest{:}],0,2);

for i = 1:numel(XTrain)
    XTrain{i} = (XTrain{i} - XTrain_mu) ./ XTrain_sig ;
    YTrain{i}=(YTrain{i} - YTrain_mu) ./ YTrain_sig;
end

for i = 1:numel(XTest)
    XTest{i}=(XTest{i} - XTest_mu) ./ XTest_sig;
    YTest{i}=(YTest{i} - YTest_mu) ./ YTest_sig;
end
numResponses = size(YTrain{1},1);
numHiddenUnits = 200;
numFeatures=15;
layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits,'OutputMode','sequence')
    fullyConnectedLayer(50)
    dropoutLayer(0.5)
    fullyConnectedLayer(numResponses)
    regressionLayer];
maxEpochs = 90;
options = trainingOptions('adam', ...
    'MaxEpochs',maxEpochs, ...
    'InitialLearnRate',0.01, ...
    'GradientThreshold',1, ...
    'Shuffle','never', ...
    'Plots','training-progress',...
    'Verbose',0);
net = trainNetwork(XTrain,YTrain,layers,options);
YPred = predict(net,XTest);
idx = randperm(numel(YPred),4);
figure
for i = 1:numel(idx)
    subplot(2,2,i)
    plot(YTest{idx(i)},'--')
    hold on
    plot(YPred{idx(i)},'.-')
    hold off
    title("Test Observation " + idx(i))
    xlabel("Time Step")
    ylabel("青霉素浓度")
    rmse = sqrt(mean((YPred{i} - YTest{i}).^2))
end
legend(["True" "Predicted"],'Location','southeast')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52

备注:市面上主流的网络都可以使用matlab的深度学习工具箱自行搭建,避免复杂的环境配置,如果不搞算法研究的话还是很好用的,强烈推荐。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/112520
推荐阅读
相关标签
  

闽ICP备14008679号