当前位置:   article > 正文

matlab2019a中LSTM网络使用方法及源码示例(Deep Learning Toolbox系列篇6)_matlab中lstmnet

matlab中lstmnet

此示例说明如何使用长短期记忆 (LSTM) 网络对序列数据进行分类。

要训练深度神经网络以对序列数据进行分类,可以使用 LSTM 网络。LSTM 网络允许您将序列数据输入网络,并根据序列数据的各个时间步进行预测。

此示例使用 [1] 和 [2] 中所述的日语元音数据集。此示例训练一个 LSTM 网络,旨在根据表示连续说出的两个日语元音的时序数据来识别说话者。训练数据包含九个说话者的时序数据。每个序列有 12 个特征,且长度不同。该数据集包含 270 个训练观测值和 370 个测试观测值。

源码:

  1. %% 通用matlab脚本三连
  2. clear
  3. clc
  4. close all
  5. %% 加载序列数据
  6. % 加载日语元音训练数据。XTrain 是包含 270 个不同长度的 12 维序列的元胞数组。
  7. % Y 是对应于九个说话者的标签 "1"、"2"、...、"9" 的分类向量。
  8. % XTrain 中的条目是具有 12 行(每个特征一行)和不同列数(每个时间步一列)的矩阵。
  9. [XTrain,YTrain] = japaneseVowelsTrainData;
  10. XTrain(1:5)
  11. %% 在绘图中可视化第一个时序。每行对应一个特征。
  12. figure
  13. plot(XTrain{1}')
  14. xlabel("Time Step")
  15. title("Training Observation 1")
  16. legend("Feature " + string(1:12),'Location','northeastoutside')
  17. %% 准备要填充的数据
  18. numObservations = numel(XTrain);
  19. for i=1:numObservations
  20. sequence = XTrain{i};
  21. sequenceLengths(i) = size(sequence,2);
  22. end
  23. %% 按序列长度对数据进行排序。
  24. [sequenceLengths,idx] = sort(sequenceLengths);
  25. XTrain = XTrain(idx);
  26. YTrain = YTrain(idx);
  27. %% 在条形图中查看排序的序列长度。
  28. figure
  29. bar(sequenceLengths)
  30. ylim([0 30])
  31. xlabel("Sequence")
  32. ylabel("Length")
  33. title("Sorted Data")
  34. %% 选择小批量大小 27 以均匀划分训练数据,并减少小批量中的填充量。下图说明了添加到序列中的填充。
  35. miniBatchSize = 27;
  36. %% 定义 LSTM 网络架构
  37. inputSize = 12;
  38. numHiddenUnits = 100;
  39. numClasses = 9;
  40. layers = [ ...
  41. sequenceInputLayer(inputSize)
  42. bilstmLayer(numHiddenUnits,'OutputMode','last')
  43. fullyConnectedLayer(numClasses)
  44. softmaxLayer
  45. classificationLayer]
  46. maxEpochs = 100;
  47. miniBatchSize = 27;
  48. %% 指定训练选项
  49. options = trainingOptions('adam', ...
  50. 'ExecutionEnvironment','cpu', ...
  51. 'GradientThreshold',1, ...
  52. 'MaxEpochs',maxEpochs, ...
  53. 'MiniBatchSize',miniBatchSize, ...
  54. 'SequenceLength','longest', ...
  55. 'Shuffle','never', ...
  56. 'Verbose',0, ...
  57. 'Plots','training-progress');
  58. %% 训练 LSTM 网络
  59. net = trainNetwork(XTrain,YTrain,layers,options);
  60. %% 测试 LSTM 网络
  61. [XTest,YTest] = japaneseVowelsTestData;
  62. XTest(1:3)
  63. %% LSTM 网络 net 已使用相似长度的小批量序列进行训练
  64. numObservationsTest = numel(XTest);
  65. for i=1:numObservationsTest
  66. sequence = XTest{i};
  67. sequenceLengthsTest(i) = size(sequence,2);
  68. end
  69. [sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
  70. XTest = XTest(idx);
  71. YTest = YTest(idx);
  72. %% 对测试数据进行分类
  73. miniBatchSize = 27;
  74. YPred = classify(net,XTest, ...
  75. 'MiniBatchSize',miniBatchSize, ...
  76. 'SequenceLength','longest');
  77. %% 计算预测值的分类准确度。
  78. acc = sum(YPred == YTest)./numel(YTest)

运行结果:

本文着重讲解一下matlab的深度学习训练方式

在matlab中,训练一个网络有一个常用的函数trainNetwork;此函数通常有四个参数

  1. %% 训练 LSTM 网络
  2. net = trainNetwork(XTrain,YTrain,layers,options);

 其中,XTrain为训练的输入数据集,YTrain为训练数据对应的标签;

layers为网络的架构,也就是指网络每层的处理模式。

options为超参数的设置,包括学习率,优化方法,迭代次数以及批量大小问题等。

1. 构建网络模型;

  1. inputSize = 12;
  2. numHiddenUnits = 100;
  3. numClasses = 9;
  4. layers = [ ...
  5. sequenceInputLayer(inputSize) %输入层
  6. bilstmLayer(numHiddenUnits,'OutputMode','last') %第一层隐层,LSTM架构
  7. fullyConnectedLayer(numClasses) % 第二层隐层,全连接层
  8. softmaxLayer %softmax处理
  9. classificationLayer] %分类层

 如上代码片段所示,逐行对每一层的网络处理模式进行说明,并将此构成的数组形式赋予一个变量(此中为layers)。

2. 指定训练的超参数

  1. %% 指定训练选项
  2. options = trainingOptions('adam', ...
  3. 'ExecutionEnvironment','cpu', ...
  4. 'GradientThreshold',1, ...
  5. 'MaxEpochs',maxEpochs, ...
  6. 'MiniBatchSize',miniBatchSize, ...
  7. 'SequenceLength','longest', ...
  8. 'Shuffle','never', ...
  9. 'Verbose',0, ...
  10. 'Plots','training-progress');

其中参数设定的格式必须由trainingOptions进行格式打包,即options的数据类型必须为TrainingOptions派,此示例中options的数据类型为TrainingOptionsADAM。另外本示例数据集方面的XTrain与XTest都是matlab中的cell数据类型。

 更详细的matlab深度学习训练方法(trainNetwork用法)请参考matlab官方文档或本博客Deep Learning ToolBox系列7。

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/112540
推荐阅读
相关标签
  

闽ICP备14008679号