赞
踩
- import os
-
- os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
-
- import argparse
- import json
- from typing import AsyncGenerator
- from fastapi import BackgroundTasks, FastAPI, Request
- from fastapi.responses import JSONResponse, Response, StreamingResponse
- import uvicorn
-
-
- from vllm.engine.arg_utils import AsyncEngineArgs
- from vllm.engine.async_llm_engine import AsyncLLMEngine
- from vllm.sampling_parms import SamplingParams
- from vllm.utils import random_uuid
- TIMEOUT_KEEP_ALIVE = 5 # seconds
- TIMEOUT_TO_PEVENT_DEADLOCK = 1 # seconds
- app = FastAPI()
- @app.post("/generate")
- async def generate(request: Request) -> Response:
- """
- Generate Completion for the request.
- The request should be a JSON object with the following fields:
- - prompt: the prompt to use for the genreration.
- - stream: whether to stream the results or not.
- - other fields: the sampling parameters (See 'SamplingParams' for details).
- """
-
- try:
- request_dict = await request.json()
- # contexts = request_dict.pop("contexts")
- contexts = request_dict.get("data", {}).get("context")
- salt_uuid = request_dict.pop("salt_uuid", "null")
- prompt, message_doctor = process_context_qwen(contexts)
- stgream = request_dict.pop("stream", False)
- # sampling_params = SamplingParams(**request_dict)
- # sampling_params = SamplingParams(n=1, temperature=0.95, top_p=0.65, top_k=20, max_tokens=128)
- # sampling_params = SamplingParams(best_of=1, temperature=1e-6, top_p=1, top_k=-1, max_tokens=256, ignore_eos=False)
- sampling_params = SamplingParams(n=1, temperature=0, best_of=5, top_p=1.0, top_k=-1, use_beam_search=True, max_tokens=128)
-
- request_uuid = random_uuid()
- results_generator = engine.generate(prompt, sampling_params, request_uuid)
-
- # Streaming case
- async def stream_results() -> AsyncGenerator[bytes, None]:
- async for request_output in results_generator:
- prompt = request_output.prompt
- text_outputs = [
- prompt + output.text for output in request_output.outputs
- ]
- ret = {"text": text_outputs}
- yield (json.dumps(ret) + "\0").encode("utf-8")
- async def abort_request() -> None:
- await engine.abort(request_id)
-
- if stream:
- background_tasks = BackgroundTasks()
- # Abort the request if the client disconnects.
- background_tasks.add_task(abort_request)
- return StreamingResponse(stream_results(), background=background_tasks)
-
- # Non-streaming case
- final_output = None
- async for request_output in results_generator:
- if await request.is_disconnected():
- # Abort the request if the client disconnect.
- await engine.abort(request_id)
- return Response(status_code=499)
- final_output = request_output
-
- assert final_output is not None
- text_outputs = [output.text for output in final_output.outputs]
- print(f"output:{final_output.outputs[0].text}")
- ret = {"data": {"text": text_outputs}, "code": 5200, "message": "调试成功", "salt_uuid": salt_uuid}
- except Exception as e:
- ret = {"data": {"text": ""}, "code": 5201, "message": f"调用失败\n错误信息: {e}, ", "salt_uuid": salt_uuid}
- return JSONResponse(ret)
- def process_context_qwen(contexts):
- cur_index = 0
- char_count = 0
- for index, line_dict in enumerate(contexts[::-1]):
- char_count += len(line_dict["message"])
- if char_count >= 1024:
- cur_index = len(contexts) - index - 1
- break
-
- converstaions_dataline = preprocessing(merged_json=contexts[cur_index:])[0]
-
- query = ''
- message_doctor = []
- query = ''
- for idx, datalines in enumerate(conversations_dataline["conversations"]):
- if idx != len(converstaions_dataline["conversation"]) - 1:
- if "human" in datalines:
- human = datalines["human"]
- query += f"<|im_start|>user\n{human}<|im_end|>\n"
- if "assistant" in datalines:
- assistant = datalines["assistant"]
- message_doctor.append(assistant)
- query += f"<|im_start|>assistant\n{assistant}<|im_end|>\n"
- if "system" in datalines:
- system = datalines["system"] + "。" if not datalines["system"].endswith("。") else datalines["system"]
- query += f"<|im_start|>system\n{system}<|im_end|>\n"
- else:
- if "assistant" in datalines:
- assistant = datalines["assistant"]
- message_doctor.append(assistant)
- query += f"<|im_start|>assistant\n{assistant}\n"
- else:
- human = datalines["human"]
- query += f"<|im_start|>user\n{human}<|im_end|>\n<|im_start|>assistant\n"
- return query, "\n".join(message_doctor)
- def preprocessing(merged_json):
- assistant_prefix = "助手"
- patient_prefix = "用户"
- system_prefix = "system"
- conversations_datalines = []
- conversations_id = 1
- conversations = []
- content = ''
- prev_role = None
- for idx, sentence in enumerate(merged_json):
- cur_role = None
- if sentence["role"] == assistant_prefix:
- cur_role = "assistant"
- elif sentence["role"] == system_prefix:
- cur_role = "system"
- elif sentence["role"] == patient_prefix:
- cur_role = "human"
- if cur_role is not None and prev_role is not None and cur_role != prev_role:
- conversations.append({prev_role: content})
- content = ""
- content += "\n" + sentence["message"].strip() if content else sentence["message"].strip()
-
- if idx == len(merged_json) - 1:
- if (cur_role is not None and list(conversations[-1].keys())[0] == cur_role) or cur_role is None:
- conversations[-1][list(conversations[-1].keys())[0]] += '\n' + content
- else:
- conversations.append({cur_role: content})
- if cur_role is not None:
- prev_role = cur_role
- if conversations:
- conversations_datalines.append({
- "conversations_id": conversations_id,
- "category": "qwen",
- "conversation": conversations,
- "dataset": "yyds"
- })
- conversations_id += 1
- return conversations_datalines
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--host", type=str, default="0.0.0.0")
- parser.add_argument("--port", type=int, default=12356)
- parser = AsyncEngineArgs.add_cli_args(parser)
- args = parser.parse_args()
-
- engine_args = AsyncEngineArgs.from_cli_args(args)
- engine = AsyncLLMEngine.from_engine_args(engine_args)
-
- uvicorn.run(app,
- host=args.host,
- port=args.port,
- log_level="info",
- timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
将上述代码按顺序组合在一个py文件中
python -m xxx.py --model your_model_path --tensor-parallel-size 4 --gpu-memory-utilization 0.95 --trust-remote-code --dtype half
- merged_json = [
- {
- "role": "system",
- "message": "时间:上午8点,性别:女"
- },
- {
- "role": "用户",
- "message": "您好"
- },
- {
- "role": "助手",
- "message": "您需要什么吗"
- }
- ]
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。