当前位置:   article > 正文

大模型推理框架 vLLM 源码解析(一)_vllm源码分析

vllm源码分析

大模型推理框架 vLLM 源码解析(一)

原创 marsggbo AutoML机器学习 2024-02-04 18:13

1. Quick Start

创建如下代码,命名为 run.py

  1. from vllm import LLM, SamplingParams
  2. prompts = [
  3.  "Have you followed marsggbo in Zhihu?",
  4.  "你一键三连了吗?"
  5. ] # 输入prompts
  6. sampling_params = SamplingParams(temperature=0.8top_k=50) # 采样策略
  7. llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2) # 初始化 LLM
  8. outputs = llm.generate(prompts, sampling_params) # 完成推理
  9. for output in outputs:
  10.  prompt = output.prompt
  11.     generated_text = output.outputs[0].text
  12.     print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

执行命令:python run.py。该脚本会自动将模型以张量并行的方式在两个 GPU 上进行推理计算

整个推理过程大大致流程如下图所示,即 1 给定一定数量的 prompts(字符串数组) 2. vllm 会使用 Scheduler 模块自动对需要推理句子进行调度

3. 根据调度的结果,使用 tokenizer 将字符串转换成 prompt id,然后喂给 model 进行计算得到 logits 预测结果 4. 根据 logits 预测结果和提前设置好的采样策略对结果进行采样得到新的 token id 5. 将采样结果保存到 output

inferencce pipeline

2. 整体核心模块

上图给出了 vLLM 核心模块之间的结构关系。接下来我们从简单的模块(即输入、采样和输出)开始介绍,最后详细介绍 LLM 模块。

3. Sequence

如上图我们可以看到 vLLM 为输入的句子设计了很多子模块,这些模块的用处各不相同,但是有彼此之间有关系,下面分别详细介绍一下。

3.1 SequenceStatus

首先看到 SequenceStatus,其源代码如下:

  1. class SequenceStatus(enum.Enum):
  2.     """Status of a sequence."""
  3.     WAITING = enum.auto() # 等待中,句子还没开始推理,或者推理还未结束
  4.     RUNNING = enum.auto() # 运行中
  5.     SWAPPED = enum.auto() # 已交换
  6.     FINISHED_STOPPED = enum.auto() # 已停止
  7.     FINISHED_LENGTH_CAPPED = enum.auto() # 已长度限制
  8.     FINISHED_ABORTED = enum.auto() # 已中止
  9.     FINISHED_IGNORED = enum.auto() # 已忽略
  10.     @staticmethod
  11.     def is_finished(status: "SequenceStatus") -> bool:
  12.         # 判断状态是否为已停止、已长度限制、已中止或已忽略
  13.         return status in [
  14.             SequenceStatus.FINISHED_STOPPED,
  15.             SequenceStatus.FINISHED_LENGTH_CAPPED,
  16.             SequenceStatus.FINISHED_ABORTED,
  17.             SequenceStatus.FINISHED_IGNORED,
  18.         ]

3.2 SequenceData

SequenceData 用于存储与序列相关的数据。这个类有三个属性:prompt_token_ids(提示词的标记ID)、output_token_ids(生成文本的标记ID)和cumulative_logprob(累计对数概率)。

  1. class SequenceData:
  2.     def __init__(
  3.         self,
  4.         prompt_token_ids: List[int],
  5.     ) -> None:
  6.         self.prompt_token_ids = prompt_token_ids
  7.         self.output_token_ids: List[int] = []
  8.         self.cumulative_logprob = 0.0

3.3 Sequence

Sequence 用于存储序列的数据、状态和块信息,且每个序列有唯一标识,即seq_id。注意看下面的代码:

  • 数据其实是通过上面的 SequenceData 保存的

  • 默认初始化状态,所有句子序列的状态都是 SequenceStatus.WAITING

  • 所谓块信息,其实就是 vLLM 会在初始化阶段预留出一定数量的CPU 和 GPU 内存,一般是以 token 为单位的,例如在初始化的时候会使用值全为 0,大小为 (256, 128)的 prompt_ids做 warm up。每个序列会按照实际大小申请 block 来记录内存使用情况,即序列 token 数越多,属性logical_token_blocks包含的 block 个数也就越多。

  1. class Sequence:
  2.     def __init__(
  3.         self,
  4.         seq_id: int,
  5.         prompt: str,
  6.         prompt_token_ids: List[int],
  7.         block_size: int,
  8.     ) -> None:
  9.         self.seq_id = seq_id
  10.         self.prompt = prompt
  11.         self.block_size = block_size
  12.         self.data = SequenceData(prompt_token_ids) # 数据
  13.         self.logical_token_blocks: List[LogicalTokenBlock] = []
  14.         # Initialize the logical token blocks with the prompt token ids.
  15.         self._append_tokens_to_blocks(prompt_token_ids) # 块信息
  16.         self.status = SequenceStatus.WAITING # 状态
  17.   ...

