当前位置:   article > 正文

时间序列预测—双向LSTM(Bi-LSTM)_bi-lstm代码

bi-lstm代码

        本文展示了使用双向LSTM(Bi-LSTM)进行时间序列预测的全过程,包含详细的注释。整个过程主要包括:数据导入、数据清洗、结构转化、建立Bi-LSTM模型、训练模型(包括动态调整学习率和earlystopping的设置)、预测、结果展示、误差评估等完整的时间序列预测流程。
  本文使用的数据集在本人上传的资源中,链接为mock_kaggle.csv

代码如下:

  1. import pandas as pd
  2. import numpy as np
  3. import math
  4. import keras
  5. from matplotlib import pyplot as plt
  6. from matplotlib.pylab import mpl
  7. import tensorflow as tf
  8. from sklearn.preprocessing import MinMaxScaler
  9. from keras import backend as K
  10. from keras.layers import LeakyReLU
  11. from sklearn.metrics import mean_squared_error # 均方误差
  12. from keras.callbacks import LearningRateScheduler
  13. from keras.callbacks import EarlyStopping
  14. from tensorflow.keras import Input, Model,Sequential
  15. from keras.layers import Bidirectional#, Concatenate
  1. mpl.rcParams['font.sans-serif'] = ['SimHei'] #显示中文
  2. mpl.rcParams['axes.unicode_minus']=False #显示负号

取数据

  1. data=pd.read_csv('mock_kaggle.csv',encoding ='gbk',parse_dates=['datetime'])
  2. Date=pd.to_datetime(data.datetime)
  3. data['date'] = Date.map(lambda x: x.strftime('%Y-%m-%d'))
  4. datanew=data.set_index(Date)
  5. series = pd.Series(datanew['股票'].values, index=datanew['date'])
series
  1. date
  2. 2014-01-01 4972
  3. 2014-01-02 4902
  4. 2014-01-03 4843
  5. 2014-01-04 4750
  6. 2014-01-05 4654
  7. ...
  8. 2016-07-27 3179
  9. 2016-07-28 3071
  10. 2016-07-29 4095
  11. 2016-07-30 3825
  12. 2016-07-31 3642
  13. Length: 937, dtype: int64

滞后扩充数据

  1. dataframe1 = pd.DataFrame()
  2. num_hour = 16
  3. for i in range(num_hour,0,-1):
  4. dataframe1['t-'+str(i)] = series.shift(i)
  5. dataframe1['t'] = series.values
  6. dataframe3=dataframe1.dropna()
  7. dataframe3.index=range(len(dataframe3))
dataframe3
t-16t-15t-14t-13t-12t-11t-10t-9t-8t-7t-6t-5t-4t-3t-2t-1t
04972.04902.04843.04750.04654.04509.04329.04104.04459.05043.05239.05118.04984.04904.04822.04728.04464
14902.04843.04750.04654.04509.04329.04104.04459.05043.05239.05118.04984.04904.04822.04728.04464.04265
24843.04750.04654.04509.04329.04104.04459.05043.05239.05118.04984.04904.04822.04728.04464.04265.04161
34750.04654.04509.04329.04104.04459.05043.05239.05118.04984.04904.04822.04728.04464.04265.04161.04091
44654.04509.04329.04104.04459.05043.05239.05118.04984.04904.04822.04728.04464.04265.04161.04091.03964
......................................................
9161939.01967.01670.01532.01343.01022.0813.01420.01359.01075.01015.0917.01550.01420.01358.02893.03179
9171967.01670.01532.01343.01022.0813.01420.01359.01075.01015.0917.01550.01420.01358.02893.03179.03071
9181670.01532.01343.01022.0813.01420.01359.01075.01015.0917.01550.01420.01358.02893.03179.03071.04095
9191532.01343.01022.0813.01420.01359.01075.01015.0917.01550.01420.01358.02893.03179.03071.04095.03825
9201343.01022.0813.01420.01359.01075.01015.0917.01550.01420.01358.02893.03179.03071.04095.03825.03642

显示详细信息

921 rows × 17 columns

二折划分数据并标准化

  1. pd.DataFrame(np.random.shuffle(dataframe3.values)) #shuffle
  2. pot=len(dataframe3)-12
  3. train=dataframe3
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/327747
推荐阅读
相关标签
  

闽ICP备14008679号