当前位置:   article > 正文

Attention as an RNN

attention as an rnn

摘要

https://arxiv.org/pdf/2405.13956
Transformer的出现标志着序列建模领域的一个重大突破,它提供了一种高性能架构,能够充分利用GPU并行性。然而,Transformer在推理时计算成本高昂,限制了其应用,特别是在资源受限的环境中(例如,移动和嵌入式设备)。针对这个问题,我们(1)首先展示了注意力机制可以视为一种特殊的循环神经网络(RNN),能够高效地计算其多对一RNN输出。然后(2)我们展示了基于注意力的流行模型(如Transformer)可以视为RNN的变体。然而,与传统RNN(如LSTM)不同,这些模型无法高效地更新新的标记,这是序列建模中的一个重要属性。为了解决这个问题,我们(3)基于并行前缀扫描算法,介绍了一种新的高效计算注意力多对多RNN输出的方法。基于新的注意力形式化,我们(4)引入了Aaren,一个基于注意力的模块,它不仅可以(i)像Transformer一样并行训练,而且(ii)能够像传统RNN一样高效地更新新的标记,进行推理时仅需常数内存。在经验上,我们展示了Aaren在跨四个流行的序列问题设置(强化学习、事件预测、时间序列分类和时间序列预测任务)的38个数据集上实现了与Transformer相当的性能,同时更加节省时间和内存。

1、引言

序列建模领域的进步具有深远的影响,其应用范围广泛,包括强化学习(例如,机器人技术和自动驾驶)、时间序列分类(例如,金融欺诈检测和医疗诊断)以及时间序列预测(例如,天气和能源消耗预测)。在过去的几年里,基于Transformer的模型(Vaswani et al., 2017)在序列建模中是一个广泛研究的主题。这是因为Transformer强大的性能和利用GPU并行性的能力。因此,针对各种序列设置(如强化学习、时间序列和语音处理)已经撰写了众多关于Transformer的综述论文(Agarwal et al., 2023; Li et al., 2023; Lin et al., 2022; Jiang et al., 2024; Latif et al., 2023)。

随着电池供电设备等低资源领域的迅速增长,这些领域中部署的模型必须计算高效。然而,Transformer在这方面是昂贵的,因为它们的内存和计算成本随着序列长度的平方而增长。尽管可以使用诸如KV缓存(Pope et al., 2023)等技术来提高推理时的效率,但由于Transformer需要(1)与标记数量成线性关系的内存,以及(2)将之前所有标记的缓存传递给模型,因此它们在低资源领域仍然很昂贵。这些限制在具有长上下文(即大量标记)的设置中更为严重,这些设置在时间序列(例如,气候和经济)中很常见。

为了解决这个问题,我们首先审视注意力机制,这是导致Transformer二次计算复杂度的组件。(1)我们展示了注意力可以被视为一种特殊的循环神经网络(RNN),具有高效地计算其多对一RNN输出的能力。利用注意力的RNN形式化,(2)我们表明流行的基于注意力的模型(例如,Transformer和Perceiver)可以被视为RNN。然而,与LSTM(Hochreiter和Schmidhuber, 1997)和GRU(Cho et al., 2014)等传统RNN不同,这些基于注意力的模型无法随着新标记的出现进行高效的更新,这是在处理数据流形式的序列问题设置中非常重要的属性。

为了解决这个问题,(3)我们基于并行前缀扫描算法(Blelloch, 1990)引入了一种新的注意力形式化,它高效地计算了注意力的多对多RNN输出。在这个新的注意力形式化的基础上,(4)我们介绍了Aaren(Attention as a recurrent neural network),一个计算高效的模块,它不仅可以(i)像Transformer一样并行训练,而且(ii)可以随着新标记的出现进行高效的更新,推理时仅需要常数内存(就像传统RNN一样)。在经验上,我们展示了Aaren在跨越四个流行的序列数据设置的38个数据集上实现了与Transformer相当的性能,这些设置包括强化学习、事件预测、时间序列分类和时间序列预测任务,同时在时间和内存效率上更胜一筹。

2. 背景

2.1 循环神经网络

循环神经网络(RNNs)是专门用于序列建模的模型。简而言之,RNNs通过迭代计算隐藏状态来处理序列数据,如下所示:

h t = f θ ( h t − 1 , x t ) h_{t} = f_{\theta}(h_{t-1}, x_{t}) ht=fθ(ht1,xt)