3.3 SequenceGroup

Sequence只是单个序列的表示方式,seq_id是它的唯一标识。SequenceGroup则是为了表示多个序列,request_id是它的唯一标识,表示是第几个请求。

具体而言,可以看到__init__函数有个参数是 seqs: List[Sequence],它表示由一个或多个 Sequence 组成的列表,然后会通过self.seqs_dict = {seq.seq_id: seq for seq in seqs}转化成字典方便管理,这个字典的 key 是每个 Sequence 的唯一标识seq_id

  1. class SequenceGroup:
  2.     def __init__(
  3.         self,
  4.         request_id: str,
  5.         seqs: List[Sequence],
  6.         sampling_params: SamplingParams,
  7.         arrival_time: float,
  8.         lora_request: Optional[LoRARequest] = None,
  9.         prefix: Optional[Prefix] = None,
  10.     ) -> None:
  11.         self.request_id = request_id
  12.         self.seqs_dict = {seq.seq_id: seq for seq in seqs}
  13.         self.sampling_params = sampling_params
  14.         self.arrival_time = arrival_time
  15.   ...

下面是 vLLm 中 LLMEngine 使用 Sequence 和 SequenceGroup 的场景示例:

  1. class LLMEngine:
  2.     def add_request(
  3.         self,
  4.         request_id: str,
  5.         prompt: Optional[str],
  6.         sampling_params: SamplingParams,
  7.         prompt_token_ids: Optional[List[int]] = None,
  8.         arrival_timeOptional[float] = None,
  9.         lora_request: Optional[LoRARequest] = None,
  10.         prefix_pos: Optional[int] = None,
  11.     ) -> None:
  12.         prompt_token_ids = self.encode_request(
  13.             request_id=request_id,
  14.             prompt=prompt,
  15.             prompt_token_ids=prompt_token_ids,
  16.             lora_request=lora_request) # 将字符串序列转换成 id
  17.         # Create the sequences.
  18.         block_size = self.cache_config.block_size
  19.         seq_id = next(self.seq_counter)
  20.         seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
  21.                        lora_request)
  22.         # Create the sequence group.
  23.         seq_group = SequenceGroup(request_id, [seq], sampling_params,
  24.                                   arrival_time)
  25.         # Add the sequence group to the scheduler.
  26.         self.scheduler.add_seq_group(seq_group)

可以看到SequenceGroupseqs参数在最初阶段其实只是单个序列 ,即[seq]。但是我们知道其实一个 prompt 可以有多个输出结果,所以SequenceGroup的目的是管理一个输入 prompt的多个生成序列信息。如果我们设置SamplingParams.n=2(第 4 节会介绍),那么在推理过程中,SequenceGroup会新增一个 Sequence,这个新增的 Sequence 的 seq_id 和原来的那个 Sequence 不一样,具体的代码细节会在下一篇文章中介绍。

3.5 SequenceGroupMetadata

  1. class SequenceGroupMetadata:
  2.     def __init__(
  3.         self,
  4.         request_id: str,
  5.         is_prompt: bool,
  6.         seq_data: Dict[int, SequenceData],
  7.         sampling_params: SamplingParams,
  8.         block_tables: Dict[int, List[int]],
  9.     ) -> None:
  10.         self.request_id = request_id
  11.         self.is_prompt = is_prompt
  12.         self.seq_data = seq_data
  13.         self.sampling_params = sampling_params
  14.         self.block_tables = block_tables
  15.   ...

