赞
踩
- import math
- from dataclasses import dataclass, asdict
-
- import torch
- import torch.nn as nn
-
- from src.modules.transformer import Block
- from src.modules.prompt import Prompt
- from src.modules.utils import (
- FlattenHead,
- PoolingHead,
- RevIN,
- )
-
-
- class TEMPOConfig:
- """
- Configuration of a `TEMPO` model.
- Args:
- num_series: 时间序列的数量, N
- input_len: 输入时间序列的长度, L
- pred_len: 预测时间序列的长度, Y
- block_size: 块的最大长度(openai gpt2 固定)
- n_layer: Transformer 层的数量
- n_head: 多头注意力机制中的头数量
- n_embd: 嵌入维度的数量
- patch_size: 块的大小,用于将输入时间序列分割成多个小块
- patch_stride: 块的步幅,用于指定块之间的重叠程度
- revin: 是否使用 RevIN(归一化和逆变换)
- affine: 在 RevIN 中是否使用仿射变换
- embd_pdrop:嵌入层的 dropout 率
- resid_pdrop: 残差连接的 dropout 率
- attn_pdrop: 注意力层的 dropout 率
- head_type: 输出层的类型,可以是 FlattenHead 或 PoolingHead
- head_pdtop: 输出层的 dropout 率
- individual: 是否为每个组件使用独立的输出层
- lora: 是否使用 LoRA(低秩近似)
- lora_config: LoRA 的配置
- model_type: 模型类型,默认为 gpt2
- interpret: 是否输出组件以便解释
- """
-
- num_series: int
- input_len: int
- pred_len: int
- patch_size: int
- patch_stride: int
- block_size: int = None
- n_layer: int = None
- n_head: int = None
- n_embd: int = None
- revin: bool = True
- affine: bool = True
- embd_pdrop: float = 0.1
- resid_pdrop: float = 0.1
- attn_pdrop: float = 0.1
- head_type: str = "flatten"
- head_pdtop: float = 0.1
- individual: bool = False
- lora: bool = False
- lora_config: dict = None
- prompt_config: dict = None
- #Prompt 模块的配置
- model_type: str = "gpt2"
- interpret: bool = False
todict
TEMPOConfig
类实例转换为一个字典
- def todict(self):
- return asdict(self)
-
- '''
- asdict 是 Python 的 dataclasses 模块提供的一个函数,用于将数据类实例转换为字典。
- 这个方法将当前实例的所有属性转换为字典键值对,并返回这个字典。
- '''
__contains__
重载了 Python 的 __contains__
魔术方法,使得 TEMPOConfig
实例可以像字典一样使用 in
操作符来检查属性是否存在。
- def __contains__(self, key):
- return key in self.todict()
__getitem__
重载了 __getitem__
魔术方法,使得 TEMPOConfig
实例可以像字典一样通过键来获取属性值
- def __getitem__(self, key):
- return getattr(self, key)
__setitem__
重载了 __setitem__
魔术方法,使得 TEMPOConfig
实例可以像字典一样通过键来设置属性值
- def __setitem__(self, key, value):
- setattr(self, key, value)
update
通过一个字典 config
更新 TEMPOConfig
实例的属性
- def update(self, config: dict):
- for k, v in config.items():
- setattr(self, k, v)
- class TEMPO(nn.Module):
- """
- Notation:
- B: 批次大小
- N: 时间序列的数量
- E: 嵌入维度
- P: 块的数量
- PS: patch的大小
- L: 输入时间序列的长度
- Y: 预测时间序列的长度
- """
-
- models = ("gpt2",)
- #支持的模型类型列表
-
- head_types = ("flatten", "pooling")
- #支持的输出层类型
-
- params = {
- "gpt2": dict(block_size=1024, n_head=12, n_embd=768),
- }
- '''
- 模型的参数,例如 "gpt2" 模型的块大小、注意力头数和嵌入维度等
- '''
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。