赞
踩
在这篇博客中,我们将详细解释一段使用 FastAPI 构建的聊天完成 API 代码。这段代码实现了一个 POST 请求的 API 端点,用于处理聊天消息并生成响应。我们将逐行解析代码,并提供必要的背景知识和示例代码。
FastAPI 是一个用于构建 API 的现代、快速(高性能)的 Web 框架,基于 Python 3.6+。它使用类型提示来自动生成文档和验证请求数据。
Pydantic 是一个用于数据验证和设置管理的库,常与 FastAPI 一起使用。它通过 Python 的类型提示来定义数据模型,并自动验证输入数据的类型和格式。
PyTorch 是一个开源的深度学习框架,广泛用于研究和生产环境。它提供了灵活的张量计算和自动求导功能。
下面是完整的代码,我们将逐段进行解释。
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
gen_params = dict(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=request.stream,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
tool_choice=request.tool_choice,
)
logger.debug(f"==== request ====\n{gen_params}")
if request.stream:
predict_stream_generator = predict_stream(request.model, gen_params)
output = await anext(predict_stream_generator)
if output:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
logger.debug(f"First result output:\n{output}")
function_call = None
if output and request.tools:
try:
function_call = process_response(output, request.tools, use_tool=True)
except:
logger.warning("Failed to parse tool call")
if isinstance(function_call, dict):
function_call = ChoiceDeltaToolCallFunction(**function_call)
generate = parse_output_text(request.model, output, function_call=function_call)
return EventSourceResponse(generate, media_type="text/event-stream")
else:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
response = ""
async for response in generate_stream_glm4(gen_params):
pass
if response["text"].startswith("\n"):
response["text"] = response["text"][1:]
response["text"] = response["text"].strip()
usage = UsageInfo()
function_call, finish_reason = None, "stop"
tool_calls = None
if request.tools:
try:
function_call = process_response(response["text"], request.tools, use_tool=True)
except Exception as e:
logger.warning(f"Failed to parse tool call: {e}")
if isinstance(function_call, dict):
finish_reason = "tool_calls"
function_call_response = ChoiceDeltaToolCallFunction(**function_call)
function_call_instance = FunctionCall(
name=function_call_response.name,
arguments=function_call_response.arguments
)
tool_calls = [
ChatCompletionMessageToolCall(
id=generate_id('call_', 24),
function=function_call_instance,
type="function")]
message = ChatMessage(
role="assistant",
content=None if tool_calls else response["text"],
function_call=None,
tool_calls=tool_calls,
)
logger.debug(f"==== message ====\n{message}")
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason=finish_reason,
)
task_usage = UsageInfo.model_validate(response["usage"])
for usage_key, usage_value in task_usage.model_dump().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse(
model=request.model,
choices=[choice_data],
object="chat.completion",
usage=usage
)
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
这行代码定义了一个 POST 请求的 API 端点 /v1/chat/completions
,并指定了请求的响应模型为 ChatCompletionResponse
。create_chat_completion
函数将处理传入的 ChatCompletionRequest
请求。
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
这里我们进行请求验证:
messages
列表的长度是否小于 1。如果以上任一条件为真,则抛出 HTTP 400 错误。
gen_params = dict(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=request.stream,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
tool_choice=request.tool_choice,
)
logger.debug(f"==== request ====\n{gen_params}")
将请求中的参数转化为一个字典 gen_params
,用于后续的生成操作。同时,记录调试信息。
if request.stream:
predict_stream_generator = predict_stream(request.model, gen_params)
output = await anext(predict_stream_generator)
if output:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
logger.debug(f"First result output:\n{output}")
如果请求中指定了流式响应,则调用 predict_stream
函数生成流式响应生成器,并返回 EventSourceResponse
。如果第一个输出存在,则直接返回生成器作为事件流响应。
function_call = None
if output and request.tools:
try:
function_call = process_response(output, request.tools, use_tool=True)
except:
logger.warning("Failed to parse tool call")
if isinstance(function_call, dict):
function_call = ChoiceDeltaToolCallFunction(**function_call)
generate = parse_output_text(request.model, output, function_call=function_call)
return EventSourceResponse(generate, media_type="text/event-stream")
else:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
如果输出存在且请求包含工具调用,则尝试解析工具调用。如果解析成功,则处理工具调用并生成新的事件流响应;否则,继续返回原始的事件流生成器。
response = ""
async for response in generate_stream_glm4(gen_params):
pass
if response["text"].startswith("\n"):
response["text"] = response["text"][1:]
response["text"] = response["text"].strip()
如果请求未指定流式响应,则调用 generate_stream_glm4
生成响应。在生成响应后,去掉开头的换行符并修剪两端空白。
usage = UsageInfo()
function_call, finish_reason = None, "stop"
tool_calls = None
if request.tools:
try:
function_call = process_response(response["text"], request.tools, use_tool=True)
except Exception as e:
logger.warning(f"Failed to parse tool call: {e}")
if isinstance(function_call, dict):
finish_reason = "tool_calls"
function_call_response = ChoiceDeltaToolCallFunction(**function_call)
function_call_instance = FunctionCall(
name=function_call_response.name,
arguments=function_call_response.arguments
)
tool_calls = [
ChatCompletionMessageToolCall(
id=generate_id('call_', 24),
function=function_call_instance,
type="function")]
在处理响应后,创建 UsageInfo
实例并检查是否有工具调用。如果有工具调用,则解析并生成工具调用响应。
message = ChatMessage(
role="assistant",
content=None if tool_calls else response["text"],
function_call=None,
tool_calls=tool_calls,
)
根据生成的响应和工具调用信息,创建一个 ChatMessage
实例。
logger.debug(f"==== message ====\n{message}")
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason=finish_reason,
)
这段代码将生成的 ChatMessage
实例记录到日志中,并且创建一个 ChatCompletionResponseChoice
实例,其中包含了消息的索引、消息内容和完成原因。
task_usage = UsageInfo.model_validate(response["usage"])
for usage_key, usage_value in task_usage.model_dump().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
从响应中提取使用信息,并将其添加到 usage
实例中。UsageInfo.model_validate
方法用于验证并创建一个包含使用信息的实例。
return ChatCompletionResponse(
model=request.model,
choices=[choice_data],
object="chat.completion",
usage=usage
)
最后,创建并返回一个 ChatCompletionResponse
实例,其中包含了模型名称、选项列表和使用信息。
通过这篇博客,我们详细解析了一个基于 FastAPI 实现的聊天完成 API 的代码。我们逐行解释了代码的功能,并介绍了相关的基础概念和库。
为了帮助理解,我们提供一个简化版的示例代码,用于实现类似的聊天完成 API:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
app = FastAPI()
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
temperature: float
max_tokens: int
class ChatCompletionResponse(BaseModel):
message: ChatMessage
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
# Simplified response generation logic
response_text = "This is a response."
response_message = ChatMessage(role="assistant", content=response_text)
return ChatCompletionResponse(message=response_message)
这段简化代码定义了一个基本的聊天完成 API 端点,处理请求并返回简单的响应。通过这个示例,可以更好地理解完整代码的工作原理。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。