当前位置:   article > 正文

论文辅助笔记:Tempo 之 model.py

论文辅助笔记:Tempo 之 model.py

 0 导入库

  1. import math
  2. from dataclasses import dataclass, asdict
  3. import torch
  4. import torch.nn as nn
  5. from src.modules.transformer import Block
  6. from src.modules.prompt import Prompt
  7. from src.modules.utils import (
  8. FlattenHead,
  9. PoolingHead,
  10. RevIN,
  11. )

1TEMPOConfig

1.1 构造函数

  1. class TEMPOConfig:
  2. """
  3. Configuration of a `TEMPO` model.
  4. Args:
  5. num_series: 时间序列的数量, N
  6. input_len: 输入时间序列的长度, L
  7. pred_len: 预测时间序列的长度, Y
  8. block_size: 块的最大长度(openai gpt2 固定)
  9. n_layer: Transformer 层的数量
  10. n_head: 多头注意力机制中的头数量
  11. n_embd: 嵌入维度的数量
  12. patch_size: 块的大小,用于将输入时间序列分割成多个小块
  13. patch_stride: 块的步幅,用于指定块之间的重叠程度
  14. revin: 是否使用 RevIN(归一化和逆变换)
  15. affine: 在 RevIN 中是否使用仿射变换
  16. embd_pdrop:嵌入层的 dropout 率
  17. resid_pdrop: 残差连接的 dropout 率
  18. attn_pdrop: 注意力层的 dropout 率
  19. head_type: 输出层的类型,可以是 FlattenHead 或 PoolingHead
  20. head_pdtop: 输出层的 dropout 率
  21. individual: 是否为每个组件使用独立的输出层
  22. lora: 是否使用 LoRA(低秩近似)
  23. lora_config: LoRA 的配置
  24. model_type: 模型类型,默认为 gpt2
  25. interpret: 是否输出组件以便解释
  26. """
  27. num_series: int
  28. input_len: int
  29. pred_len: int
  30. patch_size: int
  31. patch_stride: int
  32. block_size: int = None
  33. n_layer: int = None
  34. n_head: int = None
  35. n_embd: int = None
  36. revin: bool = True
  37. affine: bool = True
  38. embd_pdrop: float = 0.1
  39. resid_pdrop: float = 0.1
  40. attn_pdrop: float = 0.1
  41. head_type: str = "flatten"
  42. head_pdtop: float = 0.1
  43. individual: bool = False
  44. lora: bool = False
  45. lora_config: dict = None
  46. prompt_config: dict = None
  47. #Prompt 模块的配置
  48. model_type: str = "gpt2"
  49. interpret: bool = False

1.2  todict

TEMPOConfig 类实例转换为一个字典

  1. def todict(self):
  2. return asdict(self)
  3. '''
  4. asdict 是 Python 的 dataclasses 模块提供的一个函数,用于将数据类实例转换为字典。
  5. 这个方法将当前实例的所有属性转换为字典键值对,并返回这个字典。
  6. '''

1.3 __contains__

重载了 Python 的 __contains__ 魔术方法,使得 TEMPOConfig 实例可以像字典一样使用 in 操作符来检查属性是否存在。

  1. def __contains__(self, key):
  2. return key in self.todict()

1.4 __getitem__

重载了 __getitem__ 魔术方法,使得 TEMPOConfig 实例可以像字典一样通过键来获取属性值

  1. def __getitem__(self, key):
  2. return getattr(self, key)

1.5__setitem__

重载了 __setitem__ 魔术方法,使得 TEMPOConfig 实例可以像字典一样通过键来设置属性值

  1. def __setitem__(self, key, value):
  2. setattr(self, key, value)

1.6 update

通过一个字典 config 更新 TEMPOConfig 实例的属性

  1. def update(self, config: dict):
  2. for k, v in config.items():
  3. setattr(self, k, v)

2 TEMPO

  1. class TEMPO(nn.Module):
  2. """
  3. Notation:
  4. B: 批次大小
  5. N: 时间序列的数量
  6. E: 嵌入维度
  7. P: 块的数量
  8. PS: patch的大小
  9. L: 输入时间序列的长度
  10. Y: 预测时间序列的长度
  11. """
  12. models = ("gpt2",)
  13. #支持的模型类型列表
  14. head_types = ("flatten", "pooling")
  15. #支持的输出层类型
  16. params = {
  17. "gpt2": dict(block_size=1024, n_head=12, n_embd=768),
  18. }
  19. '''
  20. 模型的参数,例如 "gpt2" 模型的块大小、注意力头数和嵌入维度等
  21. '''

2.1 __init__

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/541233
推荐阅读
相关标签
  

闽ICP备14008679号