赞
踩
目前,大多数LLMs的上下文窗口限制在4k个标记左右,这意味着模型在处理超过这个长度的文本时性能会下降。这种限制对于需要大量上下文信息的场景,虽然可以通过在更长的文本上进行微调来将预训练LLM的上下文窗口扩展上下文窗口,但要进一步扩展上下文窗口面临着三个主要挑战:
paper:LongRoPE: Extending LLM Context Window Beyond 2 Million Tokens
link:https://arxiv.org/abs/2402.13753
位置插值中的非均匀性问题是指在扩展大型语言模型(LLMs)的上下文窗口时,如何有效地为新增的token位置分配位置嵌入(positional embeddings),以便模型能够在更长的序列上保持或提升性能。在LongRoPE这篇文章中,作者们发现并利用了两种主要的非均匀性,以改进位置插值方法:
RoPE维度的非均匀性:
Token位置的非均匀性:
LongRoPE采用了以下方法解决这些非均匀性问题:
import torch import torch.nn as nn import torch.optim as optim import random import numpy as np import gzip import io class RoPEPositionalEncoding(nn.Module): """ Rotary Position Encoding (RoPE) module. """ def __init__(self, d_model, max_len=5000, base=10000): super().__init__() self.d_model = d_model self.max_len = max_len self.base = base self.theta = torch.tensor( [base ** (-2 * (i // 2) / d_model) for i in range(d_model)] ) def forward(self, positions): angles = positions.unsqueeze(-1) * self.theta return torch.stack([angles.cos(), angles.sin()], dim=-1).flatten(-2) def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat): """ Perform non-uniform interpolation on position embeddings. Args: pos_embed (torch.Tensor): Position embeddings. extension_ratio (float): Extension ratio for context window. lambda_factors (list): Lambda factors for interpolation. n_hat (int): Threshold for applying interpolation. Returns: torch.Tensor: Interpolated position embeddings. """ d_model = pos_embed.shape[-1] interpolated_pos = pos_embed.clone() for i in range(d_model // 2): mask = torch.arange(pos_embed.shape[-2]) < n_hat scale = torch.where( mask, torch.ones_like(pos_embed[..., 0]), 1 / lambda_factors[i] ) interpolated_pos[..., i * 2] *= scale interpolated_pos[..., i * 2 + 1] *= scale return interpolated_pos def search_lambda_factors( model, data, extension_ratio, population_size, num_mutations, num_crossovers, max_iterations, ): """ Search for optimal lambda factors using evolutionary search. Args: model (nn.Module): LongRoPE model. data (list): List of input sequences. extension_ratio (float): Extension ratio for context window. population_size (int): Size of the population for evolutionary search. num_mutations (int): Number of mutations per iteration. num_crossovers (int): Number of crossovers per iteration. max_iterations (int): Maximum number of iterations for evolutionary search. Returns: list: Optimal lambda factors found by the search. """ population = initialize_population(population_size, extension_ratio) for i in range(max_iterations): perplexities = evaluate_population(model, data, population) parents = select_topk(population, perplexities, k=population_size // 2) population = mutate(parents, num_mutations) + crossover(parents, num_crossovers) return min(population, key=lambda x: evaluate_individual(model, data, x)) def progressive_extension(model, data, base_length, target_length): """ Progressively extend the context window of the model. Args: model (nn.Module): LongRoPE model. data (list): List of input sequences. base_length (int): Base context window length. target_length (int): Target context window length. Returns: tuple: (Extended model, lambda factors, base lambda factors) """ curr_model = model curr_length = base_length while curr_length < target_length: lambda_factors, n_hat = search_lambda_factors( curr_model, data, curr_length / base_length ) curr_model = fine_tune(curr_model, data, curr_length, lambda_factors, n_hat) curr_length *= 2 lambda_factors_base, _ = search_lambda_factors( curr_model, data, curr_length / base_length, max_length=base_length ) return curr_model, lambda_factors, lambda_factors_base class LongRoPEModel(nn.Module): """ Long Range Rotary Position Encoding (LongRoPE) model. This model extends the context window of transformer-based models beyond the typical limit by using non-uniform interpolation of rotary position embeddings. It enables the model to handle longer input sequences while maintaining the ability to capture long-range dependencies. Attributes: d_model (int): Dimension of the model. n_heads (int): Number of attention heads. num_layers (int): Number of transformer layers. max_len (int): Maximum sequence length. rope (RoPEPositionalEncoding): Rotary Position Encoding (RoPE) module. transformers (nn.ModuleList): List of transformer encoder layers. lambda_factors (list): Lambda factors for non-uniform interpolation. lambda_factors_base (list): Lambda factors for the base model. extension_ratio (float): Extension ratio for the context window. n_hat (int): Threshold for applying interpolation. Methods: forward(input_ids): Perform forward pass on the input sequence. Args: input_ids (torch.Tensor): Input sequence tensor. Returns: torch.Tensor: Output embeddings from the model. extend_context(data_path, target_length, max_sequence_length, tokenizer): Extend the context window of the model. Args: data_path (str): Path to the input data file. target_length (int): Target context window length. max_sequence_length (int): Maximum sequence length for input data. tokenizer: Tokenizer object for encoding input data. Returns: LongRoPEModel: Extended LongRoPE model. """ def __init__(self, d_model, n_heads, num_layers, max_len=5000): super().__init__() self.d_model = d_model self.num_layers = num_layers self.rope = RoPEPositionalEncoding(d_model, max_len) self.transformers = nn.ModuleList( [ nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads) for _ in range(num_layers) ] ) self.lambda_factors = None self.lambda_factors_base = None def forward(self, input_ids): positions = torch.arange(input_ids.size(1), device=input_ids.device) pos_embeddings = self.rope(positions) if self.lambda_factors is not None: pos_embeddings = non_uniform_interpolation( pos_embeddings, self.extension_ratio, self.lambda_factors, self.n_hat ) input_embeddings = input_ids + pos_embeddings for transformer in self.transformers: input_embeddings = transformer(input_embeddings) return input_embeddings def extend_context(self, data_path, target_length, max_sequence_length, tokenizer): """ Extend the context window of the model. Args: data_path (str): Path to the input data file. target_length (int): Target context window length. max_sequence_length (int): Maximum sequence length for input data. tokenizer: Tokenizer object for encoding input data. Returns: LongRoPEModel: Extended LongRoPE model. """ if tokenizer is None: raise ValueError("Tokenizer is required for extending context.") self.extension_ratio = target_length / self.rope.max_len data = load_data(data_path, tokenizer, max_sequence_length) model, lambda_factors, lambda_factors_base = progressive_extension( self, data, self.rope.max_len, target_length ) self.lambda_factors = lambda_factors self.lambda_factors_base = lambda_factors_base self.n_hat = self.rope.max_len // 2 return model def load_data(data_path, tokenizer, max_sequence_length): """ Load and preprocess the input data. Args: data_path (str): Path to the input data file. tokenizer: Tokenizer object for encoding input data. max_sequence_length (int): Maximum sequence length for input data. Returns: list: List of preprocessed input sequences. """ if data_path is None or tokenizer is None: raise ValueError("Data path and tokenizer are required for loading data.") if data_path.endswith(".gz"): with gzip.open(data_path, "rt", encoding="utf-8") as file: text_data = file.read() else: with open(data_path, "r", encoding="utf-8") as file: text_data = file.read() tokenized_data = tokenizer.encode(text_data) sequences = [ tokenized_data[i : i + max_sequence_length] for i in range(0, len(tokenized_data), max_sequence_length) ] tensor_data = [torch.tensor(seq, dtype=torch.long) for seq in sequences] return tensor_data def initialize_population(population_size, extension_ratio): """ Initialize the population for evolutionary search. Args: population_size (int): Size of the population. extension_ratio (float): Extension ratio for context window. Returns: list: Initialized population. """ population = [] population.append(torch.ones(512) * extension_ratio) ntk_factors = torch.tensor([extension_ratio ** (2 * i / 512) for i in range(512)]) population.append(ntk_factors) yarn_factors = torch.ones(512) yarn_factors[:128] = 1.0 yarn_factors[128:256] = extension_ratio ** (1 / 3) yarn_factors[256:] = extension_ratio population.append(yarn_factors) for _ in range(population_size - 3): factors = torch.ones(512) for i in range(512): if random.random() < 0.1: factors[i] = random.uniform(1, extension_ratio) population.append(factors) return population def evaluate_individual(model, data, individual): """ Evaluate an individual lambda factor configuration. Args: model (nn.Module): LongRoPE model. data (list): List of input sequences. individual (list): Lambda factor configuration. Returns: float: Perplexity score for the individual. """ model.lambda_factors = individual perplexities = [] for seq in data: input_ids = seq.unsqueeze(0) output = model(input_ids) perplexity = torch.exp(torch.mean(output)) perplexities.append(perplexity.item()) return np.mean(perplexities) def evaluate_population(model, data, population): """ Evaluate the population of lambda factor configurations. Args: model (nn.Module): LongRoPE model. data (list): List of input sequences. population (list): Population of lambda factor configurations. Returns: list: Perplexity scores for each individual in the population. """ perplexities = [] for individual in population: perplexity = evaluate_individual(model, data, individual) perplexities.append(perplexity) return perplexities def select_topk(population, perplexities, k): """ Select the top-k individuals from the population based on perplexity scores. Args: population (list): Population of lambda factor configurations. perplexities (list): Perplexity scores for each individual in the population. k (int): Number of top individuals to select. Returns: list: Top-k individuals from the population. """ indices = np.argsort(perplexities)[:k] return [population[i] for i in indices] def mutate(parents, num_mutations): """ Perform mutation on the parent population. Args: parents (list): Parent population. num_mutations (int): Number of mutations to perform. Returns: list: Mutated population. """ mutated_population = [] for _ in range(num_mutations): parent = random.choice(parents) child = parent.clone() for i in range(512): if random.random() < 0.1: child[i] *= random.uniform(0.8, 1.2) mutated_population.append(child) return mutated_population def crossover(parents, num_crossovers): """ Perform crossover on the parent population. Args: parents (list): Parent population. num_crossovers (int): Number of crossovers to perform. Returns: list: Crossover population. """ crossover_population = [] for _ in range(num_crossovers): parent1, parent2 = random.sample(parents, 2) child = parent1.clone() for i in range(512): if random.random() < 0.5: child[i] = parent2[i] crossover_population.append(child) return crossover_population def fine_tune(model, data, target_length, lambda_factors, n_hat, num_epochs=3): """ Fine-tune the LongRoPE model. Args: model (nn.Module): LongRoPE model. data (list): List of input sequences. target_length (int): Target context window length. lambda_factors (list): Lambda factors for interpolation. n_hat (int): Threshold for applying interpolation. num_epochs (int, optional): Number of fine-tuning epochs. Defaults to 3. Returns: nn.Module: Fine-tuned LongRoPE model. """ model.lambda_factors = lambda_factors model.n_hat = n_hat optimizer = optim.Adam(model.parameters(), lr=1e-4) for epoch in range(num_epochs): for seq in data: optimizer.zero_grad() seq_len = seq.size(0) if seq_len <= target_length: input_ids = seq.unsqueeze(0) else: start_idx = random.randint(0, seq_len - target_length) input_ids = seq[start_idx : start_idx + target_length].unsqueeze(0) output = model(input_ids) loss = torch.mean(output) loss.backward() optimizer.step() return model # Example usage data_path = "path/to/your/dataset" d_model = 512 n_heads = 8 num_layers = 6 base_length = 4096 target_length = 2048 * 1024 data = load_data(data_path) model = LongRoPEModel(d_model, n_heads, num_layers, base_length) model = model.extend_context(data, target_length) input_ids = torch.randn(2, target_length, d_model) output = model(input_ids) print(output.shape) # Expected shape: (batch_size, target_length, d_model)
dad
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。