当前位置:   article > 正文

时间序列预测——LSTM模型(附代码实现)_lstm模型代码

lstm模型代码

目录

模型原理

模型实现

导入所需要的库

设置随机数种子

导入数据集

打印前五行数据进行查看

数据处理

归一化处理

查看归一化处理后的数据

将时间序列转换为监督学习问题

打印数据前五行

 划分训练集和测试集

查看划分后的数据维度

搭建LSTM模型

 得到损失图

模型预测

画图展示

得到预测图像 

回归评价指标


需要完整源码的联系QQ:2625520691(知识成果,白嫖勿扰)

模型原理

        长短时记忆网络( Long short-term memory,LSTM )是一种循环神经网络 (Recurrent neural network, RNN)的特殊变体,具有“门”结构,通过门单元的逻辑控制决定数据是否更新或是选择丢弃,克服了 RNN 权重影响过大、容易产生梯度消失和爆炸的缺点,使网络可以更好、更快地收敛,能够有效提高预测精度。LSTM 拥有三个门, 分别为遗忘门、输入门、输出门,以此决定每一时刻信息记忆与遗忘。输入门决定有多少新的信息加入到细胞当中,遗忘门控制每一时刻信息是否会被遗忘,输出门决定每一时刻是否有信息输出。其基本结构如图所示。

公式如下:

(1)遗忘门

(2)输入门

(3)单元

(4)输出门

(5)最终输出

模型实现

导入所需要的库

  1. import matplotlib.pyplot as plt
  2. from pandas import read_csv
  3. from pandas import DataFrame
  4. from pandas import concat
  5. from sklearn.preprocessing import MinMaxScaler
  6. from tensorflow.keras.models import Sequential
  7. from tensorflow.keras.layers import LSTM,Dense,Dropout
  8. from numpy import concatenate
  9. from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score
  10. from math import sqrt

设置随机数种子

  1. import tensorflow as tf
  2. tf.random.set_seed(2)

导入数据集

  1. qy_data=read_csv(r'C:\Users\HUAWEI\Desktop\abc.csv',parse_dates=['num'],index_col='num')
  2. qy_data.index.name='num' #选定索引列

打印前五行数据进行查看

数据处理

  1. # 获取DataFrame中的数据,形式为数组array形式
  2. values = qy_data.values
  3. # 确保所有数据为float类型
  4. values = values.astype('float32')

归一化处理

使用MinMaxScaler缩放器,将全部数据都缩放到[0,1]之间,加快收敛。

  1. scaler = MinMaxScaler(feature_range=(0, 1))
  2. scaled = scaler.fit_transform(values)

查看归一化处理后的数据

  

时间序列转换为监督学习问题

时间序列形式的数据转换为监督学习集的形式,例如:[[10],[11],[12],[13],[14]]转换为[[0,10],[10,11],[11,12],[12,13],[13,14]],即把前一个数作为输入,后一个数作为对应输出。

  1. def series_to_supervised(data, n_in=1, n_out=1, dropnan=True):
  2. n_vars = 1 if type(data) is list else data.shape[1]
  3. df = DataFrame(data)
  4. cols, names = list(), list()
  5. # input sequence (t-n, ... t-1)
  6. for i in range(n_in, 0, -1):
  7. cols.append(df.shift(i))
  8. names += [('var%d(t-%d)' % (j + 1, i)) for j in range(n_vars)]
  9. # forecast sequence (t, t+1, ... t+n)
  10. for i in range(0, n_out):
  11. cols.append(df.shift(-i))
  12. if i == 0:
  13. names += [('var%d(t)' % (j + 1)) for j in range(n_vars)]
  14. else:
  15. names += [('var%d(t+%d)' % (j + 1, i)) for j in range(n_vars)]
  16. # put it all together
  17. agg = concat(cols, axis=1)
  18. agg.columns = names
  19. # drop rows with NaN values
  20. if dropnan:
  21. agg.dropna(inplace=True)
  22. return agg
  23. reframed = series_to_supervised(scaled, 2, 1)

打印数据前五行

  

 划分训练集和测试集

  1. # 划分训练集和测试集
  2. values = reframed.values
  3. trainNum = int(len(values) * 0.7)
  4. train = values[:trainNum,:]
  5. test = values[trainNum:, :]

查看划分后的数据维度

  1. print(train_X.shape, train_y.shape)
  2. print(test_X.shape, test_y.shape)

 

搭建LSTM模型

初始化LSTM模型,设置神经元核心的个数,迭代次数,优化器等等

  1. model = Sequential()
  2. model.add(LSTM(27, input_shape=(train_X.shape[1], train_X.shape[2])))
  3. model.add(Dropout(0.5))
  4. model.add(Dense(15,activation='relu'))#激活函数
  5. model.compile(loss='mae', optimizer='adam')
  6. history = model.fit(train_X, train_y, epochs=95, batch_size=2, validation_data=(test_X, test_y), verbose=2,shuffle=False)

 得到损失图

模型预测

  1. y_predict = model.predict(test_X)
  2. test_X = test_X.reshape((test_X.shape[0], test_X.shape[2]))

画图展示

  1. plt.figure(figsize=(10,8),dpi=150)
  2. plt.plot(inv_y,color='red',label='Original')
  3. plt.plot(inv_y_predict,color='green',label='Predict')
  4. plt.xlabel('the number of test data')
  5. plt.ylabel('Soil moisture')
  6. plt.legend()
  7. plt.show()

得到预测图像 

将测试集的y值和预测值绘制在同一张图表中

  

回归评价指标

  1. # calculate MSE 均方误差
  2. mse=mean_squared_error(inv_y,inv_y_predict)
  3. # calculate RMSE 均方根误差
  4. rmse = sqrt(mean_squared_error(inv_y, inv_y_predict))
  5. #calculate MAE 平均绝对误差
  6. mae=mean_absolute_error(inv_y,inv_y_predict)
  7. #calculate R square
  8. r_square=r2_score(inv_y,inv_y_predict)
  9. print('均方误差MSE: %.6f' % mse)
  10. print('均方根误差RMSE: %.6f' % rmse)
  11. print('平均绝对误差MAE: %.6f' % mae)
  12. print('R_square: %.6f' % r_square)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/134780
推荐阅读
相关标签
  

闽ICP备14008679号