当前位置:   article > 正文













  1. class TransformerGLM(nn.Module):
  2. def __init__(self, config, device) -> None:
  3. super().__init__()
  4. self.config = config
  5. self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
  6. rotary_dim = (
  7. 128
  8. )
  9. self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
  10. dtype=config.torch_dtype)
  11. self.layers = nn.ModuleList(TransformerBlock(config, i, device) for i in range(config.num_layers))
  12. self.final_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
  13. dtype=config.torch_dtype)
  14. self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  15. self.seq_length = config.seq_length
  16. def forward(self, input_ids,
  17. position_ids: Optional[torch.Tensor] = None,
  18. attention_mask: Optional[torch.BoolTensor] = None,
  19. input_pos=None,
  20. is_input_mask=False,
  21. kv_caches=None
  22. ) -> Tensor:
  23. inputs_embeds = self.embedding(input_ids)
  24. inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
  25. rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
  26. rotary_pos_emb = rotary_pos_emb[position_ids]
  27. rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
  28. presents = ()
  29. for i, layer in enumerate(self.layers):
  30. inputs_embeds, kv_cache = layer(inputs_embeds, rotary_pos_emb=rotary_pos_emb, input_pos=input_pos,
  31. attention_mask=attention_mask, kv_cache=kv_caches[i])
  32. presents = presents + (kv_cache,)
  33. hidden_states = self.final_layernorm(inputs_embeds)
  34. lm_logits = self.output_layer(hidden_states)
  35. lm_logits = lm_logits.transpose(0, 1).contiguous()
  36. return lm_logits, presents



  1. class KVCache(nn.Module):
  2. def __init__(self, max_batch_size, max_seq_length, dtype=torch.bfloat16):
  3. super().__init__()
  4. cache_shape = (2, max_batch_size, max_seq_length, 128)
  5. self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
  6. self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
  7. def update(self, input_pos, k_val, v_val):
  8. # input_pos: S, k_val: [S, B, H, D]
  9. assert input_pos.shape[0] == k_val.shape[0]
  10. k_out = self.k_cache
  11. v_out = self.v_cache
  12. k_val = k_val.transpose(0, 2).contiguous()
  13. v_val = v_val.transpose(0, 2).contiguous()
  14. k_out[:, :, input_pos] = k_val.clone()
  15. v_out[:, :, input_pos] = v_val.clone()
  16. k_out = k_out.transpose(0, 2).contiguous()
  17. v_out = v_out.transpose(0, 2).contiguous()
  18. return k_out, v_out




  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from dataclasses import dataclass
  6. from typing import Optional, Tuple
  7. import torch
  8. import torch.nn as nn
  9. from torch import Tensor
  10. from torch.nn import functional as F
  11. from math import gcd
  12. from functools import reduce
  13. import math
  14. def find_multiple(n: int, *args: Tuple[int]) -> int:
  15. k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,))
  16. if n % k == 0:
  17. return n
  18. return n + k - (n % k)
  19. class CoreAttention(torch.nn.Module):
  20. def __init__(self, config, layer_number):
  21. super(CoreAttention, self).__init__()
  22. self.config = config
  23. self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
  24. self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
  25. self.attention_softmax_in_fp32 = True
  26. self.layer_number = max(1, layer_number)
  27. projection_size = config.kv_channels * config.num_attention_heads
  28. self.hidden_size_per_partition = projection_size
  29. self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
  30. self.num_attention_heads_per_partition = config.num_attention_heads
  31. self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
  32. coeff = self.layer_number
  33. self.norm_factor *= coeff
  34. self.coeff = coeff
  35. self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
  36. def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
  37. query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
  38. if attention_mask is None:
  39. context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
  40. is_causal=True)
  41. else:
  42. attention_mask = ~attention_mask
  43. context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
  44. attention_mask)
  45. context_layer = context_layer.permute(2, 0, 1, 3)
  46. new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
  47. context_layer = context_layer.reshape(*new_context_layer_shape)
  48. return context_layer
  49. def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
  50. # x: [sq, b, np, hn]
  51. sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
  52. rot_dim = rope_cache.shape[-2] * 2
  53. x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
  54. # truncate to support variable sizes
  55. rope_cache = rope_cache[:sq]
  56. xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
  57. rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
  58. x_out2 = torch.stack(
  59. [
  60. xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
  61. xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
  62. ],
  63. -1,
  64. )
  65. x_out2 = x_out2.flatten(3)
  66. return torch.cat((x_out2, x_pass), dim=-1)
  67. class RotaryEmbedding(nn.Module):
  68. def __init__(self, dim, original_impl=False, device=None, dtype=None):
  69. super().__init__()
  70. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
  71. self.register_buffer("inv_freq", inv_freq)
  72. self.dim = dim
  73. self.original_impl = original_impl
  74. def forward_impl(
  75. self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
  76. ):
  77. """Enhanced Transformer with Rotary Position Embedding.
  78. Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
  79. transformers/rope/__init__.py. MIT License:
  80. https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
  81. """
  82. # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
  83. theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
  84. # Create position indexes `[0, 1, ..., seq_len - 1]`
  85. seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
  86. # Calculate the product of position index and $\theta_i$
  87. idx_theta = torch.outer(seq_idx, theta).float()
  88. cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
  89. # this is to mimic the behaviour of complex32, else we will get different results
  90. if dtype in (torch.float16, torch.bfloat16, torch.int8):
  91. cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
  92. return cache
  93. def forward(self, max_seq_len, offset=0):
  94. return self.forward_impl(
  95. max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
  96. )
  97. class KVCache(nn.Module):
  98. def __init__(self, max_batch_size, max_seq_length, dtype=torch.bfloat16):
  99. super().__init__()
  100. cache_shape = (2, max_batch_size, max_seq_length, 128)
  101. self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
  102. self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
  103. def update(self, input_pos, k_val, v_val):
  104. # input_pos: S, k_val: [S, B, H, D]
  105. assert input_pos.shape[0] == k_val.shape[0]
  106. k_out = self.k_cache
  107. v_out = self.v_cache
  108. k_val = k_val.transpose(0, 2).contiguous()
  109. v_val = v_val.transpose(0, 2).contiguous()
  110. k_out[:, :, input_pos] = k_val.clone()
  111. v_out[:, :, input_pos] = v_val.clone()
  112. k_out = k_out.transpose(0, 2).contiguous()
  113. v_out = v_out.transpose(0, 2).contiguous()
  114. return k_out, v_out
  115. class RMSNorm(torch.nn.Module):
  116. def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
  117. super().__init__()
  118. self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
  119. self.eps = eps
  120. def forward(self, hidden_states: torch.Tensor):
  121. input_dtype = hidden_states.dtype
  122. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  123. hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
  124. return (self.weight * hidden_states).to(input_dtype)
  125. class TransformerGLM(nn.Module):
  126. def __init__(self, config, device) -> None:
  127. super().__init__()
  128. self.config = config
  129. self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
  130. rotary_dim = (
  131. 128
  132. )
  133. self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
  134. dtype=config.torch_dtype)
  135. self.layers = nn.ModuleList(TransformerBlock(config, i, device) for i in range(config.num_layers))
  136. self.final_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
  137. dtype=config.torch_dtype)
  138. self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  139. self.seq_length = config.seq_length
  140. def forward(self, input_ids,
  141. position_ids: Optional[torch.Tensor] = None,
  142. attention_mask: Optional[torch.BoolTensor] = None,
  143. input_pos=None,
  144. is_input_mask=False,
  145. kv_caches=None
  146. ) -> Tensor:
  147. inputs_embeds = self.embedding(input_ids)
  148. inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
  149. rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
  150. rotary_pos_emb = rotary_pos_emb[position_ids]
  151. rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
  152. presents = ()
  153. for i, layer in enumerate(self.layers):
  154. inputs_embeds, kv_cache = layer(inputs_embeds, rotary_pos_emb=rotary_pos_emb, input_pos=input_pos,
  155. attention_mask=attention_mask, kv_cache=kv_caches[i])
  156. presents = presents + (kv_cache,)
  157. hidden_states = self.final_layernorm(inputs_embeds)
  158. lm_logits = self.output_layer(hidden_states)
  159. lm_logits = lm_logits.transpose(0, 1).contiguous()
  160. return lm_logits, presents
  161. class MLP(torch.nn.Module):
  162. """MLP.
  163. MLP will take the input with h hidden state, project it to 4*h
  164. hidden dimension, perform nonlinear transformation, and project the
  165. state back into h hidden dimension.
  166. """
  167. def __init__(self, config, device=None):
  168. super(MLP, self).__init__()
  169. self.add_bias = config.add_bias_linear
  170. # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
  171. self.dense_h_to_4h = nn.Linear(
  172. config.hidden_size,
  173. config.ffn_hidden_size * 2,
  174. bias=self.add_bias,
  175. device=device,
  176. # **_config_to_kwargs(config)
  177. )
  178. def swiglu(x):
  179. x = torch.chunk(x, 2, dim=-1)
  180. return F.silu(x[0]) * x[1]
  181. self.activation_func = swiglu
  182. # Project back to h.
  183. self.dense_4h_to_h = nn.Linear(
  184. config.ffn_hidden_size,
  185. config.hidden_size,
  186. bias=self.add_bias,
  187. device=device,
  188. # **_config_to_kwargs(config)
  189. )
  190. def forward(self, hidden_states):
  191. # [s, b, 4hp]
  192. intermediate_parallel = self.dense_h_to_4h(hidden_states)
  193. intermediate_parallel = self.activation_func(intermediate_parallel)
  194. # [s, b, h]
  195. output = self.dense_4h_to_h(intermediate_parallel)
  196. return output
  197. class SelfAttention(torch.nn.Module):
  198. """Parallel self-attention layer abstract class.
  199. Self-attention layer takes input with size [s, b, h]
  200. and returns output of the same size.
  201. """
  202. def __init__(self, config, layer_number, device=None):
  203. super(SelfAttention, self).__init__()
  204. self.config = config
  205. self.layer_number = max(1, layer_number)
  206. self.projection_size = config.kv_channels * config.num_attention_heads
  207. # Per attention head and per partition values.
  208. self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
  209. # 32
  210. self.num_attention_heads_per_partition = config.num_attention_heads
  211. self.multi_query_attention = config.multi_query_attention
  212. self.qkv_hidden_size = 3 * self.projection_size
  213. self.num_multi_query_groups_per_partition = config.multi_query_group_num
  214. self.qkv_hidden_size = (
  215. self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
  216. )
  217. self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
  218. bias=config.add_bias_linear or config.add_qkv_bias,
  219. # device=device, **_config_to_kwargs(config)
  220. )
  221. self.core_attention = CoreAttention(config, self.layer_number)
  222. # Output.
  223. self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
  224. device=device,
  225. # **_config_to_kwargs(config)
  226. )
  227. def forward(
  228. self, hidden_states, rotary_pos_emb, input_pos, attention_mask=None, kv_cache=None
  229. ):
  230. # hidden_states: [sq, b, h]
  231. # =================================================
  232. # Pre-allocate memory for key-values for inference.
  233. # =================================================
  234. # =====================
  235. # Query, Key, and Value
  236. # =====================
  237. # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
  238. mixed_x_layer = self.query_key_value(hidden_states)
  239. (query_layer, key_layer, value_layer) = mixed_x_layer.split(
  240. [
  241. self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
  242. self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
  243. self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
  244. ],
  245. dim=-1,
  246. )
  247. query_layer = query_layer.view(
  248. query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
  249. )
  250. key_layer = key_layer.view(
  251. key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
  252. )
  253. value_layer = value_layer.view(
  254. value_layer.size()[:-1]
  255. + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
  256. )
  257. # apply relative positional encoding (rotary embedding)
  258. query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
  259. key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
  260. # 更新kvcache
  261. cache_k, cache_v = kv_cache
  262. cache_k[input_pos] = key_layer
  263. cache_v[input_pos] = value_layer
  264. key_layer = cache_k.clone()
  265. value_layer = cache_v.clone()
  266. kv_cache = (key_layer, value_layer)
  267. key_layer = key_layer.unsqueeze(-2)
  268. key_layer = key_layer.expand(
  269. -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
  270. )
  271. key_layer = key_layer.contiguous().view(
  272. key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
  273. )
  274. value_layer = value_layer.unsqueeze(-2)
  275. value_layer = value_layer.expand(
  276. -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
  277. )
  278. value_layer = value_layer.contiguous().view(
  279. value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
  280. )
  281. # ==================================
  282. # core attention computation
  283. # ==================================
  284. context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask=attention_mask)
  285. # =================
  286. # Output. [sq, b, h]
  287. # =================
  288. output = self.dense(context_layer)
  289. return output, kv_cache
  290. class TransformerBlock(nn.Module):
  291. def __init__(self, config, layer_number, device) -> None:
  292. super().__init__()
  293. self.hidden_dropout = config.hidden_dropout
  294. self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
  295. dtype=config.torch_dtype)
  296. self.self_attention = SelfAttention(config, layer_number, device=device)
  297. self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
  298. dtype=config.torch_dtype)
  299. self.mlp = MLP(config, device=device)
  300. def forward(self, hidden_states, rotary_pos_emb, input_pos, attention_mask=None, kv_cache=None):
  301. layernorm_output = self.input_layernorm(hidden_states)
  302. attention_output, kv_cache = self.self_attention(
  303. layernorm_output,
  304. rotary_pos_emb,
  305. input_pos,
  306. attention_mask=attention_mask,
  307. kv_cache=kv_cache
  308. )
  309. residual = hidden_states
  310. layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
  311. layernorm_input = residual + layernorm_input
  312. layernorm_output = self.post_attention_layernorm(layernorm_input)
  313. mlp_output = self.mlp(layernorm_output)
  314. residual = layernorm_input
  315. output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
  316. output = residual + output
  317. return output, kv_cache
  318. class RMSNorm(torch.nn.Module):
  319. def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
  320. super().__init__()
  321. self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
  322. self.eps = eps
  323. def forward(self, hidden_states: torch.Tensor):
  324. input_dtype = hidden_states.dtype
  325. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  326. hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
  327. return (self.weight * hidden_states).to(input_dtype)
  328. if __name__ == '__main__':
  329. import os
  330. os.environ['CUDA_VISIBLE_DEVICES'] = "1"
  331. from transformers import AutoConfig
  332. model_path = "./chatglm2-6b-merge"
  333. config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
  334. model = TransformerGLM(config, device=None)
  335. for name, _ in model.named_parameters():
  336. print(name)


