当前位置:   article > 正文

【阅读论文】TimesNet短期预测的基本流程梳理_m4数据集

m4数据集

目录

前言

一、run.py

1.args

2.train, test

二、TimesNet_M4.sh

三、exp_short_term_forecasting.py

四、train()

五、TimesNet.py

1.Model

2.TimesBlock

3.FFT_for_Period


前言

果然是初入机器学习的新手,对pycharm、pytorch的套路了解得太少,在学习之路上犯了不少错误,走了不少弯路,虽然现在依旧是个新人,但也还是来做个阶段性的总结,也算是成长的证明。还是以TimesNet为例,下面用基于m4数据集(quarterly类别)的short term forecasting程序来做说明。

一、run.py

主程序,主要是两个部分:args和train, test。

1.args

作用:导入基本参数。

主要代码:

  1. import argparse
  2. parser = argparse.ArgumentParser(description='TimesNet')
  3. parser.add_argument('--属性名', type=类型, default=默认值, help='说明')
  4. args = parser.parse_args()
  5. # parser.add_argument中还有一些用的少的参数required, action, nargs

2.train, test

作用:开始训练、测试模型。

主要代码:

  1. setting = '{}_{}_..._{}_{}'.format(args.属性, ..., args.属性)
  2. exp = Exp(args) # 把args传递给exp_short_term_forecasting
  3. exp.train(setting) # 用于给checkpoints命名
  4. exp.test(setting) # 用于给test_results的子文件夹命名
  5. torch.cuda.empty_cache() # 清空显存缓冲区

二、TimesNet_M4.sh

作用:便于预设参数,批量执行程序。

主要代码:

  1. export CUDA_VISIBLE_DEVICES=0
  2. # 使用的显卡序号,个人电脑的主卡多为“0”,服务器可以按需选择
  3. model_name=TimesNet # 模型的名字
  4. python -u run.py \
  5. --参数名 参数值 \ # 提前设置各种所需参数
  6. --model $model_name \ # model的名字已经在上面写了
  7. ... \
  8. --参数名 参数值

三、exp_short_term_forecasting.py

作用:短期预测的主要函数。

主要代码:

  1. class Exp_Short_Term_Forecast(Exp_Basic):
  2. # 基于Exp_Basic而新建的类
  3. def __init__(self, args):
  4. # 初始化
  5. def _build_model(self):
  6. # 选择TimesNet模型,基于pytorch的nn.Module写的
  7. def _get_data(self, flag):
  8. # 读取m4数据,基于pytorch的DataLoader写的
  9. def _select_optimizer(self):
  10. # 选择优化器,直接用pytorch的
  11. def _select_criterion(self, loss_name='MSE'):
  12. # 选择评价标准/结束标准,MSE是直接用pytorch的
  13. def train(self, setting):
  14. # 训练模型
  15. def vali(self, train_loader, vali_loader, criterion):
  16. # 验证模型
  17. def test(self, setting, test=0):
  18. # 测试模型

由于vali()仅用了一次,test()和train()相似度比较高,故下文只解释train()。

四、train()

主要代码:

  1. train_data, train_loader = self._get_data(flag='train')
  2. # 读取数据
  3. # 这里不是很理解,在Dataset_M4类中可以发现,他是先把10w条数据全部读进去,再根据seasonal_patterns进行筛选的,反正程序都是按照季节依次执行的,为啥不直接读取季节的csv文件呢?
  4. early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
  5. # 早停,为了避免过拟合
  6. for epoch in range(self.args.train_epochs):
  7. for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
  8. # batch_x, batch_y是从打乱顺序的m4数据中截取的16个长度为24的序列,其中batch_x是前16个,batch_y是后16个,因此batch_x和batch_y有8个是相同的,但截取序列时,只保证batch_x中的数据有意义(非空),batch_y随意,如Q19657中截取的24个数据
  9. # [[ 1380.0000],[ 1350.0000],[ 1330.0000],[ 1320.0000],[ 1300.0000],[ 1300.0000],
  10. # [ 1280.0000],[ 1280.0000],[ 1280.0000],[ 1280.0000],[ 1260.0000],[ 1260.0000],
  11. # [ 1260.0000],[ 1250.0000],[ 1240.0000],[ 1230.0000],[ 1230.0000],[ 0.0000],
  12. # [ 0.0000],[ 0.0000],[ 0.0000],[ 0.0000],[ 0.0000],[ 0.0000]]
  13. model_optim.zero_grad() # 将网络中的梯度置零
  14. # decoder input
  15. dec_inp = torch.xxxxxx
  16. # dec_inp为后8个变为0的batch_y
  17. outputs = self.model(batch_x, None, dec_inp, None) # 运行TimesNet模型
  18. # size为16x16,实际上只有batch_x有用
  19. outputs = outputs[:, -self.args.pred_len:, f_dim:] # outputs的最后8个是预测值
  20. loss_value, loss_sharpness, loss
  21. train_loss.append(loss.item())
  22. # 计算loss
  23. loss.backward()
  24. model_optim.step()
  25. # pytorch的一套组合拳
  26. # optimizer.zero_grad() 清空过往梯度;
  27. # loss.backward() 反向传播,计算当前梯度;
  28. # optimizer.step() 根据梯度更新网络参数
  29. best_model_path = path + '/' + 'checkpoint.pth'
  30. self.model.load_state_dict(torch.load(best_model_path))
  31. # 读取最佳节点来继续运行,免得发生意外,程序重跑
  32. return self.model