其中, t t t 表示步骤索引, x x x 表示一个标记, h h h 表示隐藏状态, f θ f_{\theta} fθ 是一个由 θ \theta θ 参数化的神经网络。初始隐藏状态 h 0 h_{0} h0 的值通常是通过反向传播学习的。流行的RNNs如长短期记忆网络(LSTMs)(Hochreiter和Schmidhuber, 1997)和门控循环单元(GRUs)(Cho et al., 2014)将隐藏状态 h t h_{t} ht 的大小固定为一个常数,与步骤索引无关。因此,这些模型在测试时非常高效,每个标记只需要常数内存和时间,并且可以很容易地随着新标记的出现进行更新,这是序列建模中的一个重要属性。然而,由于设计上的迭代性,这些流行的RNNs也因其缺乏可并行性而面临可扩展性问题。因此,RNNs被基于注意力的可并行化模块——Transformer(Vaswani et al., 2017)——所取代,成为许多序列建模设置的标准。

2.2 注意力机制

注意力机制从一组给定的查询标记 X Q X_{Q} XQ中检索来自一组上下文标记 X C X_{C} XC的信息,其方式如下:

Attention ( Q , K , V ) = softmax ⁡ ( Q K T ) V \text{Attention}(Q, K, V) = \operatorname{softmax}\left(QK^{T}\right)V Attention(Q,K,V)=softmax(QKT)V

其中, Q = X Q W q Q = X_{Q}W_{q} Q=XQWq是查询矩阵, K = X C W k K = X_{C}W_{k} K=XCWk是键矩阵, V = X C W v V = X_{C}W_{v} V=XCWv是值矩阵。 W q , W k , W v ∈ R d × d W_{q}, W_{k}, W_{v} \in \mathbb{R}^{d \times d} Wq,Wk,WvRd×d是权重矩阵(学习参数)。 softmax ⁡ ( Q K T ) \operatorname{softmax}\left(QK^{T}\right) softmax(QKT)计算上下文标记的权重,用于加权平均。值得注意的是,与RNNs不同,注意力机制并不是为了迭代而设计的;相反,它旨在轻松利用GPU并行性。Transformer(Vaswani et al., 2017)使用了自注意力 1 {}^{1} 1,这是注意力机制的一个特例,其中查询标记与上下文标记相同。然而,自注意力机制相对于标记数量需要二次计算,并且不能有效地随着新标记的出现进行更新。因此,Transformers计算成本高昂,限制了它们在资源有限领域的应用。

3、方法

为了解决这个问题,我们提出了一个基于注意力的高效模块,该模块能够利用GPU并行性,同时又能有效地进行更新。我们首先在第3.1节中展示了注意力可以被视为一种具有特殊能力的RNN,即它能够高效地计算其多对一RNN(如图1a所示)的输出。利用注意力的RNN表示形式,我们进一步展示了流行的基于注意力的模型,如Transformer(如图1b所示)和Perceiver(如图1c所示),也可以被视为RNN。然而,与传统的RNN不同,这些模型无法有效地使用新标记进行自我更新,这限制了它们在数据以流的形式到达的序列问题设置中的潜力。

为了解决这个问题,我们在第3.2节中介绍了一种基于并行前缀扫描算法的高效方法来计算作为多对多RNN的注意力。基于这一点,我们在第3.3节中引入了Aaren([A]ttention [a]s a [re]current neural [n]etwork),一个计算高效的模块,它不仅可以(i)像Transformer一样并行训练,还可以(ii)在推理时有效地使用新标记进行更新,仅需要常数内存来进行推理(像传统RNN一样)。
在这里插入图片描述

3.1 注意力作为(多对一)RNN

对于查询向量 q q q的注意力可以被视为一个函数,该函数通过 N N N个上下文标记 x 1 : N x_{1:N} x1:N及其键和值 { ( k i , v i ) } i = 1 N \left\{\left(k_{i}, v_{i}\right)\right\}_{i=1}^{N} {(ki,vi)}i=1N映射到一个单一的输出 o N = Attention ( q , k 1 : N , v 1 : N ) o_{N}=\text{Attention}\left(q, k_{1:N}, v_{1:N}\right) oN=Attention(q,k1:N,v1:N)。给定 s i = dot ( q , k i ) s_{i}=\text{dot}\left(q, k_{i}\right) si=dot(q,ki),输出 o N o_{N} oN可以表示为:

o N = ∑ i = 1 N softmax ( s ) i v i = ∑ i = 1 N exp ⁡ ( s i ) v i ∑ i = 1 N exp ⁡ ( s i ) = a ^ N c ^ N o_{N}=\sum_{i=1}^{N} \text{softmax}(s)_{i} v_{i}=\frac{\sum_{i=1}^{N} \exp \left(s_{i}\right) v_{i}}{\sum_{i=1}^{N} \exp \left(s_{i}\right)}=\frac{\hat{a}_{N}}{\hat{c}_{N}} oN=i=1Nsoftmax(s)ivi=i=1Nexp(si)i=1Nexp(si)vi=c^Na^N

