赞
踩
TTT出来有一段时间了,让我确定要写TTT解读的,是源于我司LLM论文100篇课程群里的一学员辰子说,“校长 最近的TTT考不考虑讲一下”,以及微博上也有朋友问:“老师,TTT 怎么说?”
故当时想着:解读完mamba2之后,则解读open-television、我司7方面review微调gemma2,再接下来是TTT、nature审稿微调、序列并行、Flash Attention3、vLLM(Efficient memory management for large language model serving with pagedattention)..
如今虽然mamba2的解读还没完全修订完,但“open-television、我司7方面review微调gemma2”都解读的差不多了,故今天开写TTT,对应论文为:Learning to (Learn at Test Time): RNNs with Expressive Hidden States
比我科研水平高的人大有人在,但比我写的大模型技术博客还要更通俗易懂的则寥寥无几,究其原因,还是在于比他人花更多的时间和心思
核心动力还是来源于写博十多年下来积累的习惯:尽我最大努力,让最广大的读者可以最大程度的最快理解
虽说相对于transformer的二次复杂度,RNN有着线性的复杂度
当然,如果序列长度比较短,那么无论是二次还是线性,其实区别不大,只有在序列长度比较长,即长上下文时,RNN这个线性的复杂度才有比较高的价值或意义
那到底多长之后,会使得RNN这个线性复杂度具有极高的价值或意义呢,如下图所示,这个上下文长度是在8k之后
可问题是一旦上下文足够长后,现有的RNN如Mamba在实际利用这些额外信息时会遇到困难(On the other hand,once context is long enough, existing RNNs such as Mamba struggle to actually take advantage ofthe extra information being conditioned on)
为何RNN在面对长下文时处理起来会比较困难呢,原因在于与自注意力不同,RNN层必须将上下文压缩成固定大小的隐藏状态,其作为一种压缩启发式方法,更新规则需要发现成千上万甚至数百万个token之间的潜在结构和关系
如下图所示,先根据输入和前一时刻的隐藏状态计算出最新的隐藏状态,在此之后,便可以根据最新的隐藏状态预测出了
至于RNN的详细介绍,详见此文:如何从RNN起步,一步一步通俗理解LSTM
其实,所有序列建模层都可以从将历史上下文存储到隐藏状态的角度来看,比如RNN层——如LSTM [33]、RWKV [56]和Mamba [26]层——在时间上将上下文压缩到固定大小的状态中
如下图所示(如你所见,一个通用的序列建模层表示为一个根据更新规则转换的隐藏状态。所有序列建模层都可以看作是该图中三个组件的不同实例:初始状态、更新规则和输出规则)
我 | 是 | 中 | 国 | ? | |
我 | q1 k1 | q2 k1 | q3 k1 | q4 k1 | |
爱 | q1 k2 | q2 k2 | q3 k2 | q4 k2 | |
中 | q1 k3 | q2 k3 | q3 k3 | q4 k3 | |
国 | q1 k4 | q2 k4 | q3 k4 | q4 k4 |
所以现在的问题变成了,我们既希望将成千上万甚至数百万的token压缩到一个隐藏状态中,且该隐藏状态还能有效的捕捉它们的底层结构和关系(we need to compress thousands or potentially millions of tokens into a hidden state that can effectively capture their underlying structures and relationship),从而兼顾高效与质量
过去的一年半,LLM火爆全球
总之,种种迹象表明,可以使用自监督学习来压缩历史上下文到一个隐藏状态,通过将上下文变成一个无标签的数据集并将隐藏状态变成一个模型(Our key idea is to use self-supervised learning to compress the historic context x1, . . . , xt into a hidden state st , by making the context an unlabeled dataset and the state a model)
类似的,可以将隐藏状态现在等同于 ——模型的权重,这可以是一个线性模型,或一个小的神经网络(the hidden state st is now equivalent to Wt , the weights of a model f , which can be a linear model, a small neural network, or anything else)
而TTT通过建模为模型的权重来更新隐藏状态的过程中,涉及到:输出规则和更新规则
我 | |||||
我 | 是 | ||||
我 | 是 | 中 | |||
= 对{我 + 是 + 中 + 国}的压缩 | 我 | 是 | 中 | 国 | |
其中, = {或是“人”、或是“的”} | 我 | 是 | 中 | 国 | 基于,可得的预测 = 人 |
与其他RNN层和自注意力机制一样,算法将输入序列 x1, . . . , xT映射到输出序列 z1, . . . , zT可以通过使用上述隐藏状态、更新规则和输出规则编程到序列建模层的前向传递中
即使在测试时,新层仍然训练一组不同的权重序列W1, . . . , WT——对于每个输入序列,因此,称之为测试时训练(TTT)层
且
最后,你将在下文依次看到更多的公式(公式4-7还不理解 没事,继续看下文即可)
输出规则(外循环) | 公式1 | 公式3 | 公式4 | 公式5 | ||
更新规则(内循环) | 公式2 | 公式6 | 公式7 |
考虑到TTT的最终目标是使公式1——在语言建模上表现良好,故可以基于人类先验知识的自监督任务,采用一种更端到端的方法——直接优化自监督任务以实现下一个词预测的最终目标
具体来说,在外循环中学习自监督任务
弄这么多不同view的好处是什么呢?
目前的TTT有个比较明显的问题是,其更新规则——公式2不能被并行化,因为在两个地方依赖于
而由于后者内部包含了大部分计算,所以我们要重点对后者——内部,做并行化
首先,梯度下降GD有许多变体,GD的一般更新规则可以表示为(记为公式6)
其中,是下降方向,这个公式的价值和意义在于,一旦计算出,便可以基于上述公式的后半部分且通过cumsum获得所有的
而对于我们所想要的更新规则是在线梯度下降,使用
进一步,这个其实控制着速度与质量之间的权衡,如下图所示(在TTT的所有实验中,作者们选择)
总之,有两个潜在的渠道可以将信息从 传播到,其中:对梯度算子做累加和(cumsum and the gradient operator)
然,上面面介绍的并行化是必要的,但对于wall-clock time的效率而言还不够(说白了,就是速度还不够快),故接下来,咱们来探讨下对偶形式
现代加速器专门用于矩阵-矩阵乘法,称为 matmuls。例如,NVIDIA A100 GPU 包含高度优化的单元,称为 TensorCores,它们只能执行一种操作——将两个大小为 16 ×16 的矩阵相乘。 如果没有足够的矩阵乘法,TensorCores 就会闲置,A100 的大部分潜力将无法实现
不幸的是,即使使用小批量开发的 TTT 层仍然有很少的矩阵乘法
考虑 ℓ的最简单情况,其中,仅针对第一个大小为 b的 TTT 小批量
此外,考虑作为线性模型,复制公式3,在时间的损失为:
为帮助大家更好的理解,回顾一下之前的公式1、公式2、公式3,以及公式6
输出规则(外循环) 公式1 公式3 公式4 公式5 更新规则(内循环) 公式2 公式6
如上一节所讨论的,可以并行化以下计算
对于 t= 1, . . . , b,然而不能通过一个单一的matmul计算所有的b个
相反,需要 b个外积来逐个计算它们。更糟糕的是,对于每个,是 d × d,这会比在大 时产生更大的内存占用和I/O成本
为了解决这两个问题,他们做了一个简单的观察:实际上不需要具体化,只要我们能在小批量结束时计算 和输出token
现在用上面简化的TTT-Linear案例来演示这些计算
其中掩码是具有零值的下三角掩码(类似于注意力掩码,但用零代替无穷大),并且项可
以从的计算中重用。 现在 ∆也可以方便地用矩阵乘法计算。将 ∆代入公式7,我们得到
以上,称这个过程为对偶形式,与之前的原始形式相对比,其中 G和 W是显式物化的,如前所述,这两种形式在输出上是等价的
回顾一下
其实,在这些引发的实例化中,具有线性模型和批量GD的TTT层其实等价于线性注意力[Transformers are rnns: Fast autoregressive transformers with linear attention]——一种广为人知的RNN层
简而言之,线性注意力 [41] 只是没有softmax的自注意力。 回顾自注意力的定义:
没有 softmax , 其便变成了
这是线性注意力的最简单形式
与其他RNN层类似,它可以写成递归形式,其中是隐藏状态
而由于可以通过 cumsum在每个计算,因此线性注意力相对于 T也具有线性复杂度
定理1 考虑TTT层,其中作为内循环模型,批量梯度下降,作为更新规则,并且
然后,给定相同的输入序列,公式5中定义的输出规则产生相同的输出序列作为线性注意力
且为方便大家理解,特此再列一下上文介绍过的公式4、公式5、公式6
输出规则(外循环) 公式1 公式3 公式4 公式5 更新规则(内循环) 公式2 公式6
其中
该定理的证明如下
根据公式4中 的定义,有:
且根据公式6中的批量GD定义,可知
将 代入公式5中的输出规则,可得到输出token
这是线性注意力的定义
定理2 考虑使用Nadaraya-Watson估计器[7, 12]定义的TTT层:
其中是第1.2.1节讨论的标签view
- 从公式3中的简单重建任务开始,添加了一些外循环参数,使这个任务可以学习
- 为产生从的损坏
一种设计是使其成为低秩投影,其中是一个可学习的矩阵,总之,被称为训练view
由于并不是中的所有信息都值得记住,因此重建标签可以是另一个低秩投影而不是,这里被称为标签view,其中也是可学习的
并且
是一个带有带宽超参数的核函数和,然后给定相同的输入序列, 公式5中定义的输出规则产生与自注意力相同的输出序列
为了统一这两种构造,可定义一种新的抽象,称为学习者,它唯一地引发了TTT层
类似于标准机器学习包中的定义 [54],所有学习者都需要实现两个方法:训练和预测。 现在将induced的TTT层的隐藏状态重新定义为学习者的内部存储,并将更新和输出规则重新定义为训练和预测方法
在这种新的TTT层定义下,定理1中的参数学习器和定理2中的非参数学习器都可以包括在内,比如下图便总结了在所有序列建模层的更广泛范围内TTT层的这一通用定义
这种通用定义对参数学习器有一个额外的好处:在参数学习器的内部存储中,除了 之外,还可以有更多的对象,例如优化器状态,这也将包含在诱导的TTT层的隐藏状态中。 这种扩展允许TTT层在未来的工作中使用更复杂的优化器,例如Adam [42]
提出了两种TTT层的变体——TTT-Linear和TTT-MLP,它们仅在的实例化上有所不同
具体来说,隐藏维度是输入维度的4×,然后是GELU激活[31]
且为了在TTT期间获得更好的稳定性,总是包含层归一化(LN)和残差连接。即,,其中可以是或
TTT初始化 在所有序列之间共享,尽管后续权重对于每个输入序列是不同的。可以将作为外循环的一部分来学习,而不是将其设置为0
由于外循环参数总是用θ而不是W表示,将别名分配给它。 在实践中,与重建视图相比,增加的参数量可以忽略不计,因为它的输入和输出都是低维的。根据经验,观察到学习显著提高了训练的稳定性
先回顾下
1.2.2 对于更新规则——使用小批量TTT进行并行化:小批量梯度下降
目前的TTT有个比较明显的问题是,其更新规则——公式2不能被并行化,因为在两个地方依赖于
- 一个地方是在减号前
- 一个地方在内部
而由于后者内部包含了大部分计算,所以我们要重点对后者——内部,做并行化
首先,梯度下降GD有许多变体,GD的一般更新规则可以表示为(记为公式6)
其中,是下降方向,这个公式的价值和意义在于,一旦计算出,便可以基于上述公式的后半部分且通过cumsum获得所有的
而对于我们所想要的更新规则是在线梯度下降,使用
- 为了并行化,可以对所有这些变量相对于进行计算,在这种变体情况下,其中被称为批量梯度下降
学习率通常是梯度下降中最重要的超参数,因此尝试在外循环中学习内循环学习率,如之前介绍过的公式6
为了增加灵活性,可以使成为输入token的函数(因此在时间上有所不同)
具体来说,设计,其中可学习向量是一个外循环参数,是sigmoid函数,标量是基础学习率,对于TTT-Linear设为1,对于TTT-MLP设为0.1。当然,也可以解释为的一个门
将任何RNN层集成到更大的架构中的最干净方法是直接替换Transformer中的自注意力机制,在这种情况下称为骨干
然而,现有的RNN如Mamba [26] 和Gri n [18] 都使用与Transformer不同的骨干。最显著的是,它们的骨干在RNN层之前包含时间卷积,这可能有助于收集跨时间的局部信息
且在实验了Mamba骨干后,发现它也提高了TTT层的困惑度,因此TTT将其纳入TTT的方法中,详见下图
如上图所示
实际实现时,由于Transformer和Mamba使用不同的骨干网络,而TTT-Linear和TTT-MLP总是默认使用Mamba骨干网络,除非另有说明
接下来,通过与两个基线——Transformer和现代RNN Mamba进行比较来评估TTT-Linear和TTT-MLP
且主要代码库基于EasyLM [25],这是一个用于在JAX中训练和服务LLM的开源项目。另,所有
实验都可以使用这里《https://github.com/test-time-training/ttt-lm-pytorch》提供的公开代码和数据集进行复现
在数据集的选择上,根据Mamba论文 [26],在Pile [24] 上进行标准实验,使用2k和8k的上下文长度,这是一个用于训练开源LLM的流行文档数据集 [9]
然而,Pile包含的长度超过8k的序列很少 [19]。 为了评估长上下文的能力,我们还在1k到32k范围
内以2×递增的上下文长度进行实验,使用一个名为Books3的Pile子集,该子集已被广泛用于训练长上下文的LLM [49, 3]
且为了确保评估公平,作者们在可能的情况下严格遵循Mamba论文中的评估协议:
如下图所示「当一个图同时包含Transformer骨干和Mamba骨干时,则分别用(T)和 (M)表示」
总之,TTT-Linear在2k上下文中表现与Mamba相当,而在8k上下文中表现更好
将TTT层从Mamba骨干网络切换到Transformer骨干网络有两个影响
我们假设,当序列建模层具有较不具表现力的隐藏状态时,Mamba骨干中的时间卷积会更有帮
助(We hypothesize that the temporal convolutions in the Mamba backbone help more when the sequence modeling layer has a less expressive hidden state.)
所以虽说线性模型的表现力不如MLP,但线性模型可以从mamba的卷积中受益更多,才使得具备mamba骨干网络的TTT-Linear可以比肩TTT-MLP(The linear model is less expressive than the MLP, therefore benefits more from the convolutions.We will revisit this hypothesis in the next subsection),即线性 + mamba ≈ MLP,即mamba的卷积部分可以弥补线性模型的劣势
如下图所示
Mamba是众多结构化状态空间模型之一 [27, 21, 57, 18]。 这些模型中的隐藏状态是一个向量,类似于LSTM
在TTT-Linear或TTT-MLP中,隐藏状态是一个矩阵或两个矩阵,因此更大
如下图所示,我们发现TTT层可以利用其更大的隐藏状态在长上下文中压缩更多信息,其中TTT-MLP优于TTT-Linear,而TTT-Linear又优于Mamba
与TTT-Linear类似,RWKV [55, 56]、xLSTM [5]和门控线性注意力(GLA)[79]也具有矩阵隐藏状
态,这些状态继承自线性注意力 [41]。 现代RNN如GLA使用块状并行来提高硬件效率,因此块内
的标记可以通过matmul而不是cumsum来处理
然而,块状并行并没有改变模型的表达能力,因为所有时间依赖性仍然等同于 cumsum
相比之下,小批量TTT允许跨小批量的更复杂的时间依赖性。每个隐藏状态 依赖于其小批量内的先前 仍然通过 cumsum,但也通过梯度算子依赖于先前小批量中的
如之前这图所示
小批量TTT在表达能力和硬件效率之间实现了权衡,因为较小的批量大小b会以更高的延迟为代价带来更好的困惑度。 这种权衡是TTT的一个独特且重要的特性
如下表所示,中间批量大小b= 16显著优于 b=T完全cumsum
如上表所示,这里所有模型都有 125M 参数,并按照小节 3.1 中的配方进行训练
在机器学习中,测试时学习的概念有着悠久的历史。 这种概念最早的版本之一被称为局部学习(Bottou 和 Vapnik [10]):对于每个测试输入,在做出预测之前,先对其邻居进行训练。 这种程序已被有效地应用于从SVM [81]到现代LLM [29]的各种模型
测试时学习的另一个早期版本被称为传导学习 [22]。 Vladimir Vapnik [74]提出的传导原则是“...得到你真正需要的答案,而不是更一般的答案。”
传导学习的实际应用使用测试数据来为SVM的边界添加约束 [39, 17]。 然而,传导学习通常需要多个测试实例才能在经验上有效,不像许多测试时训练的实例化,只需要一次一个测试实例(图像、视频或自然语言序列)
在计算机视觉中,测试时学习的理念已经应用于面部检测 [38]、物体检测 [53]、图像超分辨率 [65]和 3D 重建 [50] 等应用领域数十年
最近,同样的理念也被应用于自然语言处理领域,在那里它被称为动态评估 [44, 45]。 基本方法是直接在测试序列上微调语言模型,这通常以prompt的形式出现
接下来,我们详细讨论两个相关的工作方向:测试时训练和快速权重
测试时训练 (TTT) 的核心思想是每个测试实例定义其自身的学习问题,其中该测试实例本身是泛化的目标 [69]
具体来说
由于测试实例没有标签,学习问题只能通过自监督任务来制定。 先前的工作表明,使用重建的TTT显著提高了性能,尤其是在异常值上[23]
当在以流的形式到达的视频帧上进行测试并且TTT是自回归的时,改进变得更加显著[76],因为是在过去的帧上训练的。 自回归连接使得[76]与TTT的论文最相关
从概念上讲,TTT与先前工作的最大区别在于TTT的重建任务是在外循环中学习的,而不是通过人为先验手工制作的。 TTT的后续工作探索了诸如机器人操作[28]和运动[68]等应用,这些应用通常需要对自监督任务进行不同的设计
快速权重的一般思想是仅在最相关的数据上更新“快速”模型的参数,而不是像传统做法那样在所有数据上更新“慢速”模型 [71]
这个想法自20世纪80年代以来就存在了 [32]。 最相关的数据可以是测试实例本身,因此TTT可以被视为快速权重的一种特殊情况
先前关于快速权重的工作通常避免形成一个明确的学习问题来优化数据上的某些目标。 例如,Hebbian学习和Hopfield网络的更新规则 [35] 只是简单地将 xxT(或其某些变体) [4] 添加到每个输入 x的快速权重中。 相比之下,TTT接受了明确制定学习问题的理念,其中测试实例是泛化的目标。 TTT的更新规则也是一个明确的优化步骤
快速权重程序员(FWPs)的想法是用一个“慢”模型来更新快速权重 [62]
几十年来,研究人员一直在争论,学习如何学习,也称为元学习或双层优化,应该是智能的关键组成部分 [61, 6, 70, 47]。 在之前的工作中,例如 [2]、[20] 和 [52],内循环每次从整个数据集而不是序列中学习,因此外循环需要一组数据集或任务。 简而言之,外循环是“比常规训练高一级”。由于很难收集数百万个数据集,这个外循环很难扩展
相比之下,对于TTT,每个序列本身就是一个数据集,并定义了自己的泛化问题。内循环比常规训练“低一级”,因此TTT的外循环只是监督学习经典问题的另一种解决方案,而不是像跨数据集泛化那样的新问题设置
如下表所示,TTT的外循环与常规训练“处于同一级别”,这使得TTT的外循环更容易扩展
总之,如上表所示,TTT的论文将监督学习重新表述为学习如何学习,具有两个嵌套循环
外循环的高亮行与常规训练中的相同,外循环的参数成为内循环的超参数。 直观地说,内循环,即TTT,是常规训练的“下一级”
最后,如TTT论文所说
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。