赞
踩
PrefixEncoder
# 根据前缀 ID 获取前缀嵌入 # 前缀嵌入将连接到分头之后的 K 和 V 上 class PrefixEncoder(torch.nn.Module): """ The torch.nn model to encode the prefix Input shape: (batch-size, prefix-length) Output shape: (batch-size, prefix-length, 2*layers*hidden) """ def __init__(self, config: ChatGLMConfig): super().__init__() # 控制是否开启前缀投影,即用两层 MLP 处理前缀嵌入 self.prefix_projection = config.prefix_projection if self.prefix_projection: # KVSize = NLayer * 2 * NGroup * HeadSize kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 # 将 ID 变为嵌入的嵌入层,[PreSeqLen, KVSize] self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) # 处理嵌入的 MLP # 映射到 HidSize, 计算 tanh,在映射到 KVSize self.trans = torch.nn.Sequential( torch.nn.Linear(kv_size, config.hidden_size), torch.nn.Tanh(), torch.nn.Linear(config.hidden_size, kv_size) ) else: # 将 ID 变为嵌入的嵌入层 self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.kv_channels * config.multi_query_group_num * 2) def forward(self, prefix: torch.Tensor): # 前缀 ID 尺寸为 [BatchSize, PreSeqLen] # 根据前缀 ID 获取嵌入,尺寸为 [BatchSize, PreSeqLen, KVSize] # 如果设定了需要投影,就用两层 MLP 处理嵌入 if self.prefix_projection: prefix_tokens = self.embedding(prefix) past_key_values = self.trans(prefix_tokens) else: past_key_values = self.embedding(prefix) return past_key_values
ChatGLMPreTrainedModel
class ChatGLMPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ is_parallelizable = False supports_gradient_checkpointing = True config_class = ChatGLMConfig base_model_prefix = "transformer" _no_split_modules = ["GLMBlock"] def _init_weights(self, module: nn.Module): """Initialize the weights.""" return # 从输入单词 ID,KVCache生成默认的(上三角)掩码矩阵 def get_masks(self, input_ids, past_key_values, padding_mask=None): # 单词 ID 尺寸为 [BatchSize, SeqLen] batch_size, seq_length = input_ids.shape # 掩码矩阵初始化为全 1,形状为 [BatchSize, SeqLen, SeqLen],每个输入序列一个 full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) # 保留其下三角元素,其余设为 9 full_attention_mask.tril_() # CacheLen:KVCache 中序列长度 # 如果没有提供则设为 0,如果提供了,从中获取长度 past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[0] # 如果提供了 KVCache,在每个掩码矩阵的上方填充 1,形状为 [BatchSize, SeqLen, CacheSeqLen] if past_length: full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1) # 如果提供了掩码数组([BatchSize, (Cache)SeqLen]) # 将其变形为 [BatchSize, 1, (Cache)SeqLen] # 然后与掩码矩阵相乘 # 将掩码数组为0的列设为0 if padding_mask is not None: full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) # 如果提供了掩码数组,并且没有提供 KVCache # 将其变形为 [BatchSize, SeqLen, 1] # 然后将掩码数组为 0 的行设为 1 if not past_length and padding_mask is not None: full_attention_mask -= padding_mask.unsqueeze(-1) - 1 # 小于 0.5 变成 true,大于 0.5 变成 false,相当于将其翻转,上三角不为 0 full_attention_mask = (full_attention_mask < 0.5).bool() # 分头,变形为 [BatchSize, 1, SeqLen, SeqLen] full_attention_mask.unsqueeze_(1) return full_attention_mask # 从输入单词 ID 生成默认的(从零开始的)序列 ID def get_position_ids(self, input_ids, device): # 单词 ID 尺寸为 [BatchSize, SeqLen] batch_size, seq_length = input_ids.shape # 序列 ID 创建为 0~(SeqLen-1)的一维数组 # 变形为 [1, SeqLen],之后重复第一维 BatchSize 次,得到 [BatchSize, SeqLen] position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) return position_ids def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, GLMTransformer): module.gradient_checkpointing = value
ChatGLMForConditionalGeneration.stream_generate()
@torch.inference_mode() def stream_generate( self, input_ids, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, return_past_key_values=False, **kwargs, ): batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] if generation_config is None: generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) model_kwargs["use_cache"] = generation_config.use_cache bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) elif generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if not has_default_max_length: logger.warn( f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", UserWarning, ) # 如果 SeqLen 大于等于配置里设定的 MaxSeqLen,发出警告 if input_ids_seq_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" " increasing `max_new_tokens`." ) # 如果没有提供 logits 处理器,初始化为空列表 logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() # 没有提供停止标准,初始化为空列表 stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() # 根据生成配置等对象获取 logits 处理器 logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=input_ids, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, ) # 根据生成配置等对象获取停止标准 stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) # 根据生成配置获取 logits 包装器 logits_warper = self._get_logits_warper(generation_config) # 未完成标志,表示每个序列是否生成完毕的数组 # 初始化为 [BatchSize] 尺寸的全 1 数组 unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) scores = None while True: # 根据传入参数组装成字典,请见该方法定义 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # 将单词 ID 传入模型,得到(所有前缀)下一个单词的 logits # [BatchSize, SeqLen, VocabSize] outputs = self( **model_inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) # 截取 SeqLen 维度的最后一维,得到整句话下一个单词的 logits # [BatchSize, VocabSize] next_token_logits = outputs.logits[:, -1, :] # 传入 logits 处理器和包装器,修正 logits next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) # 计算 softmax 得到概率值 probs = nn.functional.softmax(next_token_scores, dim=-1) # 如果设定了需要采样,对其进行多项式采样,样本容量为 1 # 否则直接取最大的 # 得到下个单词 ID,尺寸为 [BatchSize] if generation_config.do_sample: next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(probs, dim=-1) # 下个单词 ID 变形为 [BatchSize, 1],然后和输入单词 ID 拼接 input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) # 根据当前输出更新KVCache、注意力掩码和位置ID model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) # `next_tokens` 变形为 [1, BatchSize],再将第一维重复 NEOS 次,[NEOS, BatchSize] # `eos_token_id_tensor` 变形为 [NEOS, 1],将广播第二维变成 [NEOS, BatchSize] # 之后二者逐元素比较是否不相等,形成一个比较结果,尺寸为 [NEOS, BatchSize] # 之后按照 BatchSize 维度计算乘积,得到未完成标志,[BatchSize] # 如果某个序列等于终止符集合里面的任意一个,那么比较结果就会出现一个 0,未完成标志将会是 0。 unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) # 如果指定了返回 KVCache # 产生输入ID和已生成的输出ID # 和 KVCache # 否则只产生第一个 if return_past_key_values: yield input_ids, outputs.past_key_values else: yield input_ids # 如果未完成标志全为零(表示序列都已生成完毕),或者达到了停止标准,就停止生成 if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): break
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。