其中分子是 a ^ N = ∑ i = 1 N exp ⁡ ( s i ) v i \hat{a}_{N}=\sum_{i=1}^{N} \exp \left(s_{i}\right) v_{i} a^N=i=1Nexp(si)vi,分母是 c ^ N = ∑ i = 1 N exp ⁡ ( s i ) \hat{c}_{N}=\sum_{i=1}^{N} \exp \left(s_{i}\right) c^N=i=1Nexp(si)。将注意力视为RNN,我们可以迭代地计算分子和分母作为累积和 a ^ k = a ^ k − 1 + exp ⁡ ( s k ) v k \hat{a}_{k}=\hat{a}_{k-1}+\exp \left(s_{k}\right) v_{k} a^k=a^k1+exp(sk)vk c ^ k = c ^ k − 1 + exp ⁡ ( s k ) \hat{c}_{k}=\hat{c}_{k-1}+\exp \left(s_{k}\right) c^k=c^k1+exp(sk),其中 k = 1 , … , N k=1, \ldots, N k=1,,N。然而,在实践中,这是一个不稳定的实现,由于有限的精度表示和可能非常小或非常大的指数(即, exp ⁡ ( s ) \exp (s) exp(s))而遇到数值问题。为了缓解这个问题,我们使用累积最大值项 m k = max ⁡ i ∈ { 1 , … , k } s i m_{k}=\max _{i \in\{1, \ldots, k\}} s_{i} mk=maxi{1,,k}si来重写递归,而是计算

a k = ∑ i = 1 k exp ⁡ ( s i − m k ) v i a_{k}=\sum_{i=1}^{k} \exp \left(s_{i}-m_{k}\right) v_{i} ak=i=1kexp(simk)vi c k = ∑ i = 1 k exp ⁡ ( s i − m k ) c_{k}=\sum_{i=1}^{k} \exp \left(s_{i}-m_{k}\right) ck=i=1kexp(simk)

值得注意的是,最终结果是相同的 o N = a ^ N c ^ N = a N c N o_{N}=\frac{\hat{a}_{N}}{\hat{c}_{N}}=\frac{a_{N}}{c_{N}} oN=c^Na^N=cNaN。因此, a k a_{k} ak c k c_{k} ck m k m_{k} mk按照以下方式递归计算:

a k = a k − 1 exp ⁡ ( m k − 1 − m k ) + v k exp ⁡ ( s k − m k ) c k = c k − 1 exp ⁡ ( m k − 1 − m k ) + exp ⁡ ( s k − m k ) m k = max ⁡ ( m k − 1 , s k )

ak=ak1exp(mk1mk)+vkexp(skmk)ck=ck1exp(mk1mk)+exp(skmk)mk=max(mk1,sk)
akckmk=ak1exp(mk1mk)+vkexp(skmk)=ck1exp(mk1mk)+exp(skmk)=max(mk1,sk)

通过将 a k a_{k} ak c k c_{k} ck m k m_{k} mk a k − 1 a_{k-1} ak1 c k − 1 c_{k-1} ck1 m k − 1 m_{k-1} mk1的递归计算封装起来,我们引入了一个RNN单元,该单元迭代地计算注意力的输出(见图2)。注意力的RNN单元以 ( a k − 1 , c k − 1 , m k − 1 , q ) \left(a_{k-1}, c_{k-1}, m_{k-1}, q\right) (ak1,ck1,mk1,q)作为输入,并计算 ( a k , c k , m k , q ) \left(a_{k}, c_{k}, m_{k}, q\right) (ak,ck,mk,q)。注意,查询向量 q q q在RNN单元中被传递。注意力的RNN的初始隐藏状态是 ( a 0 , c 0 , m 0 , q ) = ( 0 , 0 , 0 , q ) \left(a_{0}, c_{0}, m_{0}, q\right) = (0,0,0, q) (a0,c0,m0,q)=(0,0,0,q)
在这里插入图片描述

计算注意力的方法。通过将注意力视为RNN,我们可以看到计算注意力的不同方式:(1) 逐个标记(即顺序地)以 O ( 1 ) O(1) O(1)内存递归地计算,或(2) 以传统方式(即并行地)需要线性 O ( N ) O(N) O(N)内存。由于注意力可以被视为RNN,传统计算注意力的方法也可以被视为高效地计算注意力的多对一RNN输出,即一个RNN的输出,该RNN将多个上下文标记作为输入,但只在RNN结束时输出一个标记(见图1a)。最后,除了完全顺序或完全并行之外,我们还可以将注意力计算为(3) 一个RNN,该RNN以块为单位处理标记,需要 O ( b ) O(b) O(b)内存,其中 b b b是块的大小。然而,这种方法超出了本文的范围。因此,块到块的RNN的描述被包含在附录A中。

