当前位置:   article > 正文

FastChat部署服务架构(一)

fastchat

FastChat 部署服务的代码位于 fastchat/serve,核心的文件有 3 个:

  • controller.py:实现了 Controller,它的功能包括注册新 Worker、删除 Worker、分配 Worker
  • model_worker.py:实现了 Worker,它的功能是调用模型处理请求并将结果返回给 Server。每个 Worker 都单独拥有一个完整的模型,可以多个 Worker 处理同样的模型,例如 Worker 1 和 Worker 2 都处理 Model A,这样可以提高Model A 处理请求的吞吐量。另外,Worker 和 GPU 是一对多的关系,即一个 Worker 可以对应多个 GPU,例如使用了张量并行(Tensor Parallelism)将一个模型切分到多个 GPU 上
  • openai_api_server.py:实现了 OpenAI 兼容的 RESTful API

它们的关系如下图所示:

图 1:https://github.com/lm-sys/FastChat/blob/main/docs/server_arch.md

以处理一个请求为例介绍它的流程:

  1. 用户往 Server(例如 OpenAI API Server)发送请求,其中请求包含了模型名以及输入,例如:
  1. curl http://localhost:8000/v1/chat/completions \
  2. -H "Content-Type: application/json" \
  3. -d '{
  4. "model": "Llama-3-8B-Instruct",
  5. "messages": [{"role": "user", "content": "Hello! What is your name?"}]
  6. }'

2. Server 向 Controller 发送请求,目的是获取处理 model 的 Worker 地址

3. Controller 根据负载均衡策略分配 Worker

4. Server 向 Worker 发送请求

5. Worker 处理请求并将结果返回给 Server

6. Server 将结果返回给用户

以上就是 FastChat 处理一个请求的流程,接下来,我们将实现一个最小的 FastChat。

实现 Mini FastChat

Mini FastChat 支持的功能和实现方式和 FastChat 类似,但做了简化,代码修改自 FastChat。

Mini FastChat 的目录结构如下:

  1. mini-fastchat
  2. ├── controller.py
  3. ├── worker.py
  4. └── openai_api_server.py

Controller

新建一个 controller.py 文件,主要实现了 Controller 类,它的功能是注册 Worker 以及为请求随机分配 Worker。同时,controller.py 提供了两个接口register_workerget_worker_address,前者会被 Worker 调用以将 Worker 注册到 Controller 中,后者会被 API Server 调用以获得 Worker 的地址。

  1. import argparse
  2. import uvicorn
  3. import random
  4. from fastapi import FastAPI, Request
  5. from loguru import logger
  6. class Controller:
  7. def __init__(self):
  8. self.worker_info = {}
  9. def register_worker(
  10. self,
  11. worker_addr: str,
  12. model_name: str,
  13. ):
  14. logger.info(f'Register worker: {worker_addr} {model_name}')
  15. self.worker_info[worker_addr] = model_name
  16. def get_worker_address(self, model_name: str):
  17. # 为请求分配 worker
  18. worker_addr_list = []
  19. for worker_addr, _model_name in self.worker_info.items():
  20. if _model_name == model_name:
  21. worker_addr_list.append(worker_addr)
  22. assert len(worker_addr_list) > 0, f'No worker for model: {model_name}'
  23. # 使用随机的方式分配 worker
  24. worker_addr = random.choice(worker_addr_list)
  25. return worker_addr
  26. app = FastAPI()
  27. @app.post('/register_worker')
  28. async def register_worker(request: Request):
  29. data = await request.json()
  30. controller.register_worker(
  31. worker_addr=data['worker_addr'],
  32. model_name=data['model_name'],
  33. )
  34. @app.post("/get_worker_address")
  35. async def get_worker_address(request: Request):
  36. data = await request.json()
  37. addr = controller.get_worker_address(data['model'])
  38. return {"address": addr}
  39. def create_controller():
  40. parser = argparse.ArgumentParser()
  41. parser.add_argument('--host', type=str, default='localhost')
  42. parser.add_argument('--port', type=int, default=21001)
  43. args = parser.parse_args()
  44. logger.info(f'args: {args}')
  45. controller = Controller()
  46. return args, controller
  47. if __name__ == '__main__':
  48. args, controller = create_controller()
  49. uvicorn.run(app, host=args.host, port=args.port, log_level='info')

Worker

