赞
踩
本文来自chatglm官方
代码拉下来
git clone https://github.com/THUDM/ChatGLM2-6B.git
cd ChatGLM2-6B
安装依赖
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
因为网络问题,模型文件最好到hugging face自行下载,所有文件都放到 ChatGLM2-6B/model 目录下
提供API调用的代码:
from fastapi import FastAPI, Request from transformers import AutoTokenizer, AutoModel import uvicorn, json, datetime import torch from sse_starlette.sse import EventSourceResponse app = FastAPI() # 流式推理 def predict_stream(tokenizer, prompt, history, max_length, top_p, temperature): for response, new_history in model.stream_chat(tokenizer, prompt, history=history, max_length=max_length if max_length else 2048, top_p=top_p if top_p else 0.7, temperature=temperature if temperature else 0.95): now = datetime.datetime.now() time = now.strftime("%Y-%m-%d %H:%M:%S") yield json.dumps({ 'response': response, 'history': new_history, 'status': 200, 'time': time, 'sse_status': 1 }) log = "[" + time + "] " + "---来自流式推理的消息---" + "prompt:" + prompt + ", response:" + repr(response) print(log, flush=True) # 推理完成后,发送最后一包数据,sse_statu=2标识sse结束 yield json.dumps({ 'response': response, 'history': new_history, 'status': 200, 'time': time, 'sse_status': 2 }) return torch_gc() # 编码 def parse_text(text): """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split('`') if count % 2 == 1: lines[i] = f'<pre><code class="language-{items[-1]}">' else: lines[i] = f'<br></code></pre>' else: if i > 0: if count % 2 == 1: line = line.replace("`", "\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "<br>" + line text = "".join(lines) return text # GC回收显存 def torch_gc(): if torch.cuda.is_available(): with torch.cuda.device(CUDA_DEVICE): torch.cuda.empty_cache() torch.cuda.ipc_collect() # sse流式方式 @app.post("/chatglm/server/text2text/sse") async def create_item_sse(request: Request): json_post_raw = await request.json() json_post = json.dumps(json_post_raw) json_post_list = json.loads(json_post) prompt = json_post_list.get('prompt') history = json_post_list.get('history') max_length = json_post_list.get('max_length') top_p = json_post_list.get('top_p') temperature = json_post_list.get('temperature') res = predict_stream(tokenizer, prompt, history, max_length, top_p, temperature) return EventSourceResponse(res) if __name__ == '__main__': # cpu/gpu推理,建议GPU,CPU实在是忒慢了 DEVICE = "cuda" DEVICE_ID = "0" CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True) model = AutoModel.from_pretrained("model", trust_remote_code=True).half().cuda() model.eval() uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。