将现有的基于注意力的模型视为RNN。通过将注意力视为RNN,现有的基于注意力的模型也可以被视为RNN的变体。例如,Transformer的自注意力是RNN(图1b),其上下文标记作为初始隐藏状态。Perceiver的交叉注意力是RNN(图1c),其上下文依赖的潜在变量作为初始隐藏状态。通过利用这些模型的注意力机制的RNN形式,这些现有模型可以高效地计算其输出内存。

将注意力视为现有模型的RNN的挑战。然而,当将现有的基于注意力的模型(如Transformer)视为RNN时,这些模型缺乏传统RNN(如LSTM和GRU)所共有的重要属性。特别地,LSTM和GRU能够以 O ( 1 ) O(1) O(1)常数内存和计算量有效地使用新标记来更新自己,这是序列建模中一个重要的特性,因为在序列建模中数据是以流的形式接收的。相比之下,Transformer的RNN视图(见图1b)将通过为新的标记添加一个新的RNN并以该新标记作为初始状态来处理新标记。这个新的RNN需要处理所有先前的标记,这需要 O ( N ) O(N) O(N)的线性计算量,其中 N N N是标记的数量。在Perceiver中,由于架构的原因,潜变量(图1c中的 L i L_{i} Li)是输入依赖的,这意味着当接收到新标记时它们的值会改变。由于Perceiver的RNN的初始隐藏状态(即潜变量)会改变,因此Perceiver需要从头开始重新计算其RNN,这需要 O ( N L ) O(NL) O(NL)的线性计算量,其中 N N N是标记的数量, L L L是潜变量的数量。

在这里插入图片描述

3.2 注意力作为(多对多)RNN

为了克服这些限制,我们提出了一个基于注意力的模型,该模型能够利用RNN公式的能力来进行高效的更新。为了实现这一点,我们首先引入了一种高效并行的注意力计算方法,即将注意力作为多对多RNN进行计算,即并行计算 { o i = Attention ( q , x 1 : i ) } i = 1 N \left\{o_{i}=\text{Attention}\left(q, x_{1: i}\right)\right\}_{i=1}^{N} {oi=Attention(q,x1:i)}i=1N。为了实现这一点,我们利用了并行前缀扫描算法(Blelloch, 1990)(见算法1),这是一种通过结合运算符 ⊕ \oplus N N N个顺序数据点计算 N N N个前缀计算的并行计算方法。该算法能够高效地根据 { x k } k = 1 N \left\{x_{k}\right\}_{k=1}^{N} {xk}k=1N计算 { ⨁ i = 1 k x i } k = 1 N \left\{\bigoplus_{i=1}^{k} x_{i}\right\}_{k=1}^{N} {i=1kxi}k=1N

回忆一下,注意力 Attention ( q , x 1 : k ) = o k = a k c k \text{Attention}\left(\text{q}, \text{x}_{1: k}\right) = o_k = \frac{a_k}{c_k} Attention(q,x1:k)=ok=ckak,其中 a k = ∑ i = 1 k exp ⁡ ( s i − m k ) v i a_k = \sum_{i=1}^{k} \exp \left(s_i - m_k\right) v_i ak=i=1kexp(simk)vi c k = ∑ i = 1 k exp ⁡ ( s i − m k ) c_k = \sum_{i=1}^{k} \exp \left(s_i - m_k\right) ck=i=1kexp(simk),且 m k = max ⁡ i ∈ { 1 , … , k } s i m_k = \max_{i \in \{1, \ldots, k\}} s_i mk=maxi{1,,k}si。为了高效地计算 { Attention ( q , x 1 : k ) } k = 1 N \left\{\text{Attention}\left(\text{q}, \text{x}_{1: k}\right)\right\}_{k=1}^{N} {Attention(q,x1:k)}k=1N,我们可以使用并行扫描算法来计算 { a k } k = 1 N \left\{a_k\right\}_{k=1}^{N} {ak}k=1N { c k } k = 1 N \left\{c_k\right\}_{k=1}^{N} {ck}k=1N { m k } k = 1 N \left\{m_k\right\}_{k=1}^{N} {mk}k=1N,然后将 a k a_k ak c k c_k ck 结合起来计算注意力 Attention ( q , x 1 : k ) \text{Attention}\left(\text{q}, \text{x}_{1: k}\right) Attention(q,x1:k)

