对于视觉 Transformer,将其 Self-Attention 中的 Softmax 操作替换为 ReLU/序列长度 (seqlen) 之后,性能的下降问题有所缓解。本文在 ImageNet-21K 上训练了从 Small 级别到 Large 级别的视觉 Transformer,证明了 ReLU-attention 可以在缩放性上接近或者匹配 Softmax-attention 的性能。Google出品,使用ReLU取代Softmax,ViT性能不退化
本文的研究结论是:对于视觉 Transformer,将其 Self-Attention 中的 Softmax 操作替换为 ReLU/序列长度 (seqlen) 之后,性能的下降问题有所缓解。本文在 ImageNet-21K 上训练了从 Small 级别到 Large 级别的视觉 Transformer,证明了 ReLU-attention 可以在缩放性上接近或者匹配 Softmax-attention 的性能。
1 在 ViT 中使用 ReLU 取代 Softmax
论文名称: Replacing softmax with ReLU in Vision Transformers (Arxiv 2023)
论文地址:https//arxiv.org/pdf/2309.08586.pdf
1.1 ReLU-attention 的新发现
Transformer 架构[1]在现代机器学习中无处不在。Attention 是 Transformer 的核心组件,包括一个 Softmax 操作,它在 token 上产生概率分布。Softmax 操作涉及到内部的计算所有输入的指数之和,它的计算代价相当昂贵,使得 Transformer 架构的并行化具有挑战性[2]。
本文作者探索了 Softmax 操作的 Point-wise 的替代方案,该操作不一定输出概率分布。本文的核心贡献是观察到:ReLU/序列长度(seqlen) ,可以在缩放性方面接近或匹配传统的 Softmax 操作。这一结果为并行化提供了新的机会,因为 ReLU-attention 相比传统的 Softmax-attention 可以使用更少的 gather 操作在序列长度维度实现并行化。
1.2 去掉 Softmax 的相关工作
替换 Softmax 的研究:
- ReLU 和 squared ReLU:[3][4]把 Softmax 替换成了 ReLU,[5]把 Softmax 替换成了 squared ReLU。但是这些方法不会除以序列长度,本文通过实验发现对于达到与 Softmax 相当的准确度很重要。
- [6]仍然需要对序列长度轴进行归一化,以确保注意力权重之和为1,这依然需要 gather。
去掉激活函数的研究:
1.3 ReLU-attention 方法
在进行 Self-attention 的操作时,首先计算注意力权重:
图1:Scaled point-wise attention 实验结果
Sequence length scaling
1.4 实验结果
作者在 ImageNet-21K 上训练了 30 Epochs,在 ImageNet-1K 上训练了 300 Epochs。作者使用了 ViT-22B[10]中提出的 qk-norm 技术,因为这个技术被验证在扩大视觉模型时有益于优化稳定性,但是作者发现在本文量级的模型这一技术没那么重要。
如下图2所示说明了 ReLU-attention 与 ImageNet-21K 训练的 Softmax-attention 的缩放趋势相匹配。x 轴表示实验所需的总 core hours。ReLU-attention 的优势是能够以比 Softmax-attention 以更少的 gather 操作对序列长度维度进行并行化。
图2:Softmax 操作替换为 ReLU/seqlen 的缩放性能与传统带有 qk-layernorm 的 Transformer 的缩放性能匹配
1.5 qk-norm 实验结果
本文主要实验使用了 qk-norm,其中 query 和 key 在计算注意力权重之前通过 LayerNorm 传递,作者发现有必要在扩大模型大小时防止不稳定性。如图3所示是 qk-layernorm 的实验结果。结果表明,qk-norm 对这些模型没有很大的影响。
图3:qk-norm 实验结果
1.6 添加 gate 的影响
[11]这个工作删除了 Softmax 之后,添加了一个门控单元,并且不按序列长度缩放。具体而言,在门控注意力单元中,通过额外的投影层产生输出,该输出在输出映射之前与注意力的结果做 Element-wise 的乘法。
如图4所示是添加 gate 的影响实验结果。作者研究了 gate 的存在是否消除了序列长度缩放的需要。总体而言,作者观察到无论有没有 gate 的存在,使用序列长度缩放都实现了最佳精度。注意到对于带有 ReLU 的 S/8 模型,添加 gate 操作将实验所需的 core hour 增加了大约 9.3%。
图4:添加 gate 的影响