当前位置:   article > 正文

使用 FastAPI 实现聊天完成 API 详解_社区聊天 api

社区聊天 api

简介

在这篇博客中,我们将详细解释一段使用 FastAPI 构建的聊天完成 API 代码。这段代码实现了一个 POST 请求的 API 端点,用于处理聊天消息并生成响应。我们将逐行解析代码,并提供必要的背景知识和示例代码。

基础概念

FastAPI

FastAPI 是一个用于构建 API 的现代、快速(高性能)的 Web 框架,基于 Python 3.6+。它使用类型提示来自动生成文档和验证请求数据。

Pydantic

Pydantic 是一个用于数据验证和设置管理的库,常与 FastAPI 一起使用。它通过 Python 的类型提示来定义数据模型,并自动验证输入数据的类型和格式。

PyTorch

PyTorch 是一个开源的深度学习框架,广泛用于研究和生产环境。它提供了灵活的张量计算和自动求导功能。

代码详解

下面是完整的代码,我们将逐段进行解释。

@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
    if len(request.messages) < 1 or request.messages[-1].role == "assistant":
        raise HTTPException(status_code=400, detail="Invalid request")

    gen_params = dict(
        messages=request.messages,
        temperature=request.temperature,
        top_p=request.top_p,
        max_tokens=request.max_tokens or 1024,
        echo=False,
        stream=request.stream,
        repetition_penalty=request.repetition_penalty,
        tools=request.tools,
        tool_choice=request.tool_choice,
    )
    logger.debug(f"==== request ====\n{gen_params}")

    if request.stream:
        predict_stream_generator = predict_stream(request.model, gen_params)
        output = await anext(predict_stream_generator)
        if output:
            return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
        logger.debug(f"First result output:\n{output}")

        function_call = None
        if output and request.tools:
            try:
                function_call = process_response(output, request.tools, use_tool=True)
            except:
                logger.warning("Failed to parse tool call")

        if isinstance(function_call, dict):
            function_call = ChoiceDeltaToolCallFunction(**function_call)
            generate = parse_output_text(request.model, output, function_call=function_call)
            return EventSourceResponse(generate, media_type="text/event-stream")
        else:
            return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
    response = ""
    async for response in generate_stream_glm4(gen_params):
        pass

    if response["text"].startswith("\n"):
        response["text"] = response["text"][1:]
    response["text"] = response["text"].strip()

    usage = UsageInfo()

    function_call, finish_reason = None, "stop"
    tool_calls = None
    if request.tools:
        try:
            function_call = process_response(response["text"], request.tools, use_tool=True)
        except Exception as e:
            logger.warning(f"Failed to parse tool call: {e}")
    if isinstance(function_call, dict):
        finish_reason = "tool_calls"
        function_call_response = ChoiceDeltaToolCallFunction(**function_call)
        function_call_instance = FunctionCall(
            name=function_call_response.name,
            arguments=function_call_response.arguments
        )
        tool_calls = [
            ChatCompletionMessageToolCall(
                id=generate_id('call_', 24),
                function=function_call_instance,
                type="function")]

    message = ChatMessage(
        role="assistant",
        content=None if tool_calls else response["text"],
        function_call=None,
        tool_calls=tool_calls,
    )

    logger.debug(f"==== message ====\n{message}")

    choice_data = ChatCompletionResponseChoice(
        index=0,
        message=message,
        finish_reason=finish_reason,
    )
    task_usage = UsageInfo.model_validate(response["usage"])
    for usage_key, usage_value in task_usage.model_dump().items():
        setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)

    return ChatCompletionResponse(
        model=request.model,
        choices=[choice_data],
        object="chat.completion",
        usage=usage
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92

1. 定义 API 端点

@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
  • 1
  • 2

这行代码定义了一个 POST 请求的 API 端点 /v1/chat/completions,并指定了请求的响应模型为 ChatCompletionResponsecreate_chat_completion 函数将处理传入的 ChatCompletionRequest 请求。

2. 请求验证

    if len(request.messages) < 1 or request.messages[-1].role == "assistant":
        raise HTTPException(status_code=400, detail="Invalid request")
  • 1
  • 2

这里我们进行请求验证:

  • 检查 messages 列表的长度是否小于 1。
  • 检查最后一条消息的角色是否为 “assistant”。

如果以上任一条件为真,则抛出 HTTP 400 错误。

3. 生成参数字典

    gen_params = dict(
        messages=request.messages,
        temperature=request.temperature,
        top_p=request.top_p,
        max_tokens=request.max_tokens or 1024,
        echo=False,
        stream=request.stream,
        repetition_penalty=request.repetition_penalty,
        tools=request.tools,
        tool_choice=request.tool_choice,
    )
    logger.debug(f"==== request ====\n{gen_params}")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

将请求中的参数转化为一个字典 gen_params,用于后续的生成操作。同时,记录调试信息。

4. 处理流式响应

    if request.stream:
        predict_stream_generator = predict_stream(request.model, gen_params)
        output = await anext(predict_stream_generator)
        if output:
            return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
        logger.debug(f"First result output:\n{output}")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

如果请求中指定了流式响应,则调用 predict_stream 函数生成流式响应生成器,并返回 EventSourceResponse。如果第一个输出存在,则直接返回生成器作为事件流响应。

5. 工具调用处理

        function_call = None
        if output and request.tools:
            try:
                function_call = process_response(output, request.tools, use_tool=True)
            except:
                logger.warning("Failed to parse tool call")

        if isinstance(function_call, dict):
            function_call = ChoiceDeltaToolCallFunction(**function_call)
            generate = parse_output_text(request.model, output, function_call=function_call)
            return EventSourceResponse(generate, media_type="text/event-stream")
        else:
            return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

如果输出存在且请求包含工具调用,则尝试解析工具调用。如果解析成功,则处理工具调用并生成新的事件流响应;否则,继续返回原始的事件流生成器。

6. 非流式响应处理

    response = ""
    async for response in generate_stream_glm4(gen_params):
        pass

    if response["text"].startswith("\n"):
        response["text"] = response["text"][1:]
    response["text"] = response["text"].strip()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

如果请求未指定流式响应,则调用 generate_stream_glm4 生成响应。在生成响应后,去掉开头的换行符并修剪两端空白。

7. 处理使用信息和工具调用

    usage = UsageInfo()

    function_call, finish_reason = None, "stop"
    tool_calls = None
    if request.tools:
        try:
            function_call = process_response(response["text"], request.tools, use_tool=True)
        except Exception as e:
            logger.warning(f"Failed to parse tool call: {e}")
    if isinstance(function_call, dict):
        finish_reason = "tool_calls"
        function_call_response = ChoiceDeltaToolCallFunction(**function_call)
        function_call_instance = FunctionCall(
            name=function_call_response.name,
            arguments=function_call_response.arguments
        )
        tool_calls = [
            ChatCompletionMessageToolCall(
                id=generate_id('call_', 24),
                function=function_call_instance,
                type="function")]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

在处理响应后,创建 UsageInfo 实例并检查是否有工具调用。如果有工具调用,则解析并生成工具调用响应。

8. 构建聊天消息

    message = ChatMessage(
        role="assistant",
        content=None if tool_calls else response["text"],
        function_call=None,
        tool_calls=tool_calls,
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

根据生成的响应和工具调用信息,创建一个 ChatMessage 实例。

9. 构建响应选择

    logger.debug(f"==== message ====\n{message}")

    choice_data = ChatCompletionResponseChoice(
        index=0,
        message=message,
        finish_reason=finish_reason,
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

这段代码将生成的 ChatMessage 实例记录到日志中,并且创建一个 ChatCompletionResponseChoice 实例,其中包含了消息的索引、消息内容和完成原因。

10. 更新使用信息

    task_usage = UsageInfo.model_validate(response["usage"])
    for usage_key, usage_value in task_usage.model_dump().items():
        setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
  • 1
  • 2
  • 3

从响应中提取使用信息,并将其添加到 usage 实例中。UsageInfo.model_validate 方法用于验证并创建一个包含使用信息的实例。

11. 返回最终响应

    return ChatCompletionResponse(
        model=request.model,
        choices=[choice_data],
        object="chat.completion",
        usage=usage
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

最后,创建并返回一个 ChatCompletionResponse 实例,其中包含了模型名称、选项列表和使用信息。

总结

通过这篇博客,我们详细解析了一个基于 FastAPI 实现的聊天完成 API 的代码。我们逐行解释了代码的功能,并介绍了相关的基础概念和库。

示例代码

为了帮助理解,我们提供一个简化版的示例代码,用于实现类似的聊天完成 API:

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List

app = FastAPI()

class ChatMessage(BaseModel):
    role: str
    content: str

class ChatCompletionRequest(BaseModel):
    messages: List[ChatMessage]
    temperature: float
    max_tokens: int

class ChatCompletionResponse(BaseModel):
    message: ChatMessage

@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
    if len(request.messages) < 1 or request.messages[-1].role == "assistant":
        raise HTTPException(status_code=400, detail="Invalid request")

    # Simplified response generation logic
    response_text = "This is a response."
    response_message = ChatMessage(role="assistant", content=response_text)
    
    return ChatCompletionResponse(message=response_message)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

这段简化代码定义了一个基本的聊天完成 API 端点,处理请求并返回简单的响应。通过这个示例,可以更好地理解完整代码的工作原理。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小桥流水78/article/detail/816942
推荐阅读
相关标签
  

闽ICP备14008679号