五、TimesNet.py

主要是三个部分:Model, TimesBlock和FFT_for_Period。

1.Model

主要代码:

  1. class Model(nn.Module):
  2. def __init__(self, configs):
  3. def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
  4. # Normalization from Non-stationary Transformer
  5. # Quarterly中,一个train的batch_x为[16,16,1],一个test的batch_x为[1,16,1]
  6. means, stdev
  7. x_enc /= stdev
  8. # 对x_enc做Z-Score标准化
  9. # embedding
  10. enc_out = self.enc_embedding(x_enc, x_mark_enc)
  11. # 升维,增加参数,输入[16,16,1],输出[16,16,64]
  12. enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(0, 2, 1)
  13. # 每个序列都预测了8个值,输出[16,24,64]
  14. # TimesNet
  15. for i in range(self.layer):
  16. # layer=e_layers,Quarterly中e_layers=2,2层TimesBlock处理
  17. enc_out = self.layer_norm(self.model[i](enc_out))
  18. # 跳转至TimesBlock
  19. # 返回结果[16,24,64]
  20. # porject back
  21. dec_out = self.projection(enc_out)
  22. # 将每行64维的数据投影至1维,输出[16,24,1]
  23. # De-Normalization from Non-stationary Transformer
  24. dec_out, stdev, means
  25. # 将结果进行还原,输出[16,24,1]
  26. return dec_out
  27. def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
  28. # 由于模型是基于nn.Module写的,运行模型时默认先执行forward
  29. # 所以用forward选择对应的任务来执行。

2.TimesBlock

主要代码:

  1. class TimesBlock(nn.Module):
  2. def __init__(self, configs):
  3. def forward(self, x):
  4. B, T, N = x.size() # enc_out为[16,24,64]
  5. period_list, period_weight = FFT_for_Period(x, self.k)
  6. # 跳转至FFT_for_Period
  7. # 返回结果5个周期和对应频率的振幅
  8. for i in range(self.k):
  9. # padding
  10. # 根据seq_len+pred_len能否整除period,决定是否补零
  11. # 此处能整除,输出[16,24,64]
  12. # reshape
  13. out = out.reshape
  14. # 2D conv: from 1d Variation to 2d Variation
  15. out = self.conv(out)
  16. # 根据周期,将1维序列变为2维变量
  17. # reshape back
  18. out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
  19. # 将2维变量变回1维序列
  20. res.append(out[:, :(self.seq_len + self.pred_len), :])
  21. res = torch.stack(res, dim=-1)
  22. # 记录残差
  23. # adaptive aggregation
  24. period_weight, res
  25. # 结合振幅计算新的残差
  26. # residual connection
  27. res = res + x
  28. return res

3.FFT_for_Period

主要代码:

  1. def FFT_for_Period(x, k=2):
  2. xf = torch.fft.rfft(x, dim=1) # 傅里叶变换
  3. # find period by amplitudes
  4. frequency_list = abs(xf).mean(0).mean(-1) # 频域求平均振幅
  5. frequency_list[0] = 0 # 第一个频率很高,但没用
  6. _, top_list = torch.topk(frequency_list, k) # 获取最大的k个频率
  7. top_list = top_list.detach().cpu().numpy()
  8. # detach()阻断反传,但数据仍在现存里,cpu无法获取
  9. # cpu()将数据移至cpu,返回值是cpu上的Tensor
  10. # numpy()将cpu上的tensor转为numpy数据,为ndarray类型,返回值为numpy.array()
  11. period = x.shape[1] // top_list # 周期=T/f
  12. return period, abs(xf).mean(-1)[:, top_list] # 返回周期、k个频率的振幅

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/377590
推荐阅读
相关标签
  

闽ICP备14008679号