为了实现这一点,我们提出了以下结合运算符 ⊕ \oplus ,它作用于形式为 ( m A , u A , w A ) \left(m_A, u_A, w_A\right) (mA,uA,wA) 的三元组,其中 A A A 是一组索引, m A = max ⁡ i ∈ A s i m_A = \max_{i \in A} s_i mA=maxiAsi u A = ∑ i ∈ A exp ⁡ ( s i − m A ) u_A = \sum_{i \in A} \exp \left(s_i - m_A\right) uA=iAexp(simA),以及 w A = ∑ i ∈ A exp ⁡ ( s i − m A ) v i w_A = \sum_{i \in A} \exp \left(s_i - m_A\right) v_i wA=iAexp(simA)vi。并行扫描算法以 { ( m { i } , u { i } , w { i } ) } i = 1 N = { ( s i , 1 , v i ) } i = 1 N \left\{\left(m_{\{i\}}, u_{\{i\}}, w_{\{i\}}\right)\right\}_{i=1}^{N} = \left\{\left(s_i, 1, v_i\right)\right\}_{i=1}^{N} {(m{i},u{i},w{i})}i=1N={(si,1,vi)}i=1N 作为输入。算法递归地应用运算符 ⊕ \oplus ,具体工作方式如下:

( m A , u A , w A ) ⊕ ( m B , u B , w B ) = ( m A ∪ B , u A ∪ B , w A ∪ B ) \left(m_A, u_A, w_A\right) \oplus \left(m_B, u_B, w_B\right) = \left(m_{A \cup B}, u_{A \cup B}, w_{A \cup B}\right) (mA,uA,wA)(mB,uB,wB)=(mAB,uAB,wAB)

其中 m A ∪ B = max ⁡ ( m A , m B ) \mathrm{m}_{A \cup B} = \max \left(\mathrm{m}_{A}, \mathrm{m}_{B}\right) mAB=max(mA,mB) u A ∪ B = u A exp ⁡ ( m A − m A ∪ B ) + u B exp ⁡ ( m B − m A ∪ B ) \mathrm{u}_{A \cup B} = \mathrm{u}_{A} \exp \left(\mathrm{m}_{A} - \mathrm{m}_{A \cup B}\right) + \mathrm{u}_{B} \exp \left(\mathrm{m}_{B} - \mathrm{m}_{A \cup B}\right) uAB=uAexp(mAmAB)+uBexp(mBmAB),以及 w A ∪ B = w A exp ⁡ ( m A − m A ∪ B ) + w B exp ⁡ ( m B − m A ∪ B ) \mathrm{w}_{A \cup B} = \mathrm{w}_{A} \exp \left(\mathrm{m}_{A} - \mathrm{m}_{A \cup B}\right) + \mathrm{w}_{B} \exp \left(\mathrm{m}_{B} - \mathrm{m}_{A \cup B}\right) wAB=wAexp(mAmAB)+wBexp(mBmAB)。在递归地应用运算符后,算法输出 { ( m { 1 , … , k } , u { 1 , … , k } , w { 1 , … , k } ) } k = 1 N = { ( m k , ∑ i = 1 k exp ⁡ ( s i − m k ) , ∑ i = 1 k exp ⁡ ( s i − m k ) v i ) } k = 1 N \left\{\left(\mathrm{m}_{\{1, \ldots, k\}}, \mathrm{u}_{\{1, \ldots, k\}}, \mathrm{w}_{\{1, \ldots, k\}}\right)\right\}_{k=1}^{N} = \left\{\left(m_{k}, \sum_{i=1}^{k} \exp \left(s_{i} - m_{k}\right), \sum_{i=1}^{k} \exp \left(s_{i} - m_{k}\right) v_{i}\right)\right\}_{k=1}^{N} {(m{1,,k},u{1,,k},w{1,,k})}k=1N={(mk,i=1kexp(simk),i=1kexp(simk)vi)}k=1N。这也可以写作 { ( m k , c k , a k ) } k = 1 N \left\{\left(m_{k}, c_{k}, a_{k}\right)\right\}_{k=1}^{N} {(mk,ck,ak)}k=1N。将输出元组的最后两个值组合起来,我们得到注意力 Attention ( q , x 1 : k ) = o k = a k c k \text{Attention}\left(\text{q}, \text{x}_{1: k}\right) = o_{k} = \frac{a_{k}}{c_{k}} Attention(q,x1:k)=ok=ckak,从而形成了一种高效并行的注意力计算方法,作为多对多RNN(如图3所示)。

