当前位置:   article > 正文

qwen大语言模型基于vllm部署_vllm部署qwen

vllm部署qwen

需要引入的Python包

  1. import os
  2. os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
  3. import argparse
  4. import json
  5. from typing import AsyncGenerator
  6. from fastapi import BackgroundTasks, FastAPI, Request
  7. from fastapi.responses import JSONResponse, Response, StreamingResponse
  8. import uvicorn
  9. from vllm.engine.arg_utils import AsyncEngineArgs
  10. from vllm.engine.async_llm_engine import AsyncLLMEngine
  11. from vllm.sampling_parms import SamplingParams
  12. from vllm.utils import random_uuid

全局变量

  1. TIMEOUT_KEEP_ALIVE = 5 # seconds
  2. TIMEOUT_TO_PEVENT_DEADLOCK = 1 # seconds
  3. app = FastAPI()

generate函数

  1. @app.post("/generate")
  2. async def generate(request: Request) -> Response:
  3. """
  4. Generate Completion for the request.
  5. The request should be a JSON object with the following fields:
  6. - prompt: the prompt to use for the genreration.
  7. - stream: whether to stream the results or not.
  8. - other fields: the sampling parameters (See 'SamplingParams' for details).
  9. """
  10. try:
  11. request_dict = await request.json()
  12. # contexts = request_dict.pop("contexts")
  13. contexts = request_dict.get("data", {}).get("context")
  14. salt_uuid = request_dict.pop("salt_uuid", "null")
  15. prompt, message_doctor = process_context_qwen(contexts)
  16. stgream = request_dict.pop("stream", False)
  17. # sampling_params = SamplingParams(**request_dict)
  18. # sampling_params = SamplingParams(n=1, temperature=0.95, top_p=0.65, top_k=20, max_tokens=128)
  19. # sampling_params = SamplingParams(best_of=1, temperature=1e-6, top_p=1, top_k=-1, max_tokens=256, ignore_eos=False)
  20. sampling_params = SamplingParams(n=1, temperature=0, best_of=5, top_p=1.0, top_k=-1, use_beam_search=True, max_tokens=128)
  21. request_uuid = random_uuid()
  22. results_generator = engine.generate(prompt, sampling_params, request_uuid)
  23. # Streaming case
  24. async def stream_results() -> AsyncGenerator[bytes, None]:
  25. async for request_output in results_generator:
  26. prompt = request_output.prompt
  27. text_outputs = [
  28. prompt + output.text for output in request_output.outputs
  29. ]
  30. ret = {"text": text_outputs}
  31. yield (json.dumps(ret) + "\0").encode("utf-8")
  32. async def abort_request() -> None:
  33. await engine.abort(request_id)
  34. if stream:
  35. background_tasks = BackgroundTasks()
  36. # Abort the request if the client disconnects.
  37. background_tasks.add_task(abort_request)
  38. return StreamingResponse(stream_results(), background=background_tasks)
  39. # Non-streaming case
  40. final_output = None
  41. async for request_output in results_generator:
  42. if await request.is_disconnected():
  43. # Abort the request if the client disconnect.
  44. await engine.abort(request_id)
  45. return Response(status_code=499)
  46. final_output = request_output
  47. assert final_output is not None
  48. text_outputs = [output.text for output in final_output.outputs]
  49. print(f"output:{final_output.outputs[0].text}")
  50. ret = {"data": {"text": text_outputs}, "code": 5200, "message": "调试成功", "salt_uuid": salt_uuid}
  51. except Exception as e:
  52. ret = {"data": {"text": ""}, "code": 5201, "message": f"调用失败\n错误信息: {e}, ", "salt_uuid": salt_uuid}
  53. return JSONResponse(ret)

