当前位置:   article > 正文

ChatGLM模型通过api方式调用响应时间慢怎么破,Fastapi流式接口来解惑,能快速提升响应速度_chatglm api

chatglm api

ChatGLM-6B流式HTTP API

本工程仿造OpneAI Chat Completion API(即GPT3.5 API)的实现,为ChatGLM-6B提供流式HTTP API。



前言

现在市面上好多教chatglm-6b本地化部署,命令行部署,webui部署的,但是api部署的方式企业用的很多,官方给的api没有直接支持流式接口,调用起来时间响应很慢,这次给大家讲一下流式服务接口如何写,大大提升响应速度


一、下载代码安装环境

依赖环境
实际版本以ChatGLM-6B官方为准。但是这里需要提醒一下:
官方更新stream_chat方法后,已不能使用4.25.1的transformers包,故transformers==4.27.1。
cpm_kernel需要本机安装CUDA(若torch中的CUDA使用集成方式,如各种一键安装包时,需要注意这点。)
为了获得更好的性能,建议使用CUDA11.6或11.7配合PyTorch 1.13和torchvision 0.14.1。

我使用的是3090显卡
python版本是3.9

先安装这个protobuf>=3.18,<3.20.1

transformers==4.27.1 transformers版本必须是4.27.1要不然会报错
torch安装命令用conda方式,不要使用pip要不然cpm_kernels会报错
安装命令
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
torch==1.12.1+cu113
torchvision==0.13.1
安装完以上的环境再安装下面的,保证万无一失
icetk
cpm_kernels
uvicorn==0.18.1必须这个版本,不然会报错
fastapi

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

二、接口服务脚本代码

from fastapi import FastAPI, Request
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import torch
from transformers import AutoTokenizer, AutoModel
import argparse
import logging
import os
import json
import sys

def getLogger(name, file_name, use_formatter=True):
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    console_handler = logging.StreamHandler(sys.stdout)
    formatter = logging.Formatter('%(asctime)s    %(message)s')
    console_handler.setFormatter(formatter)
    console_handler.setLevel(logging.INFO)
    logger.addHandler(console_handler)
    if file_name:
        handler = logging.FileHandler(file_name, encoding='utf8')
        handler.setLevel(logging.INFO)
        if use_formatter:
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s')
            handler.setFormatter(formatter)
        logger.addHandler(handler)
    return logger

logger = getLogger('ChatGLM', 'chatlog.log')

MAX_HISTORY = 5

