赞
踩
--fp16
, 对应megatron/arguments.py
中的定义如下: group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.')
lm-cross-entropy
时默认是使用fp32来计算的,在开启--fp16
选项的前提下可以通过指定--fp16-lm-cross-entropy
来使用fp16计算lm-loss-entropy
,对应megatron/arguments.py
中的定义如下: group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.')
args.fp32_residual_connection
,这里设置了的话会在计算残差连接的时候转为fp32再进行计算,这里残差连接在网络中对应是Embedding模块。 if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
validate_args
函数用于check参数有效性,fp16相关实现如下:def validate_args(args, defaults={}):
......
args.params_dtype = torch.float
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
......
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
......
如果指定了fp16,这里的args.fp16
为True,对应的args.params_dtype
参数类型为torch.half
。
ParallelAttention中有self.query_key_value
、self.core_attention
和self.dense
等子模块,fp16对训练的影响会应用在子模块中。
class ParallelAttention(MegatronModule): """Parallel self-attention layer abstract class. Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ def __init__(self, init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.padding): ... self.query_key_value = tensor_parallel.ColumnParallelLinear( args.hidden_size, 3 * projection_size, bias=args.add_bias_linear, gather_output=False, init_method=init_method, async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, **_args_to_kwargs()) ... self.core_attention = CoreAttention(self.layer_number, self.attn_mask_type) ... self.dense = tensor_parallel.RowParallelLinear( projection_size, args.hidden_size, bias=args.add_bias_linear, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, **_args_to_kwargs())
对于self.query_key_value
和self.dense
模块,fp16的设置能过参数中的**_args_to_kwargs()
进行传递。
def _args_to_kwargs():
args = get_args()
common_kwargs = {
"params_dtype": args.params_dtype,
"use_cpu_initialization": args.use_cpu_initialization,
"perform_initialization": args.perform_initialization,
"gradient_accumulation_fusion": args.gradient_accumulation_fusion,
"sequence_parallel_enabled": args.sequence_parallel,
}
return common_kwargs
对于self.core_attention
部分,fp16的设置是在CoreAttention
的__init__
中self.fp16 = args.fp16
。
class CoreAttention(MegatronModule):
def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
...
在ParallelAttention
模块本身中fp16会影响推理部分
class ParallelAttention(MegatronModule): def __init__(self, init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.padding): ... self.params_dtype = args.params_dtype ... def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( inference_max_sequence_len, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device()) def forward(self, hidden_states, attention_mask, encoder_output=None, inference_params=None, rotary_pos_emb=None): ... if inference_params: if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_len inf_max_batch_size = inference_params.max_batch_size inference_key_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) inference_value_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) ...
ParallelAttention
模型__init__
初始化时会设置参数类型self.params_dtype
为fp16_allocate_memory
中会用torch.empty
创建用于推理的大buffer,类型是fp16inference_params
时,forward函数中会调用_allocate_memory
当设了fp16以后,在CoreAttention
的forward计算的input就是fp16类型,在init中设置fp16 flag主要是用于计算中用到的FusedScaleMaskSoftmax
模块的输出结果类型转换。
class CoreAttention(MegatronModule):
def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
...
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
...
当FusedScaleMaskSoftmax
执行时,kernel支持fp16时会直接调用fusion算子forward_fused_softmax
;对于不支持的规模时,会调用forward_torch_softmax
进行模拟,输出的类型就根据self.input_in_float16
来进行cast转换。
class FusedScaleMaskSoftmax(nn.Module): ... def forward(self, input, mask): # [b, np, sq, sk] assert input.dim() == 4 if self.is_kernel_available(mask, *input.size()): return self.forward_fused_softmax(input, mask) else: return self.forward_torch_softmax(input, mask) def forward_torch_softmax(self, input, mask): if self.input_in_float16 and self.softmax_in_fp32: input = input.float() if self.scale is not None: input = input * self.scale mask_output = self.mask_func(input, mask) if mask is not None else input probs = torch.nn.Softmax(dim=-1)(mask_output) if self.input_in_float16 and self.softmax_in_fp32: if self.input_in_fp16: probs = probs.half() else: probs = probs.bfloat16() return probs
在ColumnParallelLinear
初始化时创建Parameter中的类型直接按params_dtype(即fp16)
来设。
class ColumnParallelLinear(torch.nn.Module):
def __init__(self, ...,
params_dtype=torch.float32,
...,
):
...
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=params_dtype))
...
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
...
以gpt2模型为例,在megatron/model/gpt_model.py
文件中的post_language_model_processing
函数, 如果指定了fp16_lm_cross_entropy
,那么在计算cross entropy
时会把output
先转为float32
再进行计算loss。
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。