qwen大模型prompt context处理函数

  1. def process_context_qwen(contexts):
  2. cur_index = 0
  3. char_count = 0
  4. for index, line_dict in enumerate(contexts[::-1]):
  5. char_count += len(line_dict["message"])
  6. if char_count >= 1024:
  7. cur_index = len(contexts) - index - 1
  8. break
  9. converstaions_dataline = preprocessing(merged_json=contexts[cur_index:])[0]
  10. query = ''
  11. message_doctor = []
  12. query = ''
  13. for idx, datalines in enumerate(conversations_dataline["conversations"]):
  14. if idx != len(converstaions_dataline["conversation"]) - 1:
  15. if "human" in datalines:
  16. human = datalines["human"]
  17. query += f"<|im_start|>user\n{human}<|im_end|>\n"
  18. if "assistant" in datalines:
  19. assistant = datalines["assistant"]
  20. message_doctor.append(assistant)
  21. query += f"<|im_start|>assistant\n{assistant}<|im_end|>\n"
  22. if "system" in datalines:
  23. system = datalines["system"] + "。" if not datalines["system"].endswith("。") else datalines["system"]
  24. query += f"<|im_start|>system\n{system}<|im_end|>\n"
  25. else:
  26. if "assistant" in datalines:
  27. assistant = datalines["assistant"]
  28. message_doctor.append(assistant)
  29. query += f"<|im_start|>assistant\n{assistant}\n"
  30. else:
  31. human = datalines["human"]
  32. query += f"<|im_start|>user\n{human}<|im_end|>\n<|im_start|>assistant\n"
  33. return query, "\n".join(message_doctor)

prompt预处理函数

  1. def preprocessing(merged_json):
  2. assistant_prefix = "助手"
  3. patient_prefix = "用户"
  4. system_prefix = "system"
  5. conversations_datalines = []
  6. conversations_id = 1
  7. conversations = []
  8. content = ''
  9. prev_role = None
  10. for idx, sentence in enumerate(merged_json):
  11. cur_role = None
  12. if sentence["role"] == assistant_prefix:
  13. cur_role = "assistant"
  14. elif sentence["role"] == system_prefix:
  15. cur_role = "system"
  16. elif sentence["role"] == patient_prefix:
  17. cur_role = "human"
  18. if cur_role is not None and prev_role is not None and cur_role != prev_role:
  19. conversations.append({prev_role: content})
  20. content = ""
  21. content += "\n" + sentence["message"].strip() if content else sentence["message"].strip()
  22. if idx == len(merged_json) - 1:
  23. if (cur_role is not None and list(conversations[-1].keys())[0] == cur_role) or cur_role is None:
  24. conversations[-1][list(conversations[-1].keys())[0]] += '\n' + content
  25. else:
  26. conversations.append({cur_role: content})
  27. if cur_role is not None:
  28. prev_role = cur_role
  29. if conversations:
  30. conversations_datalines.append({
  31. "conversations_id": conversations_id,
  32. "category": "qwen",
  33. "conversation": conversations,
  34. "dataset": "yyds"
  35. })
  36. conversations_id += 1
  37. return conversations_datalines

主函数

  1. if __name__ == "__main__":
  2. parser = argparse.ArgumentParser()
  3. parser.add_argument("--host", type=str, default="0.0.0.0")
  4. parser.add_argument("--port", type=int, default=12356)
  5. parser = AsyncEngineArgs.add_cli_args(parser)
  6. args = parser.parse_args()
  7. engine_args = AsyncEngineArgs.from_cli_args(args)
  8. engine = AsyncLLMEngine.from_engine_args(engine_args)
  9. uvicorn.run(app,
  10. host=args.host,
  11. port=args.port,
  12. log_level="info",
  13. timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

terminal命令

将上述代码按顺序组合在一个py文件中

python -m xxx.py --model your_model_path --tensor-parallel-size 4 --gpu-memory-utilization 0.95 --trust-remote-code --dtype half

输入数据格式

  1. merged_json = [
  2. {
  3. "role": "system",
  4. "message": "时间:上午8点,性别:女"
  5. },
  6. {
  7. "role": "用户",
  8. "message": "您好"
  9. },
  10. {
  11. "role": "助手",
  12. "message": "您需要什么吗"
  13. }
  14. ]

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/175500
推荐阅读
  

闽ICP备14008679号