3.3 Aaren:

注意力作为循环神经网络
利用注意力机制的并行化多对多形式,我们提出了Aaren([A]ttention [a]s a [re]current neural [n]etwork)。Aaren的接口与Transformer相同,将N个输入映射到N个输出,其中第i个输出是第1个到第i个输入的聚合。因此,Aaren也是(1)自然可堆叠的,以及(2)能够为每个序列标记计算单独的损失项。然而,与使用因果自注意力的Transformer不同,Aaren使用上述方法将注意力计算为多对多RNN,使其更加高效。Aaren的功能如下:

h 1 ( 0 ) , … , h N ( 0 ) ← x 1 , … , x N [ h 1 ( j + 1 ) , … , h N ( j + 1 ) ] ← Aaren ( q ( j ) , [ h 1 ( j ) , … , h N ( j ) ] )

h1(0),,hN(0)x1,,xN[h1(j+1),,hN(j+1)]Aaren(q(j),[h1(j),,hN(j)])
h1(0),,hN(0)[h1(j+1),,hN(j+1)]x1,,xNAaren(q(j),[h1(j),,hN(j)])

与Transformer中查询是输入标记之一的情况不同,Aaren的查询标记 q q q是通过反向传播在训练过程中学习的。在图4中,我们给出了一个堆叠的Aaren模型示例,输入上下文标记为 x 1 : 3 x_{1:3} x1:3,输出为 y 1 : 3 y_{1:3} y1:3。值得注意的是,由于Aaren利用了注意力的RNN形式,Aaren的堆叠也是RNN的堆叠。因此,Aaren还能够高效地使用新标记进行更新,即 y k y_k yk的迭代计算仅需要常数计算量,因为它仅依赖于 h k − 1 h_{k-1} hk1 x k x_k xk。与基于Transformer的模型不同,基于Transformer的模型(1)需要线性内存(当使用 KV \text{KV} KV缓存时)和(2)需要存储所有先前的标记,包括中间Transformer层中的标记,而基于Aaren的模型(1)仅需要常数内存和(2)不需要存储所有先前的标记,这使得Aaren在计算效率上显著优于Transformer。

在这里插入图片描述

4 实验

实验的目的是在(1)性能和(2)所需资源(时间和内存)方面将Aaren与Transformer进行比较。为了进行全面的比较,我们在四个问题设置下进行评估:强化学习、事件预测、时间序列预测和时间序列分类。

数据集。我们总共在38个数据集上评估了Aaren和Transformer,其中大部分是现实世界的数据集。这些数据集按问题设置划分如下:12个强化学习数据集,8个事件预测数据集,8个时间序列预测数据集,以及10个时间序列分类数据集。对于每个数据集,模型都使用5个随机种子进行评估。由于空间限制,我们请读者参考附录C中对各个数据集的描述。

模型。为了直接将Aaren与Transformer进行比较,我们在特定领域的Transformer模型中用Aaren替换了Transformer。对于强化学习,我们在Decision Transformer(Chen et al., 2021)上进行了比较。对于事件预测,我们在Transformer Hawkes Process(Zuo et al., 2020; Bae et al., 2023)上进行了比较。对于时间序列预测,我们遵循Liu et al. (2022)的方法,在具有输入归一化的Transformer上进行了比较。对于时间序列分类,我们在遵循Wu et al. (2023)的库的普通因果Transformer上进行了比较。

实现细节。实验使用流行仓库中的数据集和Transformer实现进行。更具体地说,强化学习实验中的Decision Transformer是在Barhate (2022)的代码上运行的。时间序列预测和时间序列分类实验是在Wu et al. (2023)的时间序列库仓库上运行的。事件预测实验是在Bae et al. (2023)的代码上运行的。由于Transformers和Aarens具有相同的接口,并且都是基于注意力机制的方法,因此它们共享相同的超参数集。为了公平起见,Transformers和Aarens都使用了相同的超参数。由于空间限制,具体的超参数细节在附录E中提供。

4.1 强化学习

在这些实验中,我们将Aarens与Transformer在强化学习(RL)上进行比较。在RL中,模型在训练期间的目标是学习环境交互中通过试错获得的反馈/奖励来学习策略。因此,RL在交互式设置(如机器人学、推荐引擎和交通控制)中非常流行。对于这些实验,我们考虑了Decision Transformers(Chen et al., 2021),这是一种流行的方法,用于在环境交互的数据集上以离线方式训练策略。在测试时,Decision Transformers以在线方式进行评估。
在这里插入图片描述

