赞
踩
之前在网上搜索了很多informer的解读文章,但大部分看到的都是文章翻译甚至机翻,看完之后依然对算法原理一头雾水。
Github论文源码
自己看过代码之后想要写一篇详细一点的解读,码在这里也方便自己以后回来翻看。
另外感谢@Error_hxy大佬的帮助!
由于Informer主要是在Transformer上的改进,这里不再赘述Transformer的细节,可以参见另外的博文:
深入理解Transformer及其源码解读
最火的几个全网络预训练模型梳理整合(从ELMO到BERT到XLNET)
那么Informer是做什么的呢?
主要针对长序列预测(Long Sequence Time-series Forecasting, LSTF)
目前Transformre具有较强的捕获长距离依赖的能力,但传统的Transformer依然存在以下不足,因此Informer做出了一些改进。
上面的三个改进猛地一看可能让人摸不着头脑
没关系
我们接着往下看
下面我们对Informer的结构图进行一下简单的解释。
如果看到这里还不太知道Informer是在干啥,那么我们从预处理开始一点一点看起。
上面的
X
t
,
Y
t
X^t,Y^t
Xt,Yt就是模型的输入和输出。
我们以代码中给出的某一个数据为例:
数据介绍:ETDataset(Electricity Transformer Dataset)电力变压器的负荷、油温数据。
ETDataset (github)
(小时维度的数据) 数据规格:17420 * 7
Batch_size:32
下面这张图是ETDataset的一个示例
这里并不是小时维度,而是15min的时间序列数据
那么输入模型的样本大概长什么样子?
X
e
n
c
:
32
×
96
×
7
X_{enc}:32\times96\times7
Xenc:32×96×7
这里的32是批次大小,一个批次有32个样本,一个样本代表96个时间点的数据,如上图
date=2016-07-01 00:00 是一个时间点0的数据
date=2016-07-01 01:00 是时间点1的数据。
那么批次中的样本1:时间点0到时间点95的96个维度为7的数据
批次中的样本2:时间点1到时间点96的96个维度为7的数据
批次中的样本3:时间点2到时间点97的96个维度为7的数据
……
直到取够32个样本,形成一个批次内的所有样本。
X
m
a
r
k
:
32
×
96
×
4
X_{mark}:32\times96\times4
Xmark:32×96×4
这里的4代表时间戳,例如我们用小时维度的数据,那么4分别代表
年、月、日、小时,
第一个时间点对应的时间戳就是[2016, 07, 01, 00],
第二个时间点对应的时间戳就是[2016, 07, 01, 01]
与上面的
X
e
n
c
X_{enc}
Xenc对应得到所有的样本对应的时间戳。
X
e
n
c
:
32
×
72
×
7
X_{enc}:32\times72\times7
Xenc:32×72×7
X
m
a
r
k
:
32
×
72
×
4
X_{mark}:32\times72\times4
Xmark:32×72×4
decoder的输入与encoder唯一不同的就是,每个样本对应时间序列的时间点数量并不是96,而是72。具体在进行截取样本时,从encoder输入的后半段开始取。
即:
encoder的第一个样本:时间点0到时间点95的96条维度为7的数据
那与之对应decoder的:时间点47到时间点95的48条维度为7的数据 + 时间点 95到时间点119的24个时间点的7维数据。
则最终48+24是72维度的数据。画成图大概这个亚子:
上面是encoder的输入
下面是decoder的输入
均只画出时间点数量的那个维度
输入:
x_enc/y_dec
:32 * 96/72 * 7
x_mark /y_mark:
32 * 96/72 *4
输出:
32 * 96/72 * 512
embedding的目的: 嵌入投影,在这里相当于将维度为7的一个时间点的数据投影成维度为512的数据。
整体公式:
等号右边显然可以分成三个部分:
1.输入:对应公式中的
α
u
i
t
\alpha u_i^t
αuit
具体操作为通过conv1D(width=3,stride=1) 映射为D维度(512)
(conv1D是一维卷积)
2.Position Stamp: 对应公式中的
P
E
(
L
x
×
(
t
−
1
)
+
i
)
PE_{(L_x\times(t-1)+i)}
PE(Lx×(t−1)+i)
与Transformer的postition embedding 一模一样
3.Global Time Stamp:对应公式中的 SE
具体操作为全连接层。
长序列预测问题中,需要获取全局信息,比如分层次的时间戳(week, month year),不可知的时间戳(holidays, events).
这里最细的粒度可以自行选择。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。