class ChatGLM():
    def __init__(self, quantize_level, gpu_id) -> None:
        logger.info("Start initialize model...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            "THUDM/chatglm-6b", trust_remote_code=True)
        self.model = self._model(quantize_level, gpu_id)
        self.model.eval()
        _, _ = self.model.chat(self.tokenizer, "你好", history=[])
        logger.info("Model initialization finished.")
    
    def _model(self, quantize_level, gpu_id):
        model_name = "THUDM/chatglm-6b"
        quantize = int(args.quantize)
        tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
        model = None
        if gpu_id == '-1':
            if quantize == 8:
                print('CPU模式下量化等级只能是16或4,使用4')
                model_name = "THUDM/chatglm-6b-int4"
            elif quantize == 4:
                model_name = "THUDM/chatglm-6b-int4"
            model = AutoModel.from_pretrained(model_name, trust_remote_code=True).float()
        else:
            gpu_ids = gpu_id.split(",")
            self.devices = ["cuda:{}".format(id) for id in gpu_ids]
            if quantize == 16:
                model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half().cuda()
            else:
                model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half().quantize(quantize).cuda()
        return model
    
    def clear(self) -> None:
        if torch.cuda.is_available():
            for device in self.devices:
                with torch.cuda.device(device):
                    torch.cuda.empty_cache()
                    torch.cuda.ipc_collect()
    
    def answer(self, query: str, history):
        response, history = self.model.chat(self.tokenizer, query, history=history)
        history = [list(h) for h in history]
        return response, history

    def stream(self, query, history):
        if query is None or history is None:
            yield {"query": "", "response": "", "history": [], "finished": True}
        size = 0
        response = ""
        for response, history in self.model.stream_chat(self.tokenizer, query, history):
            this_response = response[size:]
            history = [list(h) for h in history]
            size = len(response)
            yield {"delta": this_response, "response": response, "finished": False}
        logger.info("Answer - {}".format(response))
        yield {"query": query, "delta": "[EOS]", "response": response, "history": history, "finished": True}


def start_server(quantize_level, http_address: str, port: int, gpu_id: str):
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id

    bot = ChatGLM(quantize_level, gpu_id)
    
    app = FastAPI()
    app.add_middleware( CORSMiddleware,
        allow_origins = ["*"],
        allow_credentials = True,
        allow_methods=["*"],
        allow_headers=["*"]
    )
    
    @app.get("/")
    def index():
        return {'message': 'started', 'success': True}

    @app.post("/chat")
    async def answer_question(arg_dict: dict):
        result = {"query": "", "response": "", "success": False}
        try:
            text = arg_dict["query"]
            ori_history = arg_dict["history"]
            logger.info("Query - {}".format(text))
            if len(ori_history) > 0:
                logger.info("History - {}".format(ori_history))
            history = ori_history[-MAX_HISTORY:]
            history = [tuple(h) for h in history] 
            response, history = bot.answer(text, history)
            logger.info("Answer - {}".format(response))
            ori_history.append((text, response))
            result = {"query": text, "response": response,
                      "history": ori_history, "success": True}
        except Exception as e:
            logger.error(f"error: {e}")
        return result

    @app.post("/stream")
    def answer_question_stream(arg_dict: dict):
        def decorate(generator):
            for item in generator:
                yield ServerSentEvent(json.dumps(item, ensure_ascii=False), event='delta')
        result = {"query": "", "response": "", "success": False}
        try:
            text = arg_dict["query"]
            ori_history = arg_dict["history"]
            logger.info("Query - {}".format(text))
            if len(ori_history) > 0:
                logger.info("History - {}".format(ori_history))
            history = ori_history[-MAX_HISTORY:]
            history = [tuple(h) for h in history]
            return EventSourceResponse(decorate(bot.stream(text, history)))
        except Exception as e:
            logger.error(f"error: {e}")
            return EventSourceResponse(decorate(bot.stream(None, None)))

    @app.get("/clear")
    def clear():
        history = []
        try:
            bot.clear()
            return {"success": True}
        except Exception as e:
            return {"success": False}

    @app.get("/score")
    def score_answer(score: int):
        logger.info("score: {}".format(score))
        return {'success': True}

    logger.info("starting server...")
    uvicorn.run(app=app, host=http_address, port=port, debug = False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Stream API Service for ChatGLM-6B')
    parser.add_argument('--device', '-d', help='device,-1 means cpu, other means gpu ids', default='0')
    parser.add_argument('--quantize', '-q', help='level of quantize, option:16, 8 or 4', default=16)
    parser.add_argument('--host', '-H', help='host to listen', default='0.0.0.0')
    parser.add_argument('--port', '-P', help='port of this service', default=8800)
    args = parser.parse_args()
    start_server(args.quantize, args.host, int(args.port), args.device)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173

三、运行启动命令

python3 -u chatglm_service_fastapi.py --host 127.0.0.1 --port 8800 --quantize 8 --device 0
参数中,--device 为 -1 表示 cpu,其他数字i表示第i张卡。
根据自己的显卡配置来决定参数,--quantize 16 需要12g显存,显存小的话可以切换到4或者8

  • 1
  • 2
  • 3
  • 4

在这里插入图片描述

接口请求方式
流式接口,使用server-sent events技术。

接口URL: http://{host_name}/stream

请求方式:POST(JSON body)

返回方式:

使用Event Stream格式,返回服务端事件流,
事件名称:delta
数据类型:JSON
返回结果:

字段名	类型	说明
delta	string	产生的字符
query	string	用户问题,为省流,finished为true时返回
response	string	目前为止的回复,finished为true时,为完整的回复
history	array[string]	会话历史,为省流,finished为true时返回
finished	boolean	true 表示结束,false 表示仍然有数据流。

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

在这里插入图片描述

curl  调用方式
curl --location --request POST 'http://hostname:8800/stream' \
--header 'Host: localhost:8001' \
--header 'User-Agent: python-requests/2.24.0' \
--header 'Accept: */*' \
--header 'Content-Type: application/json' \
--data-raw '{"query": "给我写个广告" ,"history": [] }'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

总结

以上就是今天要讲的内容,更多大语言模型知识关注微信公众号:CV算法小屋
加我微信:Lh1141755859,加交流群,备注:进群

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/260448
推荐阅读
相关标签
  

闽ICP备14008679号