当前位置:   article > 正文

[开源] 基于SARIMA的时间序列预测模型python代码_python实现sarima模型

python实现sarima模型

 整理了基于SARIMA的时间序列预测模型python代码,免费分享给大家,记得点赞哦!

  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # # 导入环境中的相关包
  4. import itertools
  5. import numpy as np #
  6. import pandas as pd #
  7. import matplotlib.pyplot as plt
  8. from matplotlib.ticker import MultipleLocator
  9. import warnings
  10. from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
  11. from statsmodels.stats.diagnostic import acorr_ljungbox
  12. from statsmodels.tsa.statespace.sarimax import SARIMAX
  13. from sklearn.metrics import r2_score,mean_absolute_error,mean_squared_error
  14. from statsmodels.tsa.stattools import adfuller
  15. import math
  16. import seaborn as sns
  17. import statsmodels.api as sm
  18. import tensorflow as tf
  19. from pmdarima import auto_arima
  20. #显示中文
  21. #忽略警告
  22. warnings.filterwarnings('ignore')
  23. plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
  24. plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
  25. plt.rcParams['image.interpolation'] = 'nearest'
  26. plt.rcParams['image.cmap'] = 'gray'
  27. # 调用GPU加速
  28. gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
  29. for gpu in gpus:
  30. tf.config.experimental.set_memory_growth(gpu, True)
  31. df = pd.read_csv("shao - 单.csv",usecols=[1]) #读取数据
  32. df.head()
  33. plt.figure(figsize=(15, 3))
  34. plt.title('风速')
  35. plt.xlabel('时间')
  36. plt.ylabel('最大风速')
  37. plt.plot(df, 'b', label='AQI')
  38. plt.legend()
  39. plt.show()
  40. #定义稳定性检验函数
  41. def adf_val(ts, ts_title):
  42. adf, pvalue, usedlag, nobs, critical_values, icbest = adfuller(ts)
  43. name = ['adf', 'pvalue', 'usedlag',
  44. 'nobs', 'critical_values', 'icbest']
  45. values = [adf, pvalue, usedlag, nobs,
  46. critical_values, icbest]
  47. print(list(zip(name, values)))
  48. return adf, pvalue, critical_values,
  49. # 返回adf值、adf的p值、三种状态的检验值
  50. #白噪声检验也称为纯随机性检验,当数据是纯随机数据时,再对数据进行分析就没有任何意义了,所以拿到数据后最好对数据进行一个纯随机性检验。
  51. def acorr_val(ts):
  52. '''
  53. # 白噪声(随机性)检验
  54. ts: 时间序列数据,Series类型
  55. 返回白噪声检验的P值
  56. '''
  57. lbvalue, pvalue = acorr_ljungbox(ts, lags=1) # 白噪声检验结果
  58. return lbvalue, pvalue
  59. def tsplot(y, lags=None, figsize=(14, 8)):
  60. fig = plt.figure(figsize=figsize)
  61. layout = (2, 2)
  62. ts_ax = plt.subplot2grid(layout, (0, 0))
  63. hist_ax = plt.subplot2grid(layout, (0, 1))
  64. acf_ax = plt.subplot2grid(layout, (1, 0))
  65. pacf_ax = plt.subplot2grid(layout, (1, 1))
  66. y.plot(ax=ts_ax)
  67. ts_ax.set_title('A Given Training Series')
  68. y.plot(ax=hist_ax, kind='hist', bins=25)
  69. hist_ax.set_title('Histogram')
  70. #自相关(Autocorrelation): 对一个时间序列,现在值与其过去值的相关性。如果相关性为正,则说明现有趋势将继续保持。
  71. plot_acf(y, lags=lags, ax=acf_ax)
  72. #可以度量现在值与过去值更纯正的相关性
  73. plot_pacf(y, lags=lags, ax=pacf_ax)
  74. [ax.set_xlim(0) for ax in [acf_ax, pacf_ax]]
  75. sns.despine()
  76. fig.tight_layout()
  77. fig.show()
  78. return ts_ax, acf_ax, pacf_ax
  79. ts_data = df.astype('float32')
  80. #adf结果为-10.4, 小于三个level的统计值。pvalue也是接近于0 的,所以是平稳的
  81. adf, pvalue1, critical_values = adf_val(ts_data, 'raw time series')
  82. print('adf',adf)
  83. print('pvalue1',pvalue1)
  84. print('critical_values',critical_values)
  85. #若p值远小于0.01,认为该时间序列是平稳的
  86. aco=acorr_val(ts_data)
  87. print('aco',aco)
  88. ##自相关和偏自相关
  89. tsplot(ts_data, lags=20)
  90. train_data, test_data = df[0:int(len(df)*0.8)], df[int(len(df)*0.8):]
  91. #画出训练集和测试集的原数据(open 价格)
  92. plt.figure(dpi=100, figsize=(20,5))
  93. plt.title('Air Quality Index of Nanning City', size=40)
  94. plt.xlabel('time/day',size=30)
  95. plt.ylabel('AQI',size=30)
  96. plt.plot(train_data, 'b', label='Training Data',linewidth=3)
  97. plt.plot(test_data, 'g', label='Testing Data',linewidth=3)
  98. font = {'serif': 'Times New Roman','size': 30}
  99. plt.rc('font', **font)
  100. plt.legend()
  101. plt.show()
  102. #取划分的数据
  103. train_ar = train_data.values
  104. test_ar = test_data.values
  105. auto_arima(train_data, seasonal=True, m=12,max_p=7, max_d=2,max_q=7, max_P=4, max_D=4,max_Q=4).summary()
  106. def best_sarima_model(train_data,p,q,P,Q,d=1,D=1,s=12):
  107. best_model_aic = np.Inf
  108. best_model_bic = np.Inf
  109. best_model_hqic = np.Inf
  110. best_model_order = (0,0,0)
  111. models = []
  112. for p_ in p:
  113. for q_ in q:
  114. for P_ in P:
  115. for Q_ in Q:
  116. try:
  117. no_of_lower_metrics = 0
  118. model = SARIMAX(endog=train_data,order=(p_,d,q_), seasonal_order=(P_,D,Q_,s),
  119. enforce_invertibility=False).fit()
  120. models.append(model)
  121. if model.aic <= best_model_aic: no_of_lower_metrics+=1
  122. if model.bic <= best_model_bic: no_of_lower_metrics+=1
  123. if model.hqic <= best_model_hqic:no_of_lower_metrics+=1
  124. if no_of_lower_metrics >= 2:
  125. best_model_aic = np.round(model.aic,0)
  126. best_model_bic = np.round(model.bic,0)
  127. best_model_hqic = np.round(model.hqic,0)
  128. best_model_order = (p_,d,q_,P_,D,Q_,s)
  129. current_best_model = model
  130. models.append(model)
  131. print("Best model: SARIMA" + str(best_model_order) +
  132. " AIC:{} BIC:{} HQIC:{}".format(best_model_aic,best_model_bic,best_model_hqic)+
  133. " resid:{}".format(np.round(np.exp(current_best_model.resid).mean(),3)))
  134. except:
  135. pass
  136. print('\n')
  137. print(current_best_model.summary())
  138. return current_best_model, models
  139. best_model, models = best_sarima_model(train_data=train_ar,p=range(3),q=range(3),P=range(3),Q=range(3))
  140. p = range(0, 3)
  141. d = range(0, 1)
  142. q = range(0, 3)
  143. pdq = list(itertools.product(p, d, q))
  144. seasonal_pdq = [(x[0], x[1], x[2], 6) for x in list(itertools.product(p, d, q))]
  145. min_aic = 999999999
  146. for param in pdq:
  147. for param_seasonal in seasonal_pdq:
  148. try:
  149. mod = sm.tsa.statespace.SARIMAX(train_ar,
  150. order=param,
  151. seasonal_order=param_seasonal,
  152. enforce_stationarity=False,
  153. enforce_invertibility=False)
  154. results = mod.fit()
  155. print('ARIMA{}x{}12 - AIC:{}'.format(param, param_seasonal, results.aic))
  156. if results.aic < min_aic:
  157. min_aic = results.aic
  158. min_aic_model = results
  159. except:
  160. continue
  161. min_aic_model.summary()
  162. # # 构建训练数据
  163. history = [x for x in train_ar]
  164. print(type(history))
  165. predictions = list()
  166. #训练ARIMA模型
  167. for t in range(len(test_ar)):
  168. model = sm.tsa.SARIMAX(history,order=(2,1,1), seasonal_order=(0,0,1,12),enforce_invertibility=False)
  169. model_fit = model.fit()
  170. output = model_fit.forecast()#模型预测
  171. yhat = output[0]
  172. predictions.append(yhat)
  173. obs = test_ar[t]
  174. history.append(obs)
  175. print('predicted=%f, expected=%f' % (yhat, obs))
  176. testScore = math.sqrt(mean_squared_error(test_ar, predictions))
  177. print('RMSE %.3f ' %(testScore))
  178. testScore = r2_score(test_ar, predictions)
  179. print('R2 %.3f' %(testScore))
  180. testScore = mean_absolute_error(test_ar, predictions)
  181. print('MAE %.3f ' %(testScore))
  182. #只显示预测部分,不显示训练数据部分
  183. plt.figure(figsize=(12,7))
  184. plt.plot(test_data.index, predictions, color='b', marker='o', linestyle='dashed',label='Predicted')
  185. plt.plot(test_data.index, test_data, color='red', label='Actual')
  186. plt.title('SARIMA')
  187. plt.xlabel('time')
  188. plt.ylabel('AQI')
  189. plt.legend()
  190. plt.show()

更多时间序列预测代码:时间序列预测算法全集合--深度学习

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/511335
推荐阅读
相关标签
  

闽ICP备14008679号