当前位置:   article > 正文

Informer代码详解!

informer csdn

作者:超爱学习的小豆,文章摘自圆圆的算法笔记

  • 知乎链接:

    • https://zhuanlan.zhihu.com/p/646853438

1

前言

Informer是2021年时序领域的best论文,仔细阅读文章和代码会发现这篇文章的思路、论点和代码框架写的是真的好,让人看着赏心悦目,此后的时序预测算法也大多是在Informer的基础上进行的,包括输入输出的格式、特征编码的方式等。

2

论文出发点和创新点

8e54765ff2a39a628083a4721d64396d.png

图1 Informer整体框架图

出发点

Transformer中的自注意力计算是平方复杂度d9bb53cf330d7304d490144b281677e2.png

传统Transformer的Block输入输出的shape不变是柱状形式,J个Block带来的复杂度就是 f7c70906483dfd103a56610fdfcc0372.png,导致模型的输入无法变的过长,限制了时序模型的可扩展性

传统Transformer的Decoder阶段输出是step-by-step,一方面增加了耗时,另一方面也会给模型带来累计误差

创新点

提出了ProbSparse self-attention mechanism的注意力方法,在耗时和内存上都压缩到了 19e1cb6ef9519964e6b7ac3551f5c93e.png的复杂度,就是图1中的1

每个注意力Block之间都添加了“蒸馏”操作,通过将序列的shape减半来突出主要注意力,原始的柱状Transformer变成金字塔形的Transformer,使得模型可以接受更长的序列输入,并且可以降低内存和时间损耗,就是图1中的2。

设计了一个较为简单但是可以一次性输出预测值的Decoder,就是图1中的3和输出那一段黄色的部分。

3

论文细节

作者发现在原始的Transformer中的自注意力机制中,注意力分数呈现长尾分布,也就是只有少部分的点是和别的点直接有强相关性的,如图2

c0db61513259b9628954241ae45cdbaf.png

图2 原始Transformer中不同头中的注意力分数分布

所以在Transformer中如果可以在计算Attention过程中删除那些没有用的query,就可以降低计算量,相关工作也说明了删去部分无用的query不会带来精度的损失,如图3:

e4e6ef76d5067741ad16945f2d75b8b8.png

图3 不少工作说明减少部分点计算注意力不会有精度损失

从而作者要做的就是如何去定义没用的query(称为"Lazy query")以及如何找到这些query。

如何定义和寻找"Lazy query"

作者提出,更加重要的query的分布和均匀分布的差距应该是更大的,所以对每个query求出其概率分布后与均匀分布之间求出的KL散度(常用于衡量两个分布直接的相关性)就可以作为query的重要性,如图4。

8cd9a0cca1eb98b218c87ffeef22a362.png

图4

考虑到如果按照公式去计算各个query和均匀分布的KL散度,这没算注意力之前就要进行大量的计算,无疑是不可行的,作者对这部分计算公式进行了放缩、简化(有兴趣可以看原论文附录里面的证明,这里不细说,会在代码部分详细说下作者怎么做的),图5就是最后给每个query的一个"活跃"分数,这个分数越低则越可能是"Lazy"query(这个公式里很多计算是可以复用的,所以相对而言降低了计算量)。

aee02a7a6bbe400533306c7b09bc7db1.png

图5 query的活跃公式

4

代码解析

代码执行过程如下:

1. 随机采样部分的key,默认为a341894cb4de240c1abcee4f0e556648.png量纲(L为序列长度)

2. 根据采样的部分key,让每个query和key作用,计算每个query的活跃得分d35aaafef41ab30029ef156b22d897aa.png

3. 选择活跃得分较高的N个query,N默认为 7b3652746ea33415c5cd88c1a879b464.png量纲

4. 用这个N个query去和所有的key计算Attention

5. 其余的L-N个query不参与计算,直接取均值,保证输入和输出的shape不变

我们要定位到ProbAttention这个类

15ace402fd4c2b72ba4841fe8f2c34da.png

图6

如图6,最开始做了一些初始化操作,U_part和u分别对应我上面说的步骤中采样key和采样query的个数,可以看到有factor这个参数控制,采样的个数为factor * In L。

接着调用函数_prob_QK得到最活跃的query和全部key的点积结果(Attention的公式是先点击之后除以根号d做归一化,基础)以及这些query对应的索引(看图7,做了很详细的注释)

  1. scores_top, index = self._prob_QK(
  2. queries, keys, sample_k=U_part, n_top=u)

be94165798e9d0dd36cc1fe1dec5e078.png

图7 概率稀疏注意力计算过程

回到Forward里来,接下来就是对上面计算的点积进行归一化

  1. scale = self.scale or 1. / sqrt(D)
  2. if scale is not None:
  3. scores_top = scores_top * scale # attention公式中除以分母归一化