SequenceGroupMetadata 记录了一些元信息,下面代码展示了 Scheduler 模块是如何生成这些信息的:

  • request_id 就是 SequenceGroup的 request_id

  • seq_data 是一个字典,key 是每个 Sequence的 seq_id,value 则是对应的 data (即 SequenceData)

  • block_tables也是一个字典,key 也是每个 Sequence的 seq_id,value 这是对应 Sequence 申请的 block

  1. class Scheduler:
  2.     def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
  3.         scheduler_outputs = self._schedule()
  4.         # Create input data structures.
  5.         seq_group_metadata_list: List[SequenceGroupMetadata] = []
  6.         for seq_group in scheduler_outputs.scheduled_seq_groups:
  7.             seq_data: Dict[int, SequenceData] = {}
  8.             block_tables: Dict[int, List[int]] = {}
  9.             for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
  10.                 seq_id = seq.seq_id
  11.                 seq_data[seq_id] = seq.data # 单个 SequenceData
  12.                 block_tables[seq_id] = self.block_manager.get_block_table(seq) # 对应Sequenceblock信息
  13.             seq_group_metadata = SequenceGroupMetadata(
  14.                 request_id=seq_group.request_id,
  15.                 is_prompt=scheduler_outputs.prompt_run,
  16.                 seq_data=seq_data,
  17.                 sampling_params=seq_group.sampling_params,
  18.                 block_tables=block_tables,
  19.                 lora_request=seq_group.lora_request,
  20.                 prefix=seq_group.prefix,
  21.             )
  22.             seq_group_metadata_list.append(seq_group_metadata)
  23.         return seq_group_metadata_list, scheduler_outputs

3.6  SequenceOutput 和 SequenceGroupOutput

SequenceOutput 和 SequenceGroupOutput的关系就类似 Sequence 和 SequenceGroup。SequenceOutput其实就是记录了上一个 输入 token id 以及对应输出的 token id。

  1. class SequenceOutput:
  2.     def __init__(
  3.         self,
  4.         parent_seq_id: int,
  5.         output_token: int,
  6.         logprobs: Dict[int, float],
  7.     ) -> None:
  8.         self.parent_seq_id = parent_seq_id
  9.         self.output_token = output_token
  10.         self.logprobs = logprobs
  11. class SequenceGroupOutput:
  12.     def __init__(
  13.         self,
  14.         samples: List[SequenceOutput],
  15.         prompt_logprobs: Optional[PromptLogprobs],
  16.     ) -> None:
  17.         self.samples = samples
  18.         self.prompt_logprobs = prompt_logprobs

4. SamplingParams

SamplingParams

SamplingParams 包含以下参数:

  • n:要生成的序列的数量,默认为 1。

  • best_of:从多少个序列中选择最佳序列,需要大于 n,默认等于 n。

  • temperature:用于控制生成结果的随机性,较低的温度会使生成结果更确定性,较高的温度会使生成结果更随机。

  • top_p:用于过滤掉生成词汇表中概率低于给定阈值的词汇,控制随机性。

  • top_k:选择前 k 个候选 token,控制多样性。

  • presence_penalty:用于控制生成结果中特定词汇的出现频率。

  • frequency_penalty:用于控制生成结果中词汇的频率分布。

  • repetition_penalty:用于控制生成结果中的词汇重复程度。

  • use_beam_search:是否使用束搜索来生成序列。

  • length_penalty:用于控制生成结果的长度分布。

  • early_stopping:是否在生成过程中提前停止。

  • stop:要停止生成的词汇列表。

  • stop_token_ids:要停止生成的词汇的ID列表。

  • include_stop_str_in_output:是否在输出结果中包含停止字符串。

  • ignore_eos:在生成过程中是否忽略结束符号。

  • max_tokens:生成序列的最大长度。

  • logprobs:用于记录生成过程的概率信息。

  • prompt_logprobs:用于记录生成过程的概率信息,用于特定提示。

  • skip_special_tokens:是否跳过特殊符号。

  • spaces_between_special_tokens:是否在特殊符号之间添加空格。

这些参数的设置通常取决于具体需求和模型性能。以下是一些常见的设置指导方法:

  • temperature:较低的温度(如0.2)会产生更确定性的结果,而较高的温度(如0.8)会产生更随机的结果。您可以根据您的需求进行调整。

  • presence_penalty、frequency_penalty 和 repetition_penalty:这些参数可以用于控制生成结果中的词汇分布和重复程度。您可以根据您的需求进行调整。

  • use_beam_search:束搜索通常用于生成更高质量的结果,但可能会降低生成速度。您可以根据您的需求进行调整。

  • length_penalty:这个参数可以用于控制生成结果的长度。较高的值会产生更长的结果,而较低的值会产生更短的结果。您可以根据您的需求进行调整。

  • early_stopping:如果您不希望生成过长的结果,可以设置此参数为True。

  • stop 和 stop_token_ids:您可以使用这些参数来指定生成结果的结束条件。

