当前位置:   article > 正文

本地部署Chatglm2-6B及流式API接口(sse)_chatglm2-6b 流式 api 部署

chatglm2-6b 流式 api 部署

本文来自chatglm官方

代码拉下来
git clone https://github.com/THUDM/ChatGLM2-6B.git
cd ChatGLM2-6B
安装依赖
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
  • 1
  • 2
  • 3
  • 4
  • 5

因为网络问题,模型文件最好到hugging face自行下载,所有文件都放到 ChatGLM2-6B/model 目录下

提供API调用的代码:

from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import torch
from sse_starlette.sse import EventSourceResponse

app = FastAPI()


# 流式推理
def predict_stream(tokenizer, prompt, history, max_length, top_p, temperature):
    for response, new_history in model.stream_chat(tokenizer,
                                                   prompt,
                                                   history=history,
                                                   max_length=max_length if max_length else 2048,
                                                   top_p=top_p if top_p else 0.7,
                                                   temperature=temperature if temperature else 0.95):
        now = datetime.datetime.now()
        time = now.strftime("%Y-%m-%d %H:%M:%S")
        yield json.dumps({
            'response': response,
            'history': new_history,
            'status': 200,
            'time': time,
            'sse_status': 1
        })
    log = "[" + time + "] " + "---来自流式推理的消息---" + "prompt:" + prompt + ", response:" + repr(response)
    print(log, flush=True)
    # 推理完成后,发送最后一包数据,sse_statu=2标识sse结束
    yield json.dumps({
        'response': response,
        'history': new_history,
        'status': 200,
        'time': time,
        'sse_status': 2
    })
    return torch_gc()


# 编码
def parse_text(text):
    """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>" + line
    text = "".join(lines)
    return text


# GC回收显存
def torch_gc():
    if torch.cuda.is_available():
        with torch.cuda.device(CUDA_DEVICE):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()


# sse流式方式
@app.post("/chatglm/server/text2text/sse")
async def create_item_sse(request: Request):
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    prompt = json_post_list.get('prompt')
    history = json_post_list.get('history')
    max_length = json_post_list.get('max_length')
    top_p = json_post_list.get('top_p')
    temperature = json_post_list.get('temperature')
    res = predict_stream(tokenizer, prompt, history, max_length, top_p, temperature)
    return EventSourceResponse(res)


if __name__ == '__main__':
    # cpu/gpu推理,建议GPU,CPU实在是忒慢了
    DEVICE = "cuda"
    DEVICE_ID = "0"
    CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
    tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
    model = AutoModel.from_pretrained("model", trust_remote_code=True).half().cuda()
    model.eval()
    uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/348539
推荐阅读
相关标签
  

闽ICP备14008679号