之后进入到函数_get_initial_context中,计算每个query的均值,相当于最开始所有的query和所有的key之间的Attention分数都用均值初始化,之后再根据上面的分数和采样的query的索引将里面活跃的query对应的Attention分数进行替换

  1. # get the context
  2. context = self._get_initial_context(values, L_Q) # (batch_size, n_heads, seq_len_q, dim_qkv)

acb645d71e82fe71cae08814fb476411.png

图7 用均值初始化所有query和key的Attention分数

之后就是使用_update_context函数,把活跃的query部分的Attention分数替换掉(看图8,注释了)

  1. # update the context with selected top_k queries
  2. context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)

9bea74cbd1d56d19388815a159a4567a.png

图8

如图9,Informer的Encoder由e_layers个EncoderLayer、e_layers-1个ConvLayer(下采样的)和一个LayerNorm组成,而每个EncoderLayer是由概率稀疏注意力组成的AttentionLayer组成,整体结构非常简单,代码写的很好看。

efd7033738a813a28d2a93a43f138a2b.png

图9 Informer的Encoder

概率稀疏注意力部分上面说过了,可以看看这里的ConvLayer,它由参数distil决定,如果这个参数设置为False,则整个过程就没有下采样。我们看下这个类的代码,本质上就是个最大池化的下采样,把seq_len那个维度减半了

fd6759246a20c8648574ba554258ebfd.png

图10 Informer中的下采样操作

图11就是Encoder的整体结构,对着代码看很清晰

b3d44ba4dda7471fa80d01b087d36161.png

图11

如图12,Informer中的Decoder由d_layers个DecoderLayer、一个LayerNorm和一个Linear组成(负责转换到pred_len维度),DecoderLayer中第一个AttentionLayer是概率稀疏自注意力,第二个AttentionLayer对应的就是交叉的全注意力(Encoder的输出作为key和value,上一层Decoder的输出作为query)。

417cd5a436243f86e33cc3cc29dd7297.png

图12

图13中给出的是Decoder的结构

c2f03bbc213639888d5806350593a9c2.png

图13 Informer中Decoder的结构

可以看到,Decoder的输入有两部分组成,图中第一部分对应的长度是label_len,代表的是Encoder中后label_lend数据;第二部分的长度是pred_len,可以看到Decoder输入的时候,这部分的数据是用0占位的,我们可以看图14中在如何产生Decoder的数据的:

0152ccf18f5ab485ff638b7f971c29e8.png

图14

我们看看整体数据经过Informer的过程,如图15,每个过程我加了shape的注释

5940686974437387c709ed9d509b9bb1.png

图15

可以发现Encoder和Decoder的输入都会经过一个embedding,这个embedding包含三个部分,输入特征序列的Token Embedding+保证时间序列有序性的position_embedding,以及每个时间点的temporal_embedding(比如今天是7月30号,这个日期包含信息一个月的第几天,周几,是否月末... 作者将这些信息进行了编码,后续的时序预测基本都用了类似的编码)。如图16,加了注释

b62c5481764a4a16ed9dbbd6a5647cfd.png

图16

5

个人使用感受

时间复杂度:Informer跑起来还是很快的,和后续几个论文中自称更快的模型对比来说,Informer的效率是最高的主要还是得益于它计算Attention的时候放弃了一堆query且有个下采样的过程

Informer我最开始使用的时候,效果死活很垃圾,之后我把中间的下采样删去了,效果一下子变好了。分析因为我们输入的seq_len不大,本身就是概率稀疏注意力,再不断的下采样,最后用到的query也没几个了,这肯定是有问题的。

Informer主要可以调的一些参数:factor、seq_len, d_model,其余效果都一般。

很多时序模型解决的都是降低时间复杂度和内存消耗上,本质上还是在用Transformer去做时序任务,用到自己的场景时,可能需要进行一些结构上的改变,我们可以保留Informer中比如概率稀疏注意力这种高效的提取时序信息的模块,增加更多有自己业务场景的结构,看大家自己的设计喽。

推荐阅读:

我的2022届互联网校招分享

我的2021总结

浅谈算法岗和开发岗的区别

互联网校招研发薪资汇总

2022届互联网求职现状,金9银10快变成铜9铁10!!

公众号:AI蜗牛车

保持谦逊、保持自律、保持进步

2821a7f3e21e05b75baf6bd6d10545d6.jpeg

发送【蜗牛】获取一份《手把手AI项目》(AI蜗牛车著)

发送【1222】获取一份不错的leetcode刷题笔记

发送【AI四大名著】获取四本经典AI电子书

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/86298
推荐阅读
相关标签
  

闽ICP备14008679号