赞
踩
TimeGPT
理解时间序列预测的复杂需求,融合了 cross_validation
方法,旨在简化时间序列模型的验证过程。这个功能使从业者能够对历史数据严格测试他们的预测模型,评估它们的有效性,同时调整它们以获得最佳性能。本教程将指导您完成在 TimeGPT 类中进行交叉验证的微妙过程,确保您的时间序列预测模型不仅构建良好,而且经过验证是值得信赖和精确的。
# 导入colab_badge模块,用于生成Colab徽章
from nixtlats.utils import colab_badge
colab_badge('docs/tutorials/9_cross_validation')
# 导入必要的库
import numpy as np
from dotenv import load_dotenv
# 加载dotenv模块,用于从.env文件中加载环境变量
load_dotenv()
True
# 导入pandas库
import pandas as pd
# 导入TimeGPT类
from nixtlats import TimeGPT
# 创建TimeGPT对象,并传入token参数
# 如果没有传入token参数,则默认使用环境变量中的TIMEGPT_TOKEN
timegpt = TimeGPT(
token='my_token_provided_by_nixtla'
)
# 创建一个TimeGPT对象,用于生成时间相关的文本。
timegpt = TimeGPT()
TimeGPT
类中的cross_validation
方法是一种高级功能,用于对时间序列预测模型进行系统验证。该方法需要一个包含按时间排序的数据的数据帧,并采用滚动窗口方案来精确评估模型在不同时间段的性能,从而确保模型的可靠性和稳定性。
关键参数包括freq
,它表示数据的频率,如果未指定,则会自动推断。id_col
、time_col
和target_col
参数分别指定每个系列的标识符、时间步长和目标值的列。该方法通过参数进行自定义,例如n_windows
表示评估模型的独立时间窗口的数量,step_size
确定这些窗口之间的间隔。如果未指定step_size
,则默认为预测的时间范围h
。
该过程还允许通过finetune_steps
进行模型细化,指定在新数据上进行模型微调的迭代次数。通过clean_ex_first
参数可以管理数据预处理,决定是否在预测之前清理外生信号。此外,该方法还支持通过date_features
参数从时间数据进行增强特征工程,该参数可以自动生成关键的与日期相关的特征,也可以接受自定义函数进行定制特征创建。date_features_to_one_hot
参数进一步支持将分类日期特征转换为适合机器学习模型的格式。
在执行过程中,cross_validation
在每个窗口中评估模型的预测准确性,提供了模型性能随时间变化和过度拟合的稳健视图。这种详细评估确保生成的预测不仅准确,而且在不同的时间背景下保持一致。
# 读取数据集 pm_df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/peyton_manning.csv') # 使用timegpt库的cross_validation函数对数据进行交叉验证 # 参数说明: # - pm_df: 待验证的数据集 # - h: 预测的时间步数 # - n_windows: 窗口数量,用于划分训练集和验证集 # - time_col: 时间列的列名 # - target_col: 目标列的列名 # - freq: 时间频率,这里设定为每天 timegpt_cv_df = timegpt.cross_validation( pm_df, h=7, n_windows=5, time_col='timestamp', target_col='value', freq='D', ) # 打印交叉验证结果的前几行 timegpt_cv_df.head()
INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs...
timestamp | cutoff | value | TimeGPT | |
---|---|---|---|---|
0 | 2015-12-17 | 2015-12-16 | 7.591862 | 7.939553 |
1 | 2015-12-18 | 2015-12-16 | 7.528869 | 7.887512 |
2 | 2015-12-19 | 2015-12-16 | 7.171657 | 7.766617 |
3 | 2015-12-20 | 2015-12-16 | 7.891331 | 7.931502 |
4 | 2015-12-21 | 2015-12-16 | 8.360071 | 8.312632 |
# 导入IPython.display模块中的display函数
from IPython.display import display
# 从timegpt_cv_df数据框中获取唯一的cutoff值,并赋值给变量cutoffs cutoffs = timegpt_cv_df['cutoff'].unique() # 遍历cutoffs中的每个cutoff值 for cutoff in cutoffs: # 使用timegpt.plot函数绘制图形,并将结果赋值给变量fig # 绘图所需的数据为pm_df的最后100行和timegpt_cv_df中cutoff等于当前遍历值的行,删除列'cutoff'和'value' # 指定时间列为'timestamp',目标列为'value' fig = timegpt.plot( pm_df.tail(100), timegpt_cv_df.query('cutoff == @cutoff').drop(columns=['cutoff', 'value']), time_col='timestamp', target_col='value' ) # 显示图形 display(fig)
为了评估TimeGPT
在分布预测方面的性能,您可以使用level
参数生成预测区间。
# 导入所需模块和函数 # 使用timegpt.cross_validation函数进行时间序列交叉验证 # 参数pm_df为待验证的时间序列数据 # 参数h为预测的时间步长,这里设置为7 # 参数n_windows为窗口数量,这里设置为5 # 参数time_col为时间列的列名,这里设置为'timestamp' # 参数target_col为目标列的列名,这里设置为'value' # 参数freq为时间序列的频率,这里设置为'D',表示按天 # 参数level为置信水平,这里设置为[80, 90],表示计算80%和90%的置信区间 # 返回值timegpt_cv_df为交叉验证结果的数据框 timegpt_cv_df = timegpt.cross_validation( pm_df, h=7, n_windows=5, time_col='timestamp', target_col='value', freq='D', level=[80, 90], ) # 输出交叉验证结果的前几行数据 timegpt_cv_df.head()
INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Restricting input... INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Restricting input... INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Restricting input... INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Restricting input... INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Restricting input... INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs...
timestamp | cutoff | value | TimeGPT | TimeGPT-lo-90 | TimeGPT-lo-80 | TimeGPT-hi-80 | TimeGPT-hi-90 | |
---|---|---|---|---|---|---|---|---|
0 | 2015-12-17 | 2015-12-16 | 7.591862 | 7.939553 | 7.564151 | 7.675945 | 8.203161 | 8.314956 |
1 | 2015-12-18 | 2015-12-16 | 7.528869 | 7.887512 | 7.567342 | 7.598298 | 8.176726 | 8.207681 |
2 | 2015-12-19 | 2015-12-16 | 7.171657 | 7.766617 | 7.146560 | 7.266829 | 8.266404 | 8.386674 |
3 | 2015-12-20 | 2015-12-16 | 7.891331 | 7.931502 | 7.493021 | 7.657075 | 8.205929 | 8.369982 |
4 | 2015-12-21 | 2015-12-16 | 8.360071 | 8.312632 | 7.017335 | 7.446677 | 9.178586 | 9.607928 |
# 获取时间截断点的唯一值 cutoffs = timegpt_cv_df['cutoff'].unique() # 遍历每个截断点 for cutoff in cutoffs: # 绘制图表 fig = timegpt.plot( # 绘制最近100个数据点 pm_df.tail(100), # 查询截断点等于当前截断点的数据,并删除'cutoff'和'value'列 timegpt_cv_df.query('cutoff == @cutoff').drop(columns=['cutoff', 'value']), # 设置时间列为'timestamp' time_col='timestamp', # 设置目标列为'value' target_col='value', # 设置置信水平为[80, 90] level=[80, 90], # 设置模型为'TimeGPT' models=['TimeGPT'] ) # 显示图表 display(fig)
您还可以包括date_features
以查看它们对预测准确性的影响。
# 对于给定的时间序列数据,进行时间序列交叉验证 # 使用timegpt.cross_validation函数进行交叉验证 # 参数说明: # - pm_df: 待验证的时间序列数据 # - h: 预测的时间步长 # - n_windows: 窗口的数量,将时间序列数据划分为多个窗口进行交叉验证 # - time_col: 时间列的名称,用于指定时间序列数据中的时间信息 # - target_col: 目标列的名称,用于指定时间序列数据中的目标变量 # - freq: 时间序列数据的频率,以天为单位 # - level: 置信水平,用于计算预测区间 # - date_features: 日期特征,用于提取时间序列数据中的日期信息 # 返回值为交叉验证结果的数据框 timegpt_cv_df = timegpt.cross_validation( pm_df, h=7, n_windows=5, time_col='timestamp', target_col='value', freq='D', level=[80, 90], date_features=['month'], ) # 输出交叉验证结果的前几行数据 timegpt_cv_df.head()
INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Using the following exogenous variables: month_1, month_2, month_3, month_4, month_5, month_6, month_7, month_8, month_9, month_10, month_11, month_12 INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Using the following exogenous variables: month_1, month_2, month_3, month_4, month_5, month_6, month_7, month_8, month_9, month_10, month_11, month_12 INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Using the following exogenous variables: month_1, month_2, month_3, month_4, month_5, month_6, month_7, month_8, month_9, month_10, month_11, month_12 INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Using the following exogenous variables: month_1, month_2, month_3, month_4, month_5, month_6, month_7, month_8, month_9, month_10, month_11, month_12 INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Using the following exogenous variables: month_1, month_2, month_3, month_4, month_5, month_6, month_7, month_8, month_9, month_10, month_11, month_12 INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs...
timestamp | cutoff | value | TimeGPT | TimeGPT-lo-90 | TimeGPT-lo-80 | TimeGPT-hi-80 | TimeGPT-hi-90 | |
---|---|---|---|---|---|---|---|---|
0 | 2015-12-17 | 2015-12-16 | 7.591862 | 7.945311 | 7.542366 | 7.647852 | 8.242769 | 8.348255 |
1 | 2015-12-18 | 2015-12-16 | 7.528869 | 7.892559 | 7.271274 | 7.481059 | 8.304058 | 8.513843 |
2 | 2015-12-19 | 2015-12-16 | 7.171657 | 7.771581 | 7.113544 | 7.281711 | 8.261451 | 8.429619 |
3 | 2015-12-20 | 2015-12-16 | 7.891331 | 7.939502 | 6.988198 | 7.345371 | 8.533633 | 8.890807 |
4 | 2015-12-21 | 2015-12-16 | 8.360071 | 8.320170 | 7.140163 | 7.658314 | 8.982027 | 9.500178 |
# 获取时间戳的唯一值 cutoffs = timegpt_cv_df['cutoff'].unique() # 遍历每个唯一的时间戳 for cutoff in cutoffs: # 使用timegpt.plot函数绘制图形 # 参数1:使用pm_df的最后100行数据作为输入数据 # 参数2:使用timegpt_cv_df中cutoff等于当前遍历的时间戳的数据,删除cutoff和value列作为输入数据 # 参数3:指定时间戳列为timestamp # 参数4:指定目标值列为value # 参数5:指定80和90为置信水平 # 参数6:指定使用TimeGPT模型 fig = timegpt.plot( pm_df.tail(100), timegpt_cv_df.query('cutoff == @cutoff').drop(columns=['cutoff', 'value']), time_col='timestamp', target_col='value', level=[80, 90], models=['TimeGPT'] ) # 显示图形 display(fig)
此外,您可以传递外生变量以更好地向TimeGPT
提供关于数据的信息。您只需在目标列之后简单地添加外生回归变量即可。
# 读取电力数据集Y_df,数据来自'https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/electricity.csv'
Y_df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/electricity.csv')
# 读取外部变量数据集X_df,数据来自'https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/exogenous-vars-electricity.csv'
X_df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/exogenous-vars-electricity.csv')
# 将Y_df和X_df数据集进行合并,合并后的数据集为df
df = Y_df.merge(X_df)
现在让我们使用这些信息对TimeGPT
进行交叉验证。
# 导入TimeGPT模型
timegpt = TimeGPT(max_retries=2, retry_interval=5) # 创建TimeGPT对象,设置最大重试次数为2,重试间隔为5秒
# 导入的库已经存在,不需要添加import语句 # 对数据进行交叉验证,将数据按照unique_id分组,每组取最后的100*48个数据进行交叉验证 # h=48表示预测未来48个时间点的值,n_windows=2表示将数据分为两个窗口进行交叉验证 # level=[80, 90]表示计算80%和90%置信区间 timegpt_cv_df_x = timegpt.cross_validation( df.groupby('unique_id').tail(100 * 48), h=48, n_windows=2, level=[80, 90] ) # 查询unique_id为"BE"的数据的cutoff值,并将其存储在cutoffs中 cutoffs = timegpt_cv_df_x.query('unique_id == "BE"')['cutoff'].unique() # 遍历cutoffs中的每个cutoff值,对unique_id为"BE"的数据进行预测并绘制图表 for cutoff in cutoffs: # 绘制unique_id为"BE"的数据的最后24*7个时间点的真实值和预测值,并将其存储在fig中 # timegpt_cv_df_x.query('cutoff == @cutoff & unique_id == "BE"')表示查询cutoff值为当前遍历到的cutoff值,unique_id为"BE"的数据 # drop(columns=['cutoff', 'y'])表示删除查询结果中的cutoff和y两列 # models=['TimeGPT']表示使用TimeGPT模型进行预测 # level=[80, 90]表示计算80%和90%置信区间 fig = timegpt.plot( df.query('unique_id == "BE"').tail(24 * 7), timegpt_cv_df_x.query('cutoff == @cutoff & unique_id == "BE"').drop(columns=['cutoff', 'y']), models=['TimeGPT'], level=[80, 90], ) # 显示图表 display(fig)
INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Inferred freq: H INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Inferred freq: H WARNING:nixtlats.timegpt:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon. INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6 INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Validating inputs... INFO:nixtlats.timegpt:Preprocessing dataframes... INFO:nixtlats.timegpt:Inferred freq: H WARNING:nixtlats.timegpt:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon. INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6 INFO:nixtlats.timegpt:Calling Forecast Endpoint... INFO:nixtlats.timegpt:Validating inputs...
此外,您可以使用model
参数为不同的TimeGPT
实例生成交叉验证。
# 对数据进行交叉验证 timegpt_cv_df_x_long_horizon = timegpt.cross_validation( df.groupby('unique_id').tail(100 * 48), # 对数据进行分组,每个组取最后的100 * 48个数据 h=48, # 预测的时间步长为48 n_windows=2, # 使用2个窗口进行交叉验证 level=[80, 90], # 设置置信水平为80%和90% model='timegpt-1-long-horizon', # 使用timegpt-1-long-horizon模型 ) # 将列名中的'TimeGPT'替换为'TimeGPT-LongHorizon' timegpt_cv_df_x_long_horizon.columns = timegpt_cv_df_x_long_horizon.columns.str.replace('TimeGPT', 'TimeGPT-LongHorizon') # 将timegpt_cv_df_x_long_horizon与timegpt_cv_df_x进行合并 timegpt_cv_df_x_models = timegpt_cv_df_x_long_horizon.merge(timegpt_cv_df_x) # 获取unique_id为"BE"的数据的cutoff值 cutoffs = timegpt_cv_df_x_models.query('unique_id == "BE"')['cutoff'].unique() # 对每个cutoff值进行循环 for cutoff in cutoffs: # 绘制图形 fig = timegpt.plot( df.query('unique_id == "BE"').tail(24 * 7), # 获取unique_id为"BE"的最后24 * 7个数据 timegpt_cv_df_x_models.query('cutoff == @cutoff & unique_id == "BE"').drop(columns=['cutoff', 'y']), # 获取cutoff和unique_id为"BE"的数据,并删除'cutoff'和'y'列 models=['TimeGPT', 'TimeGPT-LongHorizon'], # 绘制'TimeGPT'和'TimeGPT-LongHorizon'模型的图形 level=[80, 90], # 设置置信水平为80%和90% ) # 显示图形 display(fig)
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。