赞
踩
目录
三、exp_short_term_forecasting.py
果然是初入机器学习的新手,对pycharm、pytorch的套路了解得太少,在学习之路上犯了不少错误,走了不少弯路,虽然现在依旧是个新人,但也还是来做个阶段性的总结,也算是成长的证明。还是以TimesNet为例,下面用基于m4数据集(quarterly类别)的short term forecasting程序来做说明。
主程序,主要是两个部分:args和train, test。
作用:导入基本参数。
主要代码:
- import argparse
- parser = argparse.ArgumentParser(description='TimesNet')
- parser.add_argument('--属性名', type=类型, default=默认值, help='说明')
- args = parser.parse_args()
-
- # parser.add_argument中还有一些用的少的参数required, action, nargs
作用:开始训练、测试模型。
主要代码:
- setting = '{}_{}_..._{}_{}'.format(args.属性, ..., args.属性)
- exp = Exp(args) # 把args传递给exp_short_term_forecasting
- exp.train(setting) # 用于给checkpoints命名
- exp.test(setting) # 用于给test_results的子文件夹命名
- torch.cuda.empty_cache() # 清空显存缓冲区
作用:便于预设参数,批量执行程序。
主要代码:
- export CUDA_VISIBLE_DEVICES=0
- # 使用的显卡序号,个人电脑的主卡多为“0”,服务器可以按需选择
-
- model_name=TimesNet # 模型的名字
-
- python -u run.py \
- --参数名 参数值 \ # 提前设置各种所需参数
- --model $model_name \ # model的名字已经在上面写了
- ... \
- --参数名 参数值
作用:短期预测的主要函数。
主要代码:
- class Exp_Short_Term_Forecast(Exp_Basic):
- # 基于Exp_Basic而新建的类
-
- def __init__(self, args):
- # 初始化
-
- def _build_model(self):
- # 选择TimesNet模型,基于pytorch的nn.Module写的
-
- def _get_data(self, flag):
- # 读取m4数据,基于pytorch的DataLoader写的
-
- def _select_optimizer(self):
- # 选择优化器,直接用pytorch的
-
- def _select_criterion(self, loss_name='MSE'):
- # 选择评价标准/结束标准,MSE是直接用pytorch的
-
- def train(self, setting):
- # 训练模型
-
- def vali(self, train_loader, vali_loader, criterion):
- # 验证模型
-
- def test(self, setting, test=0):
- # 测试模型
由于vali()仅用了一次,test()和train()相似度比较高,故下文只解释train()。
主要代码:
- train_data, train_loader = self._get_data(flag='train')
- # 读取数据
- # 这里不是很理解,在Dataset_M4类中可以发现,他是先把10w条数据全部读进去,再根据seasonal_patterns进行筛选的,反正程序都是按照季节依次执行的,为啥不直接读取季节的csv文件呢?
-
- early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
- # 早停,为了避免过拟合
-
- for epoch in range(self.args.train_epochs):
- for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
- # batch_x, batch_y是从打乱顺序的m4数据中截取的16个长度为24的序列,其中batch_x是前16个,batch_y是后16个,因此batch_x和batch_y有8个是相同的,但截取序列时,只保证batch_x中的数据有意义(非空),batch_y随意,如Q19657中截取的24个数据
- # [[ 1380.0000],[ 1350.0000],[ 1330.0000],[ 1320.0000],[ 1300.0000],[ 1300.0000],
- # [ 1280.0000],[ 1280.0000],[ 1280.0000],[ 1280.0000],[ 1260.0000],[ 1260.0000],
- # [ 1260.0000],[ 1250.0000],[ 1240.0000],[ 1230.0000],[ 1230.0000],[ 0.0000],
- # [ 0.0000],[ 0.0000],[ 0.0000],[ 0.0000],[ 0.0000],[ 0.0000]]
-
- model_optim.zero_grad() # 将网络中的梯度置零
-
- # decoder input
- dec_inp = torch.xxxxxx
- # dec_inp为后8个变为0的batch_y
-
- outputs = self.model(batch_x, None, dec_inp, None) # 运行TimesNet模型
- # size为16x16,实际上只有batch_x有用
-
- outputs = outputs[:, -self.args.pred_len:, f_dim:] # outputs的最后8个是预测值
- loss_value, loss_sharpness, loss
- train_loss.append(loss.item())
- # 计算loss
-
- loss.backward()
- model_optim.step()
- # pytorch的一套组合拳
- # optimizer.zero_grad() 清空过往梯度;
- # loss.backward() 反向传播,计算当前梯度;
- # optimizer.step() 根据梯度更新网络参数
-
- best_model_path = path + '/' + 'checkpoint.pth'
- self.model.load_state_dict(torch.load(best_model_path))
- # 读取最佳节点来继续运行,免得发生意外,程序重跑
-
- return self.model
主要是三个部分:Model, TimesBlock和FFT_for_Period。
主要代码:
- class Model(nn.Module):
- def __init__(self, configs):
-
- def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
- # Normalization from Non-stationary Transformer
- # Quarterly中,一个train的batch_x为[16,16,1],一个test的batch_x为[1,16,1]
-
- means, stdev
- x_enc /= stdev
- # 对x_enc做Z-Score标准化
-
- # embedding
- enc_out = self.enc_embedding(x_enc, x_mark_enc)
- # 升维,增加参数,输入[16,16,1],输出[16,16,64]
-
- enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(0, 2, 1)
- # 每个序列都预测了8个值,输出[16,24,64]
-
- # TimesNet
- for i in range(self.layer):
- # layer=e_layers,Quarterly中e_layers=2,2层TimesBlock处理
-
- enc_out = self.layer_norm(self.model[i](enc_out))
- # 跳转至TimesBlock
- # 返回结果[16,24,64]
-
- # porject back
- dec_out = self.projection(enc_out)
- # 将每行64维的数据投影至1维,输出[16,24,1]
-
- # De-Normalization from Non-stationary Transformer
- dec_out, stdev, means
- # 将结果进行还原,输出[16,24,1]
-
- return dec_out
-
- def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
- # 由于模型是基于nn.Module写的,运行模型时默认先执行forward
- # 所以用forward选择对应的任务来执行。
主要代码:
- class TimesBlock(nn.Module):
- def __init__(self, configs):
-
- def forward(self, x):
- B, T, N = x.size() # enc_out为[16,24,64]
- period_list, period_weight = FFT_for_Period(x, self.k)
- # 跳转至FFT_for_Period
- # 返回结果5个周期和对应频率的振幅
-
- for i in range(self.k):
- # padding
- # 根据seq_len+pred_len能否整除period,决定是否补零
- # 此处能整除,输出[16,24,64]
-
- # reshape
- out = out.reshape
- # 2D conv: from 1d Variation to 2d Variation
- out = self.conv(out)
- # 根据周期,将1维序列变为2维变量
-
- # reshape back
- out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
- # 将2维变量变回1维序列
-
- res.append(out[:, :(self.seq_len + self.pred_len), :])
- res = torch.stack(res, dim=-1)
- # 记录残差
-
- # adaptive aggregation
- period_weight, res
- # 结合振幅计算新的残差
-
- # residual connection
- res = res + x
- return res
主要代码:
- def FFT_for_Period(x, k=2):
- xf = torch.fft.rfft(x, dim=1) # 傅里叶变换
-
- # find period by amplitudes
- frequency_list = abs(xf).mean(0).mean(-1) # 频域求平均振幅
- frequency_list[0] = 0 # 第一个频率很高,但没用
-
- _, top_list = torch.topk(frequency_list, k) # 获取最大的k个频率
- top_list = top_list.detach().cpu().numpy()
- # detach()阻断反传,但数据仍在现存里,cpu无法获取
- # cpu()将数据移至cpu,返回值是cpu上的Tensor
- # numpy()将cpu上的tensor转为numpy数据,为ndarray类型,返回值为numpy.array()
-
- period = x.shape[1] // top_list # 周期=T/f
- return period, abs(xf).mean(-1)[:, top_list] # 返回周期、k个频率的振幅
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。