FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
Paper: https://arxiv.org/abs/2205.14135
IEEE Spectrum article about our submission to the MLPerf 2.0 benchmark using FlashAttention. FlashAttention
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
Tri Dao
Paper: https://tridao.me/publications/flash2/flash2.pdf
FlashAttention-2 硬件支持
Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100).
Turing GPUs 只能使用FlashAttention 1.x.
Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.
CUDA 11.6 and above
PyTorch 1.12 and above
transformers 4.33.1
torch 2.0.1+cu118
torchaudio 2.0.2+cu118
torchvision 0.15.2+cu118
accelerate 0.22.0
sentencepiece 0.1.99
install flash-attention
pip install flash-attn --no-build-isolation
python setup.py install
模型加载时使用modeling_chatglm.py 而非transformers的AutoModel加载,因为要对modeling中的AttenCore进行修改
from transformers import AutoTokenizer
from modeling_chatglm import ChatGLMModel, ChatGLMForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = ChatGLMForConditionalGeneration.from_pretrained(model_path, trust_remote_code=True).cuda()
model = model.eval()
... FLASH_ATTN_FLAG=True print("inference by Flash attention src:", FLASH_ATTN_FLAG) ... class CoreAttention(torch.nn.Module): ... def forward(self, query_layer, key_layer, value_layer, attention_mask): pytorch_major_version = int(torch.__version__.split('.')[0]) if pytorch_major_version >= 2: if FLASH_ATTN_FLAG: from flash_attn import flash_attn_qkvpacked_func,flash_attn_func query_layer, key_layer, value_layer = [k.permute(1, 0, 2, 3) for k in [query_layer, key_layer, value_layer]] dropout_p=0.0 softmax_scale=0.0 context_layer = flash_attn_func(query_layer, key_layer, value_layer, dropout_p, causal=True) context_layer = context_layer.permute(1, 0, 2, 3) #chatglm2-6b Official code else: query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, is_causal=True) else: if attention_mask is not None: attention_mask = ~attention_mask context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) ...
优化方案 | input tokens | speed | 显存占用(mb) |
pytorch | 1800 | 33.8 | 15472 |
pytorch2.0 | 1800 | 36.5 | 14200 |
flash attention2 | 1800 | 36.7 | 14200 |
pytorch | 7000 | 18 | 37322 |
pytorch2.0 | 7000 | 29.9 | 17030 |
flash attention2 | 7000 | 34.2 | 17102 |
pytorch | 20000 | OOM | OOM |
pytorch2.0 | 20000 | 13.5 | 24122 |
flash attention2 | 20000 | 18.6 | 24194 |
pytorch | 32396 | OOM | OOM |
pytorch2.0 | 32396 | 8 | 30448 |
flash attention2 | 32396 | 14.1 | 30520 |
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。