推理方法,也就是重写transformer模型中的generate这个方法,对于一次生成可以分为第一次解码forward阶段和余下的解码forward阶段。实现分别如下,只实现了greedy search 策略:

  1. @torch.no_grad()
  2. def first_decode_batch(model, input_ids, position_ids, input_pos, attention_mask, kv_caches):
  3. logits, kv_caches = model(input_ids=input_ids, position_ids=position_ids, input_pos=input_pos, is_input_mask=False,
  4. attention_mask=attention_mask, kv_caches=kv_caches)
  5. logits = logits[:, -1:]
  6. next_tok = torch.argmax(logits, dim=-1)
  7. return next_tok, kv_caches
  8. @torch.no_grad()
  9. def decode_one_token_batch(model, input_ids, position_ids, input_pos, attention_mask, kv_caches):
  10. logits, kv_caches = model(input_ids, position_ids=position_ids, input_pos=input_pos, is_input_mask=True,
  11. attention_mask=attention_mask, kv_caches=kv_caches)
  12. logits = logits[:, -1:]
  13. next_tok = torch.argmax(logits, dim=-1)
  14. return next_tok, kv_caches



  1. def generate_own_batch(model,
  2. inputs,
  3. sampling_kwargs,
  4. eos_token,
  5. max_seq_length, max_batch_size):
  6. device = inputs['input_ids'].device
  7. cache_shape = (max_seq_length, max_batch_size, 2, 128)
  8. dtype = torch.bfloat16
  9. kv_caches = [(torch.zeros(cache_shape, dtype=dtype).to(device), torch.zeros(cache_shape, dtype=dtype).to(device))
  10. for _ in range(model.config.num_layers)]
  11. input_ids = inputs['input_ids']
  12. ori_input_ids = input_ids.clone()
  13. position_ids = inputs['position_ids']
  14. input_pos = []
  15. for _ in range(max_batch_size):
  16. pos = list(range(0,input_ids.shape[1]))
  17. input_pos.append(pos)
  18. input_pos = torch.tensor(input_pos, device=input_ids.device)
  19. # input_pos = torch.arange(0, input_ids.shape[1], device=input_ids.device)
  20. next_token, kv_caches = first_decode_batch(model, input_ids, position_ids, input_pos, None, kv_caches)
  21. full_attention_mask = torch.ones(max_batch_size, 1, 1, max_seq_length).to(device).bool()
  22. full_attention_mask[:, :, :, input_pos] = False
  23. # pading部分为true
  24. for i in range(full_attention_mask.shape[0]):
  25. for j in range(input_ids.shape[1]):
  26. if input_ids[i, j] == 0:
  27. full_attention_mask[i, :, :, j] = True
  28. input_ids = torch.cat((input_ids, next_token.clone()), dim=1)
  29. num_new_tokens = sampling_kwargs["max_length"]
  30. T = input_ids.size()[1]
  31. position_ids = position_ids[:,-1:]
  32. input_pos = []
  33. for _ in range(max_batch_size):
  34. pos = [T]
  35. input_pos.append(pos)
  36. input_pos = torch.tensor(input_pos, device=next_token.device, dtype=torch.long)
  37. # position_ids = torch.tensor([[T - 1]], device=next_token.device, dtype=torch.long)
  38. # input_pos = torch.tensor([T], device=input_ids.device, dtype=torch.long)
  39. for i in range(num_new_tokens):
  40. input_pos += 1
  41. # Actually better for Inductor to codegen attention here
  42. with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
  43. full_attention_mask[:, :, :, input_pos] = False
  44. next_token, kv_caches = decode_one_token_batch(model, next_token, position_ids, input_pos,
  45. full_attention_mask, kv_caches)
  46. input_ids = torch.cat((input_ids, next_token.clone()), dim=1)
  47. if (input_ids == eos_token).sum(dim=1).all():
  48. break
  49. position_ids += 1
  50. # token = next_token.tolist()
  51. # token = next_token.tolist()[0]
  52. # generated_tokens.append(token)
  53. return input_ids, ori_input_ids


  1. model = TransformerGLM(config=config, device=None)
  2. checkpoint_dir = Path(model_path)
  3. model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
  4. converted_state_dict = model.state_dict()
  5. gen_kwargs = {"max_length": 200, "num_beams": 1,
  6. "do_sample": False, "top_p": 0.8,
  7. "temperature": 0.95
  8. }
  9. device = "cuda:0"
  10. model_path = "./chatglm2-6b-merge"
  11. tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
  12. eos_token = tokenizer.eos_token_id
  13. ......
  14. # 编译加速
  15. global first_decode_batch, decode_one_token_batch
  16. decode_one_token_batch = torch.compile(decode_one_token_batch, mode="reduce-overhead", fullgraph=True)
  17. first_decode_batch = torch.compile(first_decode_batch, dynamic=True, fullgraph=True)
  18. ......
  19. generate_own_batch(model, inputs, gen_kwargs, eos_token, max_seq_length, max_batch_size)



展示一下glm模型ori、compile、和compile+int8,bs=1,max_seq_length= 1000的情况下的推理速度和效果的对比,6B glm模型,模型输入prompt如下:





可以看到相同的模型和相同的数据在bs=1下,原始模型推理速度31.7 tokens/s,compile的推理速度68.1 tokens/s,110.9 tokes/s;加速效果确实比较明显。

业务领域上的实验,这里也可以给一个结论,数据就不展示了,业务上生成的token数目每次推理大都在20 tokens以内,结果如下:






