当前位置:   article > 正文

论文辅助笔记:TEMPO 之 utils.py

论文辅助笔记:TEMPO 之 utils.py

0 导入库

  1. from typing import Tuple
  2. import random
  3. import numpy as np
  4. import torch
  5. from statsmodels.tsa.seasonal import STL

1 EarlyStopping

  • 提供了一个早停机制,用于在模型训练过程中监控验证集上的损失
  • 如果损失停止改进,则停止训练

1.1 __init__

  1. class EarlyStopping:
  2. def __init__(self, patience=7, verbose=False, delta=0):
  3. self.patience = patience
  4. #早停的容忍度,如果连续 patience 次验证损失没有改善,则停止训练。
  5. self.verbose = verbose
  6. #决定是否输出详细信息
  7. self.counter = 0
  8. #记录连续未改善验证损失的次数
  9. self.best_score = None
  10. #用于存储目前为止最佳的验证损失分数
  11. self.early_stop = False
  12. #一个布尔值,指示是否应该停止训练
  13. self.val_loss_min = np.Inf
  14. #存储目前为止最小的验证损失
  15. self.delta = delta
  16. #一个阈值,用于决定损失的改善幅度

1.2 __call__ 在训练过程中监控验证损失

  1. def __call__(self, val_loss, model, path):
  2. score = -val_loss
  3. if self.best_score is None:
  4. self.best_score = score
  5. self.save_checkpoint(val_loss, model, path)
  6. #如果这是第一次调用 __call__,初始化 best_score 为 score 并保存模型。
  7. elif score < self.best_score + self.delta:
  8. self.counter += 1
  9. print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
  10. if self.counter >= self.patience:
  11. self.early_stop = True
  12. '''
  13. 如果 score < self.best_score + self.delta,则说明损失没有显著改善
  14. 增加 counter 并检查是否超过 patience,如果超过则停止训练
  15. '''
  16. else:
  17. self.best_score = score
  18. self.save_checkpoint(val_loss, model, path)
  19. self.counter = 0
  20. '''
  21. 如果 score > self.best_score + self.delta,更新 best_score 并保存模型
  22. 然后将 counter 重置为零
  23. '''

1.3 save_checkpoint 在验证损失降低时保存模型

  1. def save_checkpoint(self, val_loss, model, path):
  2. if self.verbose:
  3. print(
  4. f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..."
  5. )
  6. torch.save(model.state_dict(), path + "/" + "checkpoint.pth")
  7. #使用 torch.save() 保存模型的状态字典
  8. self.val_loss_min = val_loss

2 StandardScaler

实现数据标准化

2.1 __init__

  1. class StandardScaler:
  2.     def __init__(self):
  3.         self.mean = 0.0
  4.         self.std = 1.0

2.2  fit

计算并更新 self.meanself.std

  1. def fit(self, data):
  2.         self.mean = data.mean(0)
  3.         self.std = data.std(0)

 2.3  transform

   将数据转换为标准化形式

  1. def transform(self, data):
  2.         mean = (
  3.             torch.from_numpy(self.mean).type_as(data).to(data.device)
  4.             if torch.is_tensor(data)
  5.             else self.mean
  6.         )
  7.         std = (
  8.             torch.from_numpy(self.std).type_as(data).to(data.device)
  9.             if torch.is_tensor(data)
  10.             else self.std
  11.         )
  12. '''
  13. mean 和 std 的类型转换:
  14. 根据 data 是 torch.Tensor 还是 numpy 数组
  15. 将 self.mean 和 self.std 转换为相应类型,以确保类型匹配
  16. '''
  17.         return (data - mean) / std

 2.4 inverse_transform

标准化后的数据还原

  1.     def inverse_transform(self, data):
  2.         mean = (
  3.             torch.from_numpy(self.mean).type_as(data).to(data.device)
  4.             if torch.is_tensor(data)
  5.             else self.mean
  6.         )
  7.         std = (
  8.             torch.from_numpy(self.std).type_as(data).to(data.device)
  9.             if torch.is_tensor(data)
  10.             else self.std
  11.         )
  12.         '''
  13.         mean 和 std 的类型转换:
  14.             根据 data 是 torch.Tensor 还是 numpy 数组
  15.             将 self.mean 和 self.std 转换为相应类型,以确保类型匹配
  16.         '''
  17.         if data.shape[-1] != mean.shape[-1]:
  18.             mean = mean[-1:]
  19.             std = std[-1:]
  20.         return (data * std) + mean
  21. '''
  22. 通过 (data * std) + mean 将标准化后的数据还原为原始形式
  23. '''

3 decompose

使用STL,将时间序列分解为趋势、季节性和残差成分

  1. def decompose(
  2. x: torch.Tensor, period: int = 7
  3. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  4. #x:输入的一维时间序列,类型为 torch.Tensor,形状为 (1, seq_len)
  5. x = x.squeeze(0).cpu().numpy()
  6. '''
  7. 首先调用 squeeze(0) 将 x 的第一个维度去掉
  8. 然后通过 cpu().numpy() 将 x 转换为 numpy 数组,以便 STL 分解函数使用
  9. '''
  10. decomposed = STL(x, period=period).fit()
  11. '''
  12. 调用 STL(x, period=period).fit() 对 x 进行分解,并返回分解结果 decomposed
  13. 其中包含了 trend(趋势)、seasonal(季节性)和 resid(残差)成分
  14. '''
  15. trend = decomposed.trend.astype(np.float32)
  16. seasonal = decomposed.seasonal.astype(np.float32)
  17. residual = decomposed.resid.astype(np.float32)
  18. '''
  19. 将 decomposed 中的各个成分转换为 numpy 数组,并转为 float32 类型
  20. '''
  21. return (
  22. torch.from_numpy(trend).unsqueeze(0),
  23. torch.from_numpy(seasonal).unsqueeze(0),
  24. torch.from_numpy(residual).unsqueeze(0),
  25. )
  26. '''
  27. 将它们转换为 torch.Tensor
  28. 并使用 unsqueeze(0) 将其包装为 (1, seq_len) 的张量,以匹配输入张量的形状
  29. '''

4 set_seed

为 Python 中的各种随机生成器设置种子

  1. def set_seed(seed):
  2. random.seed(seed)
  3. np.random.seed(seed)
  4. torch.manual_seed(seed)
  5. torch.cuda.manual_seed_all(seed)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/541231
推荐阅读
相关标签
  

闽ICP备14008679号