赞
踩
目录
torch2.0发布以后模型训练和推理可以实现一行代码加速,试用之后发现效果并不明显。随后gptfast项目也发布,表明它确实是可以实现模型推理的加速,看来之前试用是打开方式不对。最近参考gptfast项目,实现了对ChatGLM模型推理的加速,主要的原理是借助torch.compile对模型推理过程中构建计算图,实现加速。本文的重点工作就是展示模型代码和推理逻辑的迁移实现,以及加速效果的对比,当然这个方案比VLLM和tensort-LLM肯定是差了点,这个不是本文的重点,后面有空了也把vllm和tensort-LLM也写写博客对比一下效率。
这个工作是真的不是特别好做,需要对模型结构和模型输入输出非常熟悉,同时也要对gptfast项目迁移原则比较熟悉,才能比较快的迁移成功。核心原则是不能有tensor切片操作,同时kvcache这种也要写成固定的长度,计算过程中不断的去填充更新,同时还要放在模型的结构外侧作为一个参数传入,加速才有效果。还有一个点要注意注意力计算的实现,由于torch更新了scaled_dot_product_attention使得最大长度的定长的矩阵计算注意力,和之前动态逐步增加长度的值是一样的,这个是注意力计算中tensor切片改写的前提(验证过确实是一样的)。细节的地方需要注意kvcache的维度形状,解码过程中不同阶段(首次forward和kvcache存在后的)模型输入的full_attention_mask是不一样的。
整体结构
- class TransformerGLM(nn.Module):
- def __init__(self, config, device) -> None:
- super().__init__()
- self.config = config
-
- self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
- rotary_dim = (
- 128
- )
- self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
- dtype=config.torch_dtype)
- self.layers = nn.ModuleList(TransformerBlock(config, i, device) for i in range(config.num_layers))
- self.final_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
- dtype=config.torch_dtype)
-
- self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- self.seq_length = config.seq_length
-
- def forward(self, input_ids,
- position_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.BoolTensor] = None,
- input_pos=None,
- is_input_mask=False,
- kv_caches=None
- ) -> Tensor:
-
- inputs_embeds = self.embedding(input_ids)
- inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
- rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
- rotary_pos_emb = rotary_pos_emb[position_ids]
- rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
-
- presents = ()
- for i, layer in enumerate(self.layers):
- inputs_embeds, kv_cache = layer(inputs_embeds, rotary_pos_emb=rotary_pos_emb, input_pos=input_pos,
- attention_mask=attention_mask, kv_cache=kv_caches[i])
- presents = presents + (kv_cache,)
- hidden_states = self.final_layernorm(inputs_embeds)
- lm_logits = self.output_layer(hidden_states)
- lm_logits = lm_logits.transpose(0, 1).contiguous()
- return lm_logits, presents
注意模型的输入新增的有input_pos,模型解码token的位置,kv_caches;模型基本模块上没有变化,精简其中的一下预处理逻辑和分支,主要就是要让torch.compile()能完成计算图的构建。
kvcache模块
- class KVCache(nn.Module):
- def __init__(self, max_batch_size, max_seq_length, dtype=torch.bfloat16):
- super().__init__()
- cache_shape = (2, max_batch_size, max_seq_length, 128)
- self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
- self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
-
- def update(self, input_pos, k_val, v_val):
- # input_pos: S, k_val: [S, B, H, D]
- assert input_pos.shape[0] == k_val.shape[0]
- k_out = self.k_cache
-
- v_out = self.v_cache
- k_val = k_val.transpose(0, 2).contiguous()
-
- v_val = v_val.transpose(0, 2).contiguous()
- k_out[:, :, input_pos] = k_val.clone()
- v_out[:, :, input_pos] = v_val.clone()
- k_out = k_out.transpose(0, 2).contiguous()
-
- v_out = v_out.transpose(0, 2).contiguous()
-
- return k_out, v_out
模块中各个变量的维度信息都标注好了,作用就是kv缓存载体以及更新逻辑提供一个方法。
其他模块就不一一介绍了,注意selfattention中kvcache的更新
整个模型的代码如下:
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
-
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- from dataclasses import dataclass
- from typing import Optional, Tuple
-
- import torch
- import torch.nn as nn
- from torch import Tensor
- from torch.nn import functional as F
- from math import gcd
- from functools import reduce
- import math
-
-
- def find_multiple(n: int, *args: Tuple[int]) -> int:
- k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,))
- if n % k == 0:
- return n
- return n + k - (n % k)
-
-
- class CoreAttention(torch.nn.Module):
- def __init__(self, config, layer_number):
- super(CoreAttention, self).__init__()
- self.config = config
- self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
- self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
-
- self.attention_softmax_in_fp32 = True
- self.layer_number = max(1, layer_number)
-
- projection_size = config.kv_channels * config.num_attention_heads
-
- self.hidden_size_per_partition = projection_size
- self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
- self.num_attention_heads_per_partition = config.num_attention_heads
-
- self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
- coeff = self.layer_number
- self.norm_factor *= coeff
- self.coeff = coeff
- self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
-
- def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
- query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
- if attention_mask is None:
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
- is_causal=True)
- else:
- attention_mask = ~attention_mask
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
- attention_mask)
-
- context_layer = context_layer.permute(2, 0, 1, 3)
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
- context_layer = context_layer.reshape(*new_context_layer_shape)
-
- return context_layer
-
-
- def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
- # x: [sq, b, np, hn]
- sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
- rot_dim = rope_cache.shape[-2] * 2
- x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
- # truncate to support variable sizes
- rope_cache = rope_cache[:sq]
- xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
- rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
- x_out2 = torch.stack(
- [
- xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
- xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
- ],
- -1,
- )
- x_out2 = x_out2.flatten(3)
- return torch.cat((x_out2, x_pass), dim=-1)
-
-
- class RotaryEmbedding(nn.Module):
- def __init__(self, dim, original_impl=False, device=None, dtype=None):
- super().__init__()
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
- self.register_buffer("inv_freq", inv_freq)
- self.dim = dim
- self.original_impl = original_impl
-
- def forward_impl(
- self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
- ):
- """Enhanced Transformer with Rotary Position Embedding.
- Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
- transformers/rope/__init__.py. MIT License:
- https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
- """
- # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
- theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
-
- # Create position indexes `[0, 1, ..., seq_len - 1]`
- seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
-
- # Calculate the product of position index and $\theta_i$
- idx_theta = torch.outer(seq_idx, theta).float()
-
- cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
-
- # this is to mimic the behaviour of complex32, else we will get different results
- if dtype in (torch.float16, torch.bfloat16, torch.int8):
- cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
- return cache
-
- def forward(self, max_seq_len, offset=0):
- return self.forward_impl(
- max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
- )
-
-
- class KVCache(nn.Module):
- def __init__(self, max_batch_size, max_seq_length, dtype=torch.bfloat16):
- super().__init__()
- cache_shape = (2, max_batch_size, max_seq_length, 128)
- self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
- self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
-
- def update(self, input_pos, k_val, v_val):
- # input_pos: S, k_val: [S, B, H, D]
- assert input_pos.shape[0] == k_val.shape[0]
- k_out = self.k_cache
-
- v_out = self.v_cache
- k_val = k_val.transpose(0, 2).contiguous()
-
- v_val = v_val.transpose(0, 2).contiguous()
- k_out[:, :, input_pos] = k_val.clone()
- v_out[:, :, input_pos] = v_val.clone()
- k_out = k_out.transpose(0, 2).contiguous()
-
- v_out = v_out.transpose(0, 2).contiguous()
-
- return k_out, v_out
-
-
- class RMSNorm(torch.nn.Module):
- def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
- super().__init__()
- self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
- self.eps = eps
-
- def forward(self, hidden_states: torch.Tensor):
- input_dtype = hidden_states.dtype
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
-
- return (self.weight * hidden_states).to(input_dtype)
-
-
- class TransformerGLM(nn.Module):
- def __init__(self, config, device) -> None:
- super().__init__()
- self.config = config
-
- self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
- rotary_dim = (
- 128
- )
- self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
- dtype=config.torch_dtype)
- self.layers = nn.ModuleList(TransformerBlock(config, i, device) for i in range(config.num_layers))
- self.final_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
- dtype=config.torch_dtype)
-
- self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- self.seq_length = config.seq_length
-
- def forward(self, input_ids,
- position_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.BoolTensor] = None,
- input_pos=None,
- is_input_mask=False,
- kv_caches=None
- ) -> Tensor:
-
- inputs_embeds = self.embedding(input_ids)
- inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
- rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
- rotary_pos_emb = rotary_pos_emb[position_ids]
- rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
-
- presents = ()
- for i, layer in enumerate(self.layers):
- inputs_embeds, kv_cache = layer(inputs_embeds, rotary_pos_emb=rotary_pos_emb, input_pos=input_pos,
- attention_mask=attention_mask, kv_cache=kv_caches[i])
- presents = presents + (kv_cache,)
- hidden_states = self.final_layernorm(inputs_embeds)
- lm_logits = self.output_layer(hidden_states)
- lm_logits = lm_logits.transpose(0, 1).contiguous()
- return lm_logits, presents
-
-
- class MLP(torch.nn.Module):
- """MLP.
- MLP will take the input with h hidden state, project it to 4*h
- hidden dimension, perform nonlinear transformation, and project the
- state back into h hidden dimension.
- """
-
- def __init__(self, config, device=None):
- super(MLP, self).__init__()
-
- self.add_bias = config.add_bias_linear
-
- # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
- self.dense_h_to_4h = nn.Linear(
- config.hidden_size,
- config.ffn_hidden_size * 2,
- bias=self.add_bias,
- device=device,
- # **_config_to_kwargs(config)
- )
-
- def swiglu(x):
- x = torch.chunk(x, 2, dim=-1)
- return F.silu(x[0]) * x[1]
-
- self.activation_func = swiglu
-
- # Project back to h.
- self.dense_4h_to_h = nn.Linear(
- config.ffn_hidden_size,
- config.hidden_size,
- bias=self.add_bias,
- device=device,
- # **_config_to_kwargs(config)
- )
-
- def forward(self, hidden_states):
- # [s, b, 4hp]
- intermediate_parallel = self.dense_h_to_4h(hidden_states)
- intermediate_parallel = self.activation_func(intermediate_parallel)
- # [s, b, h]
- output = self.dense_4h_to_h(intermediate_parallel)
- return output
-
-
- class SelfAttention(torch.nn.Module):
- """Parallel self-attention layer abstract class.
- Self-attention layer takes input with size [s, b, h]
- and returns output of the same size.
- """
-
- def __init__(self, config, layer_number, device=None):
- super(SelfAttention, self).__init__()
- self.config = config
- self.layer_number = max(1, layer_number)
-
- self.projection_size = config.kv_channels * config.num_attention_heads
-
- # Per attention head and per partition values.
- self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
- # 32
- self.num_attention_heads_per_partition = config.num_attention_heads
-
- self.multi_query_attention = config.multi_query_attention
- self.qkv_hidden_size = 3 * self.projection_size
-
- self.num_multi_query_groups_per_partition = config.multi_query_group_num
- self.qkv_hidden_size = (
- self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
- )
- self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
- bias=config.add_bias_linear or config.add_qkv_bias,
- # device=device, **_config_to_kwargs(config)
- )
-
- self.core_attention = CoreAttention(config, self.layer_number)
-
- # Output.
- self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
- device=device,
- # **_config_to_kwargs(config)
- )
-
- def forward(
- self, hidden_states, rotary_pos_emb, input_pos, attention_mask=None, kv_cache=None
- ):
- # hidden_states: [sq, b, h]
-
- # =================================================
- # Pre-allocate memory for key-values for inference.
- # =================================================
- # =====================
- # Query, Key, and Value
- # =====================
-
- # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
- mixed_x_layer = self.query_key_value(hidden_states)
-
- (query_layer, key_layer, value_layer) = mixed_x_layer.split(
- [
- self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
- self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
- self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
- ],
- dim=-1,
- )
-
- query_layer = query_layer.view(
- query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
- )
- key_layer = key_layer.view(
- key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
- )
- value_layer = value_layer.view(
- value_layer.size()[:-1]
- + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
- )
-
- # apply relative positional encoding (rotary embedding)
- query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
- key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
-
- # 更新kvcache
- cache_k, cache_v = kv_cache
- cache_k[input_pos] = key_layer
- cache_v[input_pos] = value_layer
- key_layer = cache_k.clone()
- value_layer = cache_v.clone()
- kv_cache = (key_layer, value_layer)
-
- key_layer = key_layer.unsqueeze(-2)
- key_layer = key_layer.expand(
- -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
- )
- key_layer = key_layer.contiguous().view(
- key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
- )
- value_layer = value_layer.unsqueeze(-2)
- value_layer = value_layer.expand(
- -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
- )
- value_layer = value_layer.contiguous().view(
- value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
- )
-
- # ==================================
- # core attention computation
- # ==================================
-
- context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask=attention_mask)
-
- # =================
- # Output. [sq, b, h]
- # =================
-
- output = self.dense(context_layer)
-
- return output, kv_cache
-
-
- class TransformerBlock(nn.Module):
- def __init__(self, config, layer_number, device) -> None:
- super().__init__()
- self.hidden_dropout = config.hidden_dropout
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
- dtype=config.torch_dtype)
- self.self_attention = SelfAttention(config, layer_number, device=device)
- self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
- dtype=config.torch_dtype)
- self.mlp = MLP(config, device=device)
-
- def forward(self, hidden_states, rotary_pos_emb, input_pos, attention_mask=None, kv_cache=None):
- layernorm_output = self.input_layernorm(hidden_states)
- attention_output, kv_cache = self.self_attention(
- layernorm_output,
- rotary_pos_emb,
- input_pos,
- attention_mask=attention_mask,
- kv_cache=kv_cache
- )
- residual = hidden_states
- layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
- layernorm_input = residual + layernorm_input
- layernorm_output = self.post_attention_layernorm(layernorm_input)
- mlp_output = self.mlp(layernorm_output)
- residual = layernorm_input
- output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
- output = residual + output
- return output, kv_cache
-
-
- class RMSNorm(torch.nn.Module):
- def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
- super().__init__()
- self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
- self.eps = eps
-
- def forward(self, hidden_states: torch.Tensor):
- input_dtype = hidden_states.dtype
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
-
- return (self.weight * hidden_states).to(input_dtype)
-
-
- if __name__ == '__main__':
- import os
-
- os.environ['CUDA_VISIBLE_DEVICES'] = "1"
- from transformers import AutoConfig
-
- model_path = "./chatglm2-6b-merge"
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
- model = TransformerGLM(config, device=None)
- for name, _ in model.named_parameters():
- print(name)
推理方法,也就是重写transformer模型中的generate这个方法,对于一次生成可以分为第一次解码forward阶段和余下的解码forward阶段。实现分别如下,只实现了greedy search 策略:
- @torch.no_grad()
- def first_decode_batch(model, input_ids, position_ids, input_pos, attention_mask, kv_caches):
- logits, kv_caches = model(input_ids=input_ids, position_ids=position_ids, input_pos=input_pos, is_input_mask=False,
- attention_mask=attention_mask, kv_caches=kv_caches)
- logits = logits[:, -1:]
- next_tok = torch.argmax(logits, dim=-1)
- return next_tok, kv_caches
-
-
- @torch.no_grad()
- def decode_one_token_batch(model, input_ids, position_ids, input_pos, attention_mask, kv_caches):
- logits, kv_caches = model(input_ids, position_ids=position_ids, input_pos=input_pos, is_input_mask=True,
- attention_mask=attention_mask, kv_caches=kv_caches)
- logits = logits[:, -1:]
- next_tok = torch.argmax(logits, dim=-1)
- return next_tok, kv_caches
主要是得到解码过程中模型输出的token和kv_caches。特别要注意的是不能把这两个方法封装到一个类中,然后再进行torch.compile这样模型能正确输出结果,但是推理速度没有提升的,也就是torch.compile并没有生效。
整体的generate逻辑,包含停止符号,模型的初始输入、kvcaches初始化以及attention_mask输入的变化、position_ids的输入变化,batch推理是padding的加入。
- def generate_own_batch(model,
- inputs,
- sampling_kwargs,
- eos_token,
- max_seq_length, max_batch_size):
- device = inputs['input_ids'].device
- cache_shape = (max_seq_length, max_batch_size, 2, 128)
- dtype = torch.bfloat16
- kv_caches = [(torch.zeros(cache_shape, dtype=dtype).to(device), torch.zeros(cache_shape, dtype=dtype).to(device))
- for _ in range(model.config.num_layers)]
-
- input_ids = inputs['input_ids']
-
- ori_input_ids = input_ids.clone()
- position_ids = inputs['position_ids']
-
- input_pos = []
- for _ in range(max_batch_size):
- pos = list(range(0,input_ids.shape[1]))
- input_pos.append(pos)
- input_pos = torch.tensor(input_pos, device=input_ids.device)
-
- # input_pos = torch.arange(0, input_ids.shape[1], device=input_ids.device)
- next_token, kv_caches = first_decode_batch(model, input_ids, position_ids, input_pos, None, kv_caches)
-
- full_attention_mask = torch.ones(max_batch_size, 1, 1, max_seq_length).to(device).bool()
- full_attention_mask[:, :, :, input_pos] = False
-
- # pading部分为true
- for i in range(full_attention_mask.shape[0]):
- for j in range(input_ids.shape[1]):
- if input_ids[i, j] == 0:
- full_attention_mask[i, :, :, j] = True
-
- input_ids = torch.cat((input_ids, next_token.clone()), dim=1)
- num_new_tokens = sampling_kwargs["max_length"]
- T = input_ids.size()[1]
-
- position_ids = position_ids[:,-1:]
-
- input_pos = []
- for _ in range(max_batch_size):
- pos = [T]
- input_pos.append(pos)
- input_pos = torch.tensor(input_pos, device=next_token.device, dtype=torch.long)
-
- # position_ids = torch.tensor([[T - 1]], device=next_token.device, dtype=torch.long)
- # input_pos = torch.tensor([T], device=input_ids.device, dtype=torch.long)
-
- for i in range(num_new_tokens):
- input_pos += 1
- # Actually better for Inductor to codegen attention here
- with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
- full_attention_mask[:, :, :, input_pos] = False
- next_token, kv_caches = decode_one_token_batch(model, next_token, position_ids, input_pos,
- full_attention_mask, kv_caches)
-
- input_ids = torch.cat((input_ids, next_token.clone()), dim=1)
-
- if (input_ids == eos_token).sum(dim=1).all():
- break
-
- position_ids += 1
- # token = next_token.tolist()
- # token = next_token.tolist()[0]
- # generated_tokens.append(token)
- return input_ids, ori_input_ids
推理核心逻辑
- model = TransformerGLM(config=config, device=None)
- checkpoint_dir = Path(model_path)
- model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
- converted_state_dict = model.state_dict()
-
- gen_kwargs = {"max_length": 200, "num_beams": 1,
- "do_sample": False, "top_p": 0.8,
- "temperature": 0.95
- }
- device = "cuda:0"
- model_path = "./chatglm2-6b-merge"
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
- eos_token = tokenizer.eos_token_id
- ......
- # 编译加速
- global first_decode_batch, decode_one_token_batch
- decode_one_token_batch = torch.compile(decode_one_token_batch, mode="reduce-overhead", fullgraph=True)
- first_decode_batch = torch.compile(first_decode_batch, dynamic=True, fullgraph=True)
-
- ......
- generate_own_batch(model, inputs, gen_kwargs, eos_token, max_seq_length, max_batch_size)
核心点在于把解码两阶段的函数使用torch.compile函数包裹一下,真实解码过程中就会进行解码加速。
展示一下glm模型ori、compile、和compile+int8,bs=1,max_seq_length= 1000的情况下的推理速度和效果的对比,6B glm模型,模型输入prompt如下:
[ "你好", "你是谁呀?", "你能做什么呀?", "你真厉害", "真棒呀", "再见了", "给我推荐一部电影", "你知道明天天气怎么样吗?" ]
ori原始transformer的推理效果如下:
使用compile后效果如下:
compile+int8效果如下:
可以看到相同的模型和相同的数据在bs=1下,原始模型推理速度31.7 tokens/s,compile的推理速度68.1 tokens/s,110.9 tokes/s;加速效果确实比较明显。
业务领域上的实验,这里也可以给一个结论,数据就不展示了,业务上生成的token数目每次推理大都在20 tokens以内,结果如下:
这次的分享就到这里为止了,这个迁移后的模型和推理在我们公司的服务端还有个问题,我们服务端采用的多进程异步来实现web服务的,这个gptfast的服务化的集成显示int8不生效,而且bs=1时候的推理加速并没有线下加速效果明显,具体原因一直没有弄明白,可能是其他进程占用服务器资源,导致torch.compile加速失效或者降低。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。