当前位置:   article > 正文

Maltab基于长短期记忆神经网络(LSTM)的多输入多输出分类任务实现——附代码_lstm多输入多输出

lstm多输入多输出

目录

简介:

基本结构:

LSTM网络的训练方法:

使用LSTM的多输入特征分类

加载数据并预处理:

定义LSTM网络:

训练LSTM网络:

测试LSTM网络:

本文Matlab代码分享


简介:

此示例说明如何使用长短期记忆(LSTM)网络对序列数据进行分类。要训练深度神经网络以对序列数据进行分类,可以使用LSTM网络。LSTM网络允许您将序列数据输入网络,并根据序列数据的各个时间步进行预测。此示例使用Matlab自带的数据集。此示例训练一个LSTM网络,旨在根据表示连续说出的两个日语元音的时间序列数据来识别说话者(多特征输入的分类任务)。训练数据包含九个说话者的时间序列数据。每个序列有12个特征,且长度不同。该数据集包含270个训练观测值和370个测试观测值。本代码基于Matlab编写,注释详细,可改性强,适合初学者学习。程序已标准化处理,方便使用者替换数据实现不同的功能。

其余关于CNNLSTM或其他神经网络执行不同任务的文章可以看我的主页。

长短期记忆人工神经网络LSTM)介绍:

基本结构:

长短期记忆网络(LSTMLong Short-Term Memory)是一种时间循环神经网络,是为了解决一般的RNN循环神经网络)存在的长期依赖问题而专门设计出来的,所有的RNN都具有一种重复神经网络模块的链式形式。在标准RNN中,这个重复的结构模块只有一个非常简单的结构,例如一个tanh层。

LSTM是一种含有LSTM区块(blocks)或其他的一种类神经网络,文献或其他资料中LSTM区块可能被描述成智能网络单元,因为它可以记忆不定时间长度的数值,区块中有一个gate能够决定input是否重要到能被记住及能不能被输出output

1底下是四个S函数单元,最左边函数依情况可能成为区块的input,右边三个会经过gate决定input是否能传入区块,左边第二个为input gate,如果这里产出近似于零,将把这里的值挡住,不会进到下一层。左边第三个是forget gate,当这产生值近似于零,将把区块里记住的值忘掉。第四个也就是最右边的inputoutput gate,他可以决定在区块记忆中的input是否能输出。

LSTM有很多个版本,其中一个重要的版本是GRUGated Recurrent Unit),根据谷歌的测试表明,LSTM中最重要的是Forget gate,其次是Input gate,最次是Output gate

LSTM网络的训练方法:

为了最小化训练误差,梯度下降法(Gradient descent)如:应用时序性倒传递算法,可用来依据错误修改每次的权重。梯度下降法在递回神经网络(RNN)中主要的问题初次在1991年发现,就是误差梯度随着事件间的时间长度成指数般的消失。当设置了LSTM 区块时,误差也随着倒回计算,从output影响回input阶段的每一个gate,直到这个数值被过滤掉。因此正常的倒传递类神经是一个有效训练LSTM区块记住长时间数值的方法。

使用LSTM的多输入特征分类

加载数据并预处理:

加载日语元音训练数据。 是包含 270 个不同长度的 12 维序列的元胞数组。 是对应于九个说话者的标签 "1""2"..."9" 的分类向量。 中的条目是具有 12 行(每个特征一行)和不同列数(每个时间步一列)的矩阵。在绘图中可视化第一个时间序列。每行对应一个特征。

在训练过程中,默认情况下,软件将训练数据拆分成小批量并填充序列,使它们具有相同的长度。过多填充会对网络性能产生负面影响。

为了防止训练过程添加过多填充,您可以按序列长度对训练数据进行排序,并选择合适的小批量大小,以使同一小批量中的序列长度相近。下图显示了对数据进行排序之前和之后填充序列的效果。

定义LSTM网络:

定义LSTM网络架构。将输入大小指定为序列大小12(输入数据的维度)。指定具有 100 个隐含单元的双向LSTM层,并输出序列的最后一个元素。最后,通过包含大小为 9 的全连接层,后跟 softmax 层和分类层,来指定九个类。

如果您可以在预测时访问完整序列,则可以在网络中使用双向LSTM层。双向LSTM层在每个时间步从完整序列学习。如果您不能在预测时访问完整序列,例如,您正在预测值或一次预测一个时间步时,则改用LSTM层。

训练LSTM网络:

使用以指定的训练选项训练LSTM网络。trainNetwork

测试LSTM网络:

加载测试集并将序列分类到不同的说话者。

加载日语元音测试数据。是包含 370 个不同长度的 12 维序列的元胞数组。是由对应于九个说话者的标签 "1""2"..."9" 组成的分类向量。

计算预测值的分类准确度。

可以看到,所构建的LSTM经过充分的训练后分类的准确度达到了96.22%,很好的完成了多输入特征的分类任务。

本文Matlab代码分享

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/525544
推荐阅读
相关标签
  

闽ICP备14008679号