我们在D4RL基准测试(Fu et al., 2020)中流行的运动机器人环境上进行评估:HalfCheetah、Ant、Hopper和Walker。对于每个运动环境,我们比较在三种不同种类的数据集上训练的模型:Medium、Medium-Replay和Medium-Expert。这些数据集中的每一个都包含由策略生成的100万个时间步长。总的来说,我们在12个RL数据集和四个环境之间比较了模型的性能。有关各个任务的更多详细信息,请参阅附录C.1。表1中的结果表明,Aarens在所有12个数据集和四个环境中都与Transformer达到了竞争性的性能。然而,与Transformer不同,Aarens由于也是一种RNN,能够以恒定计算量高效地处理新的环境交互,使其更适合强化学习。

4.2 事件预测

在这些实验中,我们将Aarens与Transformer在事件预测(EF)上进行比较。在EF中,模型被给定一系列时间上不规则间隔的离散事件,并建模下一个事件时间和其标记(即,事件标签/类别)的概率分布。EF在许多现实世界的设置中都很流行,如金融(如交易)、医疗保健(如患者观察)和电子商务(如购买),在这些设置中,用户或系统操作发生在不规则间隔的时间点上。为了进行比较,我们将Transformers Hawkes Process(Zuo et al., 2020)中的Transformer替换为Aarens。遵循Bae et al. (2023),我们使用对数正态分布的混合来建模下一个事件时间的概率分布。对于这一设置,我们考虑了8个流行的基准数据集用于下一个事件预测(Zhang et al., 2020; Zuo et al., 2020; Bae et al., 2023):MIMIC、Wiki、Reddit、Mooc、StackOverflow、Sin、Uber和Taxi。其中7个是真实世界的数据集,而只有Sin是合成数据集。在8个数据集中,有3个(Sin、Uber和Taxi)不包含标记/标签。有关各个数据集的详细信息,请参阅附录C.2。表2中的结果表明,Aarens在所有数据集上都与Transformer的性能相当。Aarens能够高效处理新输入的能力在事件预测设置中特别有用,其中事件以不规则流的形式到达。
在这里插入图片描述

4.3 时间序列预测

在这些实验中,我们将Aarens与Transformer在时间序列预测(TSF)上进行了比较。在TSF中,模型被给定一系列时间上连续的信号观测值。模型的目标是预测序列的未来T个值。TSF模型广泛应用于各种领域,包括与气候(如天气)、能源(如供需)和经济(如股票价格)相关的领域。为了进行比较,我们考虑了一个遵循Liu et al. (2022)的具有因果掩码的Transformer,并使用了输入归一化。对于这一设置,我们考虑了之前工作中使用的8个真实世界数据集:Weather、Exchange、Traffic、ECL、ETTh1、ETTh2、ETTm1和ETTm2。有关各个数据集的详细信息,请参阅附录C.3。根据Wu et al. (2023),给定输入长度为96,模型在T ∈ {96, 192, 336, 720}的情况下进行评估。由于空间限制,表3仅包含了T=192的结果。完整结果请参考附录D中的表5。表3中的结果表明,Aarens在所有数据集上的性能都与Transformer相当。然而,与Transformer不同,Aarens能够高效地处理时间序列数据,使其更适合与时间序列相关的领域。

在这里插入图片描述

4.4 时间序列分类

在这些实验中,我们将Aarens与Transformer在时间序列分类(TSC)上进行了比较。在TSC中,模型的目标是预测时间序列的标签。这种设置在许多重要应用中都很常见,如模式识别(例如,心电图)、异常检测(例如,银行欺诈)或故障预测(例如,电网波动)(Dinger et al., 2022)。对于这一设置,我们考虑了来自UEA时间序列分类存档(Bagnall et al., 2018)的10个真实世界流行数据集:EthanolConcentration、FaceDetection、Handwriting、Heartbeat、JapaneseVowels、PEMS-SF、SelfRegulationSCP1、SelfRegulationSCP2、ArabicDigits和UWaveGesture。有关各个数据集的详细信息,请参阅附录C.4。在表4中,我们可以看到Aarens在所有数据集上的性能都与Transformer相当。

4.5 分析

在这些实验中,我们从资源需求的角度比较了Aarens与Transformers。为此,我们使用了Barhate(2022)的代码。对于Transformers,我们使用了KV缓存来提高其效率。
在这里插入图片描述

内存复杂度:在图5(左)中,我们比较了Aarens和Transformers(使用KV缓存)在推理时的内存使用情况。我们发现,Transformers的内存使用随着KV缓存技术的使用而线性增长。相比之下,Aarens无论标记数量如何,都只使用恒定内存,这使得它显著更有效率。