5. Output 模块

Output模块

Output 主要用于表示语言模型(LLM)的生成结果,包含如下两个模块:

  • CompletionOutput

  • RequestOutput

通过上面的介绍我们知道一个 request 可能包含多个序列,CompletionOutput 用来表示一个 request 中某个序列的完整输出的数据,其中下面的index就表示该序列在 request 中的索引位置

  1. class CompletionOutput:
  2.     def __init__(
  3.         self,
  4.         index: int, # 输出结果在请求中的索引
  5.         text: str, # 生成的文本
  6.         token_ids: List[int], # 生成的文本对应的 token ID 列表
  7.         cumulative_logprob: float,
  8.         logprobs: Optional[SampleLogprobs],
  9.         finish_reason: Optional[str] = None, # 序列完成的原因(SequenceStatus)
  10.         lora_request: Optional[LoRARequest] = None,
  11.     ) -> None:
  12.         self.index = index
  13.         self.text = text
  14.         self.token_ids = token_ids
  15.         self.finish_reason = finish_reason
  16.   ...

RequestOutput则表示 request 所有序列的输出结果,有它的初始化函数可以看到它记录了对应的 request_id

  1. class RequestOutput:
  2.     def __init__(
  3.         self,
  4.         request_id: str,
  5.         prompt: str,
  6.         prompt_token_ids: List[int],
  7.         prompt_logprobs: Optional[PromptLogprobs],
  8.         outputs: List[CompletionOutput],
  9.         finished: bool,
  10.         lora_request: Optional[LoRARequest] = None,
  11.     ) -> None:
  12.         self.request_id = request_id
  13.         self.prompt = prompt
  14.         self.prompt_token_ids = prompt_token_ids
  15.         self.outputs = outputs
  16.         self.finished = finished
  17.   ...

我们看看RequestOutput的from_seq_group就能很好理解CompletionOutput和 RequestOutput是如何使用的了。为方便理解,代码有删减,但是不影响最终结果:

  1. class RequestOutput:
  2.     @classmethod
  3.     def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
  4.         # 1Get the top-n sequences.
  5.         n = seq_group.sampling_params.n # 每个序列返回的生成序列数量
  6.         seqs = seq_group.get_seqs()
  7.   # 根据累积 logprob 值来选择出前 n 个生成序列
  8.   sorting_key = lambda seq: seq.get_cumulative_logprob()
  9.         sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
  10.         top_n_seqs = sorted_seqs[:n]
  11.         # 2. Create the outputs.
  12.         outputs: List[CompletionOutput] = []
  13.         for seq in top_n_seqs:
  14.             logprobs = seq.output_logprobs
  15.             finshed_reason = SequenceStatus.get_finished_reason(seq.status)
  16.             output = CompletionOutput(seqs.index(seq), seq.output_text,
  17.                                       seq.get_output_token_ids(),
  18.                                       seq.get_cumulative_logprob(), logprobs,
  19.                                       finshed_reason)
  20.             outputs.append(output)
  21.         # Every sequence in the sequence group should have the same prompt.
  22.         prompt = seq_group.prompt
  23.         prompt_token_ids = seq_group.prompt_token_ids
  24.         prompt_logprobs = seq_group.prompt_logprobs
  25.         finished = seq_group.is_finished()
  26.         return cls(seq_group.request_id,
  27.                    prompt,
  28.                    prompt_token_ids,
  29.                    prompt_logprobs,
  30.                    outputs,
  31.                    finished,
  32.                    lora_request=seq_group.lora_request)

RequestOutput是通过对传入的seq_group: SequenceGroup进行解析后得到的。解析过程主要有两个阶段:

  1. Get the top-n sequences:这一阶段就是对生成序列按照 cumulative_logprob 进行排序,最后选择出top-n 序列。

  2. Create the outputs:将所有top-n生成序列分别转换成 CompletionOutput列表,并作为RequestOutput的初始化参数。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/519741
推荐阅读
相关标签
  

闽ICP备14008679号