赞
踩
MegatronLM
的第三篇论文【Reducing Activation Recomputation in Large Transformer Models】是2022年出的。在大模型训练过程中显存占用过大往往成为瓶颈,一般会通过recomputation重计算的方式降低显存占用,但会带来额外的计算代价。这篇论文提出了两种方法,分别是sequece parallel
和selective activation recomputation
,这两种方法和Tensor并行是可以相结合的,可以有效减少不必要的计算量。
下图中绿色部分表示不同模型中需要用于保存activation需要的显存大小,蓝色部分表示不同模型中需要用于保存parameter和optimizer state需要的显存大小。红色线表示A100的显存大小80G。
以Transformer结构为例估算Activation Memory
大小,这里的Activation
定义是指前向和反向梯度计算中创建的所有tensor。按这个定义来说,计算不包含模型参数大小和优化器中状态大小,但是包含dropout op用到的mask tensor。
一个Transformer块中由一个Attention块和一个MLP块组成,中间通过两个LayerNorm层进行连接。在Transformer中用到的参数表示如下:
Attention模块的计算公式如下:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
T
d
k
)
V
Attention(Q,K,V)=softmax(QKT√dk)V
对于Attention
块来说,输入的element个数为sbh
个,每个element以16-bit的浮点数(也就是2 bytes)来进行存储的话,对应输入的element大小为2sbh bytes
,后续计算默认都是按bytes
为单位进行计算。
Attention
块中包含一个self-attention
块、一个linear
线性映射层和attention dropout
层。对于linear
线性映射层来说需要保存输入的Activation
大小为2sbh
, 对于attention dropout
层需要mask的大小为sbh
(对于一个元素的mask只用1个bytes即可),对于self-attention块的Activation Memory
的计算有以下几块:
sbh
个,总大小是 2sbh
bytes。2sbh
bytes, 总共大小为 4sbh
bytes。如下图以b=1, s=2, h=6
为例,输入
X
X
X元素个数为1 * s * h = 12
个,计算完后
Q
Q
Q 和
K
K
K 的矩阵中元素个数各有 1 * s * h = 12
个,总元素大小为2 * 2 * b * s * h = 48
bytes。softmax的输出总的元素大小为
2
a
s
2
b
2as^2b
2as2b bytes, 分别计算每个Head头的
Q
n
×
K
n
Q_n \times K_n
Qn×Kn 的乘积。计算公式如下, 图中计算以b=1, s=2, h=6, a=2
为例:
在softmax后还有dropout的mask层大小,mask矩阵的大小与softmax的输出一样,元素个数都是 a s 2 b as^2b as2b 个,但mask单个元素的大小只用1 bytes即可,总的大小为 a s 2 b as^2b as2b bytes
softmax的输出也会用于反向的计算,需要缓存下来,对应大小也是 2 a s 2 b 2as^2b 2as2b
V
V
V 矩阵的大小之前没有统计,和
Q
Q
Q、
K
K
K矩阵一样,大小也是2sbh
bytes
综上,Attention Block总的大小为 11sbh + 5as^2b
bytes。
MLP的Activation大小计算:MLP中有两层线性layer,分别存储输入矩阵大小为
2
s
b
h
2sbh
2sbh bytes和
8
s
b
h
8sbh
8sbh bytes;GeLU的反向也需要对输入进行缓存,大小为
8
s
b
h
8sbh
8sbh bytes; dropout层需要 sbh
bytes; 总大小为 19sbh
。
LayerNorm的Activation大小计算:每个LayerNorm层的输入需要
2
s
b
h
2sbh
2sbh 大小,有两个LayerNorm层,总大小为 4sbh
bytes.
最终transformer网络中一层(含Attention/MLP/LayerNorm)的Activation总的大小为:
A
c
t
i
v
a
t
i
o
n
M
e
m
o
r
y
P
e
r
L
a
y
e
r
=
s
b
h
(
34
+
5
a
s
h
)
ActivationMemoryPerLayer=sbh(34+5ash)
注意: 这里公式(1)计算的Activation总和是在没有应用模型并行策略的前提下进行的。
如下图,在Tensor模型并行中只在Attention和MLP两个地方进行了并行计算,对于Attention(Q/K/V)和MLP(Linear Layer)的输入并没有并行操作。图中 f f f 和 f ‾ \overline{f} f 互为共轭(conjugate), f f f 在前向时不做操作,反向时执行all-reduce; f ‾ \overline{f} f 在前向时执行all-reduce, 反向时不做操作。
参虑上Tensor并行的话(Tensor并行度为
t
t
t),并行部分有MLP的Linear部分(
18
s
b
h
18sbh
18sbh bytes)和Attention的QKV部分(
6
s
b
h
6sbh
6sbh bytes), ActivationMemoryPerLayer
相比公式(1)中的值降为:
A
c
t
i
v
a
t
i
o
n
M
e
m
o
r
y
P
e
r
L
a
y
e
r
=
s
b
h
(
10
+
24
t
+
5
a
s
h
t
)
ActivationMemoryPerLayer=sbh(10+24t+5asht)
在Tensor模型并行基础上提出了Sequence Parallel
,对于非Tensor模型并行的部分在sequence维度都是相互独立的,所以可以在sequence维度上进行拆分(即sequence parallel
)。拆分后如下图,
f
f
f 和
f
‾
\overline{f}
f 替换为
g
g
g 和
g
‾
\overline{g}
g,
g
g
g 和
g
‾
\overline{g}
g 也是共轭的,
g
g
g 在前向是all-gather通信,反向是reduce-scatter通信;
g
‾
\overline{g}
g在前向是reduce-scatter, 反向是all-gather通信。
接下来以MLP为例,详细说明拆分步骤。MLP层由两个Linear层组成,对应的计算公式如下, 其中 X X X 的大小为 s × b × h s \times b \times h s×b×h ; A A A 和 B B B 是Linear的权重weight矩阵,大小为 h × 4 h h \times 4h h×4h 和 4 h × h 4h \times h 4h×h。
Y
=
L
a
y
e
r
N
o
r
m
(
X
)
Z
=
G
e
L
U
(
Y
A
)
W
=
Z
B
V
=
D
r
o
p
o
u
t
(
W
)
Y=LayerNorm(X)Z=GeLU(YA)W=ZBV=Dropout(W)
如下图,切分时说明如下:
对应的计算公式如下:
[
Y
1
s
,
Y
2
s
]
=
L
a
y
e
r
N
o
r
m
(
[
X
1
s
,
X
2
s
]
)
Y
=
g
(
Y
1
s
,
Y
2
s
)
[
Z
1
h
,
Z
2
h
]
=
[
G
e
L
U
(
Y
A
1
c
)
,
G
e
L
U
(
Y
A
2
c
)
]
W
1
=
Z
1
h
B
1
r
W
2
=
Z
2
h
B
2
r
[
W
1
s
,
W
2
s
]
=
g
‾
(
W
1
,
W
2
)
[
V
1
s
,
V
2
s
]
=
[
D
r
o
p
o
u
t
(
W
1
s
)
,
D
r
o
p
o
u
t
(
W
2
s
)
]
[Ys1,Ys2]=LayerNorm([Xs1,Xs2])Y=g(Ys1,Ys2)[Zh1,Zh2]=[GeLU(YAc1),GeLU(YAc2)]W1=Zh1Br1W2=Zh2Br2[Ws1,Ws2]=¯g(W1,W2)[Vs1,Vs2]=[Dropout(Ws1),Dropout(Ws2)]
Tensor并行在一次前向和后向总共有4次的all-reduce
操作,在Sequence并行一次前向和后向总共有4次all-gather
和4次reduce-scatter
操作。ring all-reduce
执行过程中有两步,先是一个reduce-scatter
然后跟着一个all-gather
,Sequence并行相比没有引入更多的通信代价。一个使用reduce-scatter
和all-gather
实现all-reduce
的Python代码示例如下:
import torch import torch.distributed as dist # 初始化进程组 dist.init_process_group(backend='gloo') # 获取进程组中的进程数和当前进程的排名 world_size = dist.get_world_size() rank = dist.get_rank() # 定义输入和输出张量 x = torch.tensor([1, 2, 3, 4]) result = torch.zeros_like(x) # 使用 reduce_scatter 将每个进程的输入张量的部分和归约到每个进程的输出张量 dist.reduce_scatter(input_list=[x], output=result) # 使用 all_gather 将每个进程的输出张量收集到所有进程中 output_list = [torch.zeros_like(result) for _ in range(world_size)] dist.all_gather(output_list, result) # 在每个进程上打印结果 print(f"Process {rank}: {output_list}") # 清理资源 dist.destroy_process_group()
通过使用sequence parallel
和tensor parallel
以后,ActivationMemoryPerLayer
相比公式(2)的值再次减少,相比公式(1)相当于对所有的ActivationMemory进行Tensor并行, 即
A
c
t
i
v
a
t
i
o
n
M
e
m
o
r
y
P
e
r
L
a
y
e
r
t
\frac{ActivationMemoryPerLayer}{t}
tActivationMemoryPerLayer:
A
c
t
i
v
a
t
i
o
n
M
e
m
o
r
y
P
e
r
L
a
y
e
r
=
s
b
h
(
10
t
+
24
t
+
5
a
s
h
t
)
=
s
b
h
t
(
34
+
5
a
s
h
)
ActivationMemoryPerLayer=sbh(10t+24t+5asht)=sbht(34+5ash)
加上Pipeline Parallel
后,对具有
L
L
L 层的layer的transformer来说,Pipeline Parallel
并行度为
p
p
p, 对应会分为
L
p
\frac{L}{p}
pL 组(即stage个数)。以PipeDream中的1F1B
调度为例,要完成初始化的话,第1个stage必须处理完
p
p
p 个micro-batch,让其他stage至少有1个micro-batch在处理,也就是要缓存
p
p
p 个micro-batch的activation。由于每个stage都有
L
p
\frac{L}{p}
pL 个Layer,一共需要
p
×
L
p
=
L
p \times \frac{L}{p} = L
p×pL=L 个layer的activation信息,对应总的计算如下:
T
o
t
a
l
A
c
t
i
v
a
t
i
o
n
M
e
m
o
r
y
=
s
b
h
L
t
(
34
+
5
a
s
h
)
TotalActivationMemory=sbhLt(34+5ash)
当然这里的公式(5)的ActivationMemory的计算没有加上EmbeddingLayer
、最后的LayerNorm
和输出的OutputLayer
。加上这三部分的结果会略大于公式(5), 但以22B参数模型来说只增加0.01%的大小,这部分可忽略,证明请参考原论文。未计算部分如下图红色部分:
在后向过程中通过重计算方式重新计算前向结果来节省显存大小,这种方式文中称为full activation recomputation
,以transformer为例会增加30%~40%的计算量。Selective
的方式主要思路是选择 FLOPs
计算量小,且activation占用大的算子进行重计算,这里的 FLOPs
的衡量标准是GEMM的计算量大小。以公式(5)为例,针对大模型来说
5
a
s
/
h
>
34
5as/h \gt 34
5as/h>34, 如果重计算这部分layer的话可以减少快一半的activation大小。对于GPT-3来说,这种方式可以减少70%的activation显存大小,同时只增加了2.7%的
F
L
O
P
s
FLOPs
FLOPs 计算量。采用Selective Activation Recomputation
后,公式(5)的结果可以减少为:
T
o
t
a
l
r
e
q
u
i
r
e
d
m
e
m
o
r
y
=
34
s
b
h
L
t
Total required memory=34sbhLt
以下是不同方法组合下Activation Memory占用情况:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。