时间复杂度:在图5(右)中,我们比较了Aarens和Transformers(使用KV缓存)按顺序处理一系列标记所需的累积时间。对于Transformers,累积计算量随标记数量的增加呈二次增长,即 O ( 1 + 2 + … + N ) = O ( N 2 ) O(1+2+\ldots+N)=O(N^{2}) O(1+2++N)=O(N2)。相比之下,对于Aarens,累积计算量是线性的。在图中,我们看到了模型所需累积时间的类似结果。具体来说,Transformers所需的累积时间呈二次增长,而Aarens的累积时间呈线性增长。

参数数量:由于学习初始隐藏状态 q q q,Aarens模块比Transformer模块需要稍多的参数。然而,这种差异是微不足道的,因为 q q q只是一个向量。在可比模型中通过经验测量,我们发现Transformers使用了 3 , 152 , 384 3,152,384 3,152,384个参数。相比之下,等效的Aarens使用了 3 , 152 , 896 3,152,896 3,152,896个参数,仅代表约 0.016 % 0.016\% 0.016%的参数增加——这是为了获得内存和时间复杂度方面的显著优势所做出的微小权衡。

5 相关工作

与Aaren最接近的是RWKV(Peng et al., 2023)、RetNet(Sun et al., 2023)和Linear Transformer(Katharopoulos et al., 2020)等注意力机制的近似模型。这些模型提出了标准基于softmax的注意力的线性化形式,使它们可以被表述为RNN。然而,在这样做的时候,这些模型也编码了一个基于时间戳的指数因子来偏置标记,这限制了它们的潜在应用。相比之下,Aaren利用softmax注意力的精确重新表述作为RNN,允许模型本身计算每个标记的权重。

Feng et al. (2023) 展示了注意力可以递归地计算,并使用它来压缩基于集合的输入。Rabe和Staats (2022) 引入了注意力的递归表述,展示了自注意力可以高效地计算。Katharopoulos et al. (2020) 展示了带有因果掩码的Transformer可以被视为RNN。相比之下,我们(1)展示了更一般的结果,即任何注意力模型都可以被视为RNN。此外,我们(2)引入了基于并行前缀和的新注意力形式Aaren,它在效率上比Transformer更高,同时取得了与之竞争的结果。

计算前缀扫描/和的问题已经被广泛研究,并提出了各种高效的并行化算法来计算它们。由于Aaren只需要前缀扫描的输出,因此可以使用任何有效的算法来计算它。在这项工作中,我们概述了Hillis和Steele (1986) 的方法。该方法在并行计算中时间效率很高,需要 log ⁡ 2 ( N ) \log _{2}(N) log2(N) 个顺序步骤和 O ( N log ⁡ ( N ) ) \mathcal{O}(N \log (N)) O(Nlog(N)) 的总体计算。相比之下,Ladner和Fischer (1980) 的方法使用了更多的顺序步骤(具体是 2 log ⁡ 2 ( N ) − 2 2 \log _{2}(N)-2 2log2(N)2),但仅执行 O ( N ) \mathcal{O}(N) O(N) 的总体计算。对于更深入的并行前缀和算法的介绍,我们建议读者参考Blelloch (1990) 的工作。
在这项工作中,我们将Transformer应用于一系列应用的子集。为了全面了解Transformer的应用,我们推荐读者参考Islam等人(2023)的综述。对于本文所考虑的特定设置下应用的不同Transformer模型的概述,我们推荐读者参考以下综述:(1)Li等人(2023)关于强化学习中Transformer的综述;(2)Wen等人(2022)关于事件预测、时间序列预测、时间序列分类等领域中Transformer的综述。

6 结论

在这项工作中,我们展示了注意力机制可以表述为RNN,而传统的注意力计算方法是通过并行化方式计算其多对一RNN输出的。基于RNN的表述,我们展示了现有的基于注意力的模型可以被表述为RNNs。然而,与传统的RNNs(如LSTM和GRU)不同,这些方法无法有效地使用新标记进行更新。为了解决这个问题,我们基于并行前缀扫描算法引入了一种新的并行化方法来计算注意力的多对多RNN输出。基于新的注意力表述,我们介绍了Aaren,一个新的模块,它不仅可以(i)像Transformer一样并行训练,还可以(ii)在推理时高效更新,因此只需要恒定内存(如RNNs)。实验上,我们展示了Aaren在跨越四个序列数据设置的38个数据集上实现了与Transformer竞争的性能:强化学习、事件预测、时间序列分类和时间序列预测。最后,我们实验性地展示了Aaren在时间和内存效率上显著优于Transformer。

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/672928
推荐阅读
相关标签
  

闽ICP备14008679号