当前位置:   article > 正文

论文辅助笔记:TEMPO 之 dataset.py

论文辅助笔记:TEMPO 之 dataset.py

0 导入库

  1. import os
  2. import pandas as pd
  3. import torch
  4. from torch.utils.data import Dataset
  5. from .utils import StandardScaler, decompose
  6. from .features import time_features

1 Dataset_ETT_hour

1.1 构造函数

  1. class Dataset_ETT_hour(Dataset):
  2. def __init__(
  3. self,
  4. root_path,
  5. flag="train",
  6. size=None,
  7. features="S",
  8. data_path="ETTh1.csv",
  9. target="OT",
  10. scale=True,
  11. inverse=False,
  12. timeenc=0,
  13. freq="h",
  14. cols=None,
  15. period=24,
  16. ):
  17. if size == None:
  18. self.seq_len = 24 * 4 * 4
  19. self.pred_len = 24 * 4
  20. else:
  21. self.seq_len = size[0]
  22. self.pred_len = size[1]
  23. #输入sequence和输出sequence的长度
  24. assert flag in ["train", "test", "val"]
  25. type_map = {"train": 0, "val": 1, "test": 2}
  26. self.set_type = type_map[flag]
  27. '''
  28. 指定数据集的用途,可以是 "train"、"test" 或 "val",分别对应训练集、测试集和验证集
  29. '''
  30. self.features = features
  31. #指定数据集包含的特征类型,默认为 "S",表示单一特征
  32. self.target = target
  33. #指定预测的目标特征
  34. self.scale = scale
  35. #一个布尔值,用于确定数据是否需要归一化处理
  36. self.inverse = inverse
  37. #一个布尔值,用于决定是否进行逆变换
  38. self.timeenc = timeenc
  39. #用于确定是否对时间进行编码【原始模样 or -0.5~0.5区间】
  40. self.freq = freq
  41. #定义时间序列的频率,如 "h" 表示小时级别的频率
  42. self.period = period
  43. #定义时间序列的周期,默认为 24
  44. self.root_path = root_path
  45. self.data_path = data_path
  46. self.__read_data__()
  47. #用于读取并初始化数据集

1.2 __read_data__

  1. def __read_data__(self):
  2. self.scaler = StandardScaler()
  3. #初始化一个 StandardScaler 对象,用于数据的标准化处理
  4. df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))
  5. #读取数据集文件,将其存储为 DataFrame 对象 df_raw
  6. border1s = [
  7. 0,
  8. 12 * 30 * 24 - self.seq_len,
  9. 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len,
  10. ]
  11. #定义了三个区间的起始位置,分别对应训练集、验证集和测试集
  12. border2s = [
  13. 12 * 30 * 24,
  14. 12 * 30 * 24 + 4 * 30 * 24,
  15. 12 * 30 * 24 + 8 * 30 * 24,
  16. ]
  17. #定义了每个区间的结束位置
  18. border1 = border1s[self.set_type]
  19. border2 = border2s[self.set_type]
  20. '''
  21. 通过 self.set_type 确定当前数据集类型
  22. 并从 border1s 和 border2s 中获取对应的起始和结束位置 border1 和 border2
  23. '''
  24. if self.features == "M" or self.features == "MS":
  25. cols_data = df_raw.columns[1:]
  26. df_data = df_raw[cols_data]
  27. elif self.features == "S":
  28. df_data = df_raw[[self.target]]
  29. '''
  30. 选择特征数据:
  31. 多特征 "M" 或 "MS":选择所有数据列,除去日期列。
  32. 单一特征 "S":只选择目标特征列(由 self.target 指定)。
  33. '''
  34. if self.scale:
  35. train_data = df_data[border1s[0] : border2s[0]]
  36. self.scaler.fit(train_data.values)
  37. data = self.scaler.transform(df_data.values)
  38. else:
  39. data = df_data.values
  40. '''
  41. 如果 self.scale 为 True,则执行数据归一化:
  42. train_data:选择训练集的数据,用于拟合 self.scaler。
  43. data:对整个 df_data 进行转换。
  44. '''
  45. df_stamp = df_raw[["date"]][border1:border2]
  46. df_stamp["date"] = pd.to_datetime(df_stamp.date)
  47. data_stamp = time_features(df_stamp, timeenc=self.timeenc, freq=self.freq)
  48. '''
  49. 时间特征处理:
  50. 提取日期列 df_stamp,并将其转换为时间特征:
  51. pd.to_datetime:将日期转换为 datetime 对象。
  52. time_features:用于生成时间特征。
  53. '''
  54. self.data_x = data[border1:border2]
  55. if self.inverse:
  56. self.data_y = df_data.values[border1:border2]
  57. else:
  58. self.data_y = data[border1:border2]
  59. self.data_stamp = data_stamp
  60. '''
  61. 将转换后的数据和时间特征赋值给 self.data_x、self.data_y 和 self.data_stamp:
  62. self.data_x 取 data 中的对应区间数据。
  63. self.data_y 根据 self.inverse 决定是从 data 还是 df_data 中获取。
  64. self.data_stamp 取生成的时间特征。
  65. '''

1.3 __getitem__

  1. def __getitem__(self, index):
  2.         s_begin = index
  3. #设置序列的起始点
  4.         s_end = s_begin + self.seq_len
  5. #计算序列的结束点
  6.         r_begin = s_end
  7. #设置预测序列的起始点
  8.         r_end = r_begin + self.pred_len
  9. #计算预测序列的结束点
  10.         seq_x = self.data_x[s_begin:s_end]
  11. #从 data_x 中提取序列部分
  12.         seq_y = self.data_y[r_begin:r_end]
  13. # 从 data_y 中提取预测部分[ground-truth]
  14.         x = torch.tensor(seq_x, dtype=torch.float).transpose(1, 0)  # [1, seq_len]
  15.         y = torch.tensor(seq_y, dtype=torch.float).transpose(1, 0)  # [1, pred_len]
  16.         (trend, seasonal, residual) = decompose(x, period=self.period)
  17. #对序列 x 进行时间序列分解,返回趋势、季节性和残差三部分
  18.         components = torch.cat((trend, seasonal, residual), dim=0)  # [3, seq_len]
  19. #将分解后的三部分按 0 维(纵向)拼接,形成一个包含三种特征的张量
  20.         return components, y

1.3__len__

  1.     def __len__(self):
  2.         return len(self.data_x) - self.seq_len - self.pred_len + 1

1.4  inverse_transform

将数据进行逆转换,还原到原始尺度

  1.     def inverse_transform(self, data):
  2.         return self.scaler.inverse_transform(data)

2 Dataset_ETT_minute

基本上和hour 的一样,几个地方不一样:

  • __init__
    • data_path="ETTm1.csv",
    • freq="t",
    • period: int = 60,
  • __read_data__
      1. border1s = [
      2. 0,
      3. 12 * 30 * 24 * 4 - self.seq_len,
      4. 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len,
      5. ]
      6. border2s = [
      7. 12 * 30 * 24 * 4,
      8. 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4,
      9. 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4,
      10. ]

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/541235
推荐阅读
相关标签
  

闽ICP备14008679号