新建一个 worker.py 文件,主要实现了 Worker 类,同时提供了api_generate接口将会被 API Server 调用以处理用户的请求。

  1. import argparse
  2. import asyncio
  3. from typing import Optional
  4. import requests
  5. import uvicorn
  6. import torch
  7. from loguru import logger
  8. from transformers import AutoTokenizer, AutoModelForCausalLM
  9. from fastapi import FastAPI, Request
  10. def load_model(model_path: str) -> None:
  11. logger.info(f'Load model from {model_path}')
  12. tokenizer = AutoTokenizer.from_pretrained(model_path)
  13. model = AutoModelForCausalLM.from_pretrained(
  14. model_path,
  15. torch_dtype=torch.bfloat16,
  16. device_map='auto',
  17. )
  18. logger.info(f'model device: {model.device}')
  19. return model, tokenizer
  20. def generate(model, tokenizer, params: dict):
  21. input_ids = tokenizer.apply_chat_template(
  22. params['messages'],
  23. add_generation_prompt=True,
  24. return_tensors="pt"
  25. ).to(model.device)
  26. terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
  27. outputs = model.generate(
  28. input_ids,
  29. max_new_tokens=256,
  30. eos_token_id=terminators,
  31. do_sample=True,
  32. temperature=0.6,
  33. top_p=0.9,
  34. )
  35. response = outputs[0][input_ids.shape[-1]:]
  36. return tokenizer.decode(response, skip_special_tokens=True)
  37. class Worker:
  38. def __init__(
  39. self,
  40. controller_addr: str,
  41. worker_addr: str,
  42. model_path: str,
  43. model_name: Optional[str] = None,
  44. ) -> None:
  45. self.controller_addr = controller_addr
  46. self.worker_addr = worker_addr
  47. self.model, self.tokenizer = load_model(model_path)
  48. self.model_name = model_name
  49. self.register_to_controller()
  50. def register_to_controller(self) -> None:
  51. logger.info('Register to controller')
  52. url = self.controller_addr + '/register_worker'
  53. data = {
  54. 'worker_addr': self.worker_addr,
  55. 'model_name': self.model_name,
  56. }
  57. response = requests.post(url, json=data)
  58. assert response.status_code == 200
  59. def generate_gate(self, params: dict):
  60. return generate(self.model, self.tokenizer, params)
  61. app = FastAPI()
  62. @app.post("/worker_generate")
  63. async def api_generate(request: Request):
  64. params = await request.json()
  65. output = await asyncio.to_thread(worker.generate_gate, params)
  66. return {'output': output}
  67. def create_worker():
  68. parser = argparse.ArgumentParser()
  69. parser.add_argument('model_path', type=str, help='Path to the model')
  70. parser.add_argument('model_name', type=str)
  71. parser.add_argument('--host', type=str, default='localhost')
  72. parser.add_argument('--port', type=int, default=21002)
  73. parser.add_argument('--controller-address', type=str, default='http://localhost:21001')
  74. args = parser.parse_args()
  75. logger.info(f'args: {args}')
  76. args.worker_address = f'http://{args.host}:{args.port}'
  77. worker = Worker(worker_addr=args.worker_address, controller_addr=args.controller_address, model_path=args.model_path, model_name=args.model_name)
  78. return args, worker
  79. if __name__ == '__main__':
  80. args, worker = create_worker()
  81. uvicorn.run(app, host=args.host, port=args.port, log_level='info')

Server

  1. import argparse
  2. import asyncio
  3. import aiohttp
  4. import uvicorn
  5. from fastapi import FastAPI, Request
  6. from loguru import logger
  7. app = FastAPI()
  8. app_settings = {}
  9. async def fetch_remote(url, payload):
  10. async with aiohttp.ClientSession() as session:
  11. async with session.post(url, json=payload) as response:
  12. return await response.json()
  13. async def generate_completion(payload, worker_addr: str):
  14. return await fetch_remote(worker_addr + "/worker_generate", payload)
  15. async def get_worker_address(model_name: str) -> str:
  16. controller_address = app_settings['controller_address']
  17. res = await fetch_remote(
  18. controller_address + "/get_worker_address", {"model": model_name}
  19. )
  20. return res['address']
  21. @app.post('/v1/chat/completions')
  22. async def create_chat_completion(request: Request):
  23. data = await request.json()
  24. worker_addr = await get_worker_address(data['model'])
  25. response = asyncio.create_task(generate_completion(data, worker_addr))
  26. await response
  27. return response.result()
  28. def create_openai_api_server():
  29. parser = argparse.ArgumentParser()
  30. parser.add_argument('--host', type=str, default='localhost')
  31. parser.add_argument('--port', type=int, default=8000)
  32. parser.add_argument('--controller-address', type=str, default='http://localhost:21001')
  33. args = parser.parse_args()
  34. logger.info(f'args: {args}')
  35. app_settings['controller_address'] = args.controller_address
  36. return args
  37. if __name__ == '__main__':
  38. args = create_openai_api_server()
  39. uvicorn.run(app, host=args.host, port=args.port, log_level='info')

运行 Mini FastChat

配置环境

  • 创建 conda
conda create -n fastchat python=3.10 -y conda activate fastchat
  • 安装 torch2.2.1
conda install pytorch==2.2.1 pytorch-cuda=12.1 -c pytorch -c nvidia
  • 安装依赖
pip install requests aiohttp uvicorn fastapi loguru transformers

运行

  • 启动 controller
python mini-fastchat/controller.py
  • 启动 worker
  1. python mini-fastchat/worker.py meta-llama/Meta-Llama-3-8B-Instruct Llama-3-8B-Instruct
  2. # 如果环境中还有多余的 GPU,可以再起一个 worker
  3. CUDA_VISIBLE_DEVICES=1 python mini-fastchat/worker.py meta-llama/Meta-Llama-3-8B-Instruct Llama-3-8B-Instruct --port 21003
  • 启动 API server
python mini-fastchat/openai_api_server.py
  • 测试
  1. curl http://localhost:8000/v1/chat/completions \
  2. -H "Content-Type: application/json" \
  3. -d '{
  4. "model": "Llama-3-8B-Instruct",
  5. "messages": [{"role": "user", "content": "Hello! What is your name?"}]
  6. }'

如果上面的命令可以看到输出,则说明成功运行了。

可以改进的点

Mini FastChat 简单实现了类 FastChat 部署服务,但相比于 FastChat,还有很多可以改进的点,例如:

  • 负载均衡策略:Mini FastChat 的 Controller 只支持了随机分配 Worker,而 FastChat Controller 支持 LOTTERY 和 SHORTEST_QUEUE 策略
  • 代码不够鲁棒:为了简化实现,Mini FastChat 没有处理可能出现的异常情况,例如输入有误、网络异常

根据自己学习情况,整理了一个流程图

下一篇文章

        Fastchat负载均衡策略

参考链接:

https://zhuanlan.zhihu.com/p/694856151

https://zhuanlan.zhihu.com/p/695038224

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/黑客灵魂/article/detail/787411
推荐阅读
相关标签
  

闽ICP备14008679号