当前位置:   article > 正文

Megatron-LM源码系列(五): FP16使用_coreattention

coreattention

1. FP16参数指定

  • 训练模型要使用fp16时,训练启动参数中指定--fp16, 对应megatron/arguments.py中的定义如下:
    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode.')
  • 1
  • 2
  • 在计算lm-cross-entropy时默认是使用fp32来计算的,在开启--fp16选项的前提下可以通过指定--fp16-lm-cross-entropy来使用fp16计算lm-loss-entropy,对应megatron/arguments.py中的定义如下:
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')
  • 1
  • 2
  • 3
  • 在megatron中跟fp16还有关系的一个参数是args.fp32_residual_connection,这里设置了的话会在计算残差连接的时候转为fp32再进行计算,这里残差连接在网络中对应是Embedding模块。
    if args.fp32_residual_connection:
        assert args.fp16 or args.bf16, \
            'residual connection in fp32 only supported when using fp16 or bf16.'
  • 1
  • 2
  • 3
  • validate_args函数用于check参数有效性,fp16相关实现如下:
def validate_args(args, defaults={}):
    ......
    args.params_dtype = torch.float
    if args.fp16:
        assert not args.bf16
        args.params_dtype = torch.half
    ......
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
    if args.fp32_residual_connection:
        assert args.fp16 or args.bf16, \
            'residual connection in fp32 only supported when using fp16 or bf16.'
    ......
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

如果指定了fp16,这里的args.fp16为True,对应的args.params_dtype参数类型torch.half

2. ParallelAttention模块中fp16计算

2.1 训练部分

ParallelAttention中有self.query_key_valueself.core_attentionself.dense等子模块,fp16对训练的影响会应用在子模块中。

class ParallelAttention(MegatronModule):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(self, init_method,
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
            ...
            self.query_key_value = tensor_parallel.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                bias=args.add_bias_linear,
                gather_output=False,
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
        ...
        self.core_attention = CoreAttention(self.layer_number,
                                            self.attn_mask_type)
        ...                                     
        self.dense = tensor_parallel.RowParallelLinear(
            projection_size,
            args.hidden_size,
            bias=args.add_bias_linear,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=True,
            **_args_to_kwargs())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

对于self.query_key_valueself.dense模块,fp16的设置能过参数中的**_args_to_kwargs()进行传递。

def _args_to_kwargs():
    args = get_args()
    common_kwargs = {
        "params_dtype": args.params_dtype,
        "use_cpu_initialization": args.use_cpu_initialization,
        "perform_initialization": args.perform_initialization,
        "gradient_accumulation_fusion": args.gradient_accumulation_fusion,
        "sequence_parallel_enabled": args.sequence_parallel,
    }
    return common_kwargs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

对于self.core_attention部分,fp16的设置是在CoreAttention__init__self.fp16 = args.fp16

class CoreAttention(MegatronModule):

    def __init__(self, layer_number,
                 attn_mask_type=AttnMaskType.padding):
        super(CoreAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16
        self.bf16 = args.bf16
    ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

2.2 推理部分

ParallelAttention模块本身中fp16会影响推理部分

class ParallelAttention(MegatronModule):
    def __init__(self, init_method,
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        ...
        self.params_dtype = args.params_dtype
        ...

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device()) 

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, inference_params=None,
                rotary_pos_emb=None):
        ...
        if inference_params:
            if self.layer_number not in inference_params.key_value_memory_dict:
                inf_max_seq_len = inference_params.max_sequence_len
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
                    inf_max_seq_len, inf_max_batch_size)
                inference_value_memory = self._allocate_memory(
                    inf_max_seq_len, inf_max_batch_size)
        ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 当指定了fp16以后,在ParallelAttention模型__init__初始化时会设置参数类型self.params_dtype为fp16
  • 在提前分配memory时_allocate_memory中会用torch.empty创建用于推理的大buffer,类型是fp16
  • 在指定推理参数inference_params时,forward函数中会调用_allocate_memory

3. CoreAttention模块中fp16计算

当设了fp16以后,在CoreAttention的forward计算的input就是fp16类型,在init中设置fp16 flag主要是用于计算中用到的FusedScaleMaskSoftmax模块的输出结果类型转换。

class CoreAttention(MegatronModule):

    def __init__(self, layer_number,
                 attn_mask_type=AttnMaskType.padding):
        ...
        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, self.bf16,
            self.attn_mask_type,
            args.masked_softmax_fusion,
            attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)
    ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

FusedScaleMaskSoftmax执行时,kernel支持fp16时会直接调用fusion算子forward_fused_softmax;对于不支持的规模时,会调用forward_torch_softmax进行模拟,输出的类型就根据self.input_in_float16来进行cast转换。

class FusedScaleMaskSoftmax(nn.Module):
    ...
    def forward(self, input, mask):
        # [b, np, sq, sk]
        assert input.dim() == 4

        if self.is_kernel_available(mask, *input.size()):
            return self.forward_fused_softmax(input, mask)
        else:
            return self.forward_torch_softmax(input, mask)

    def forward_torch_softmax(self, input, mask):
        if self.input_in_float16 and self.softmax_in_fp32:
            input = input.float()

        if self.scale is not None:
            input = input * self.scale
        mask_output = self.mask_func(input, mask) if mask is not None else input
        probs = torch.nn.Softmax(dim=-1)(mask_output)

        if self.input_in_float16 and self.softmax_in_fp32:
            if self.input_in_fp16:
                probs = probs.half()
            else:
                probs = probs.bfloat16()

        return probs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

4. ColumnParallelLinear模块中fp16计算

ColumnParallelLinear初始化时创建Parameter中的类型直接按params_dtype(即fp16)来设。

class ColumnParallelLinear(torch.nn.Module):
def __init__(self, ...,
                 params_dtype=torch.float32,
                 ...,
                 ):
            ...
            self.weight = Parameter(torch.empty(
                self.output_size_per_partition, self.input_size,
                device=torch.cuda.current_device(), dtype=params_dtype))
            ...
                self.bias = Parameter(torch.empty(
                    self.output_size_per_partition,
                    device=torch.cuda.current_device(),
                    dtype=params_dtype))
            ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

5. lm-cross-entropy计算

以gpt2模型为例,在megatron/model/gpt_model.py文件中的post_language_model_processing函数, 如果指定了fp16_lm_cross_entropy,那么在计算cross entropy时会把output先转为float32再进行计算loss。

        if fp16_lm_cross_entropy:
            assert output.dtype == torch.half
            loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
        else:
            loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
  • 1
  • 2
  • 3
  • 4
  • 5

参考

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号