赞
踩
基于LLAMA,搭建简易版本Python服务端,以提供API服务。
export CUDA_VISIBLE_DEVICES=0,1,2...
要点(详见代码):
MultiProcessLlama
负责处理这些线/进程。python <server-name>.py --llama-version 7B
。实现的API:
/api/initchat
: 初始化对话环节
/api/chat
: 发起对话
/api/initchat
/api/initchat
/api/reset
: 重置对话状态。仅保留系统prompt。
/api/initchat
/api/initchat
/api/chat_once
: 只聊一句。等价于:/api/chat
and /api/reset
/api/initchat
/api/initchat
import argparse import multiprocessing import os import uuid as libuuid from multiprocessing import Queue import bottle import torch import torch.distributed from bottle import get, post, request from fairscale.nn.model_parallel.initialize import initialize_model_parallel from llama import Llama _arg = argparse.ArgumentParser("Server LLAMA Chat Web") _arg.add_argument("--llama-version", type=str, default="7B", help="LLAMA version, avable [7B, 13B, 70B]") args = _arg.parse_args() if not args.llama_version in ["7B", "13B", "70B"]: raise ValueError("LLaMA version not found. support 7B, 13B, 70B") class MultiProcessLlama: def __init__(self, world_size): self.in_queue_list = [ Queue(8) for _ in range(world_size) ] self.out_queue = Queue(8) self.world_size = world_size self.process_list = [] print("init Done") def chat_completion(self, *args, **kwargs): for ique in self.in_queue_list: # print("Call chat_completion") ique.put(( "chat_completion", args, kwargs, )) out = self.out_queue.get() return out def start(self, *args, **kwargs): def __loop(rank: int, world_size: int, in_queue: Queue, out_queue: Queue, args, kwargs): if rank == 0: assert out_queue != None os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "65288" # os.environ["RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank) # os.environ["WORLD_SIZE"] = str(world_size) torch.distributed.init_process_group( "nccl", rank=rank, world_size=world_size, ) initialize_model_parallel(world_size) generator = Llama.build( *args, **kwargs, ) while True: # print(f"[{rank}] in queue wait") cmd, args, kwargs = in_queue.get() # print(f"[{rank}] in queue get", cmd) out = None if cmd is None: break if cmd == "chat_completion": out = generator.chat_completion(*args, **kwargs) elif cmd == "text_completion": out = generator.text_completion(*args, **kwargs) else: print("Warnning, unknown command", cmd) # all responses are the same. write to rank 0 only if rank == 0: out_queue.put(out) # print(f"[{rank}] {cmd} {args}, {kwargs} => {out}") for i in range(self.world_size): pi = multiprocessing.Process( target=__loop, args=( i, self.world_size, self.in_queue_list[i], self.out_queue if i == 0 else None, args, kwargs, ) ) self.process_list.append(pi) pi.start() def join(self): for _ in range(4): # put 4 times to prevent missing queue for que in self.in_queue_list: que.put((None, None, None)) for pi in self.process_list: pi.join() chat_uuid_dict = dict() def generate_chat(generator, chat_info): if generator is None: return [{"role": "assistant", "content": "(ᗜ_ᗜ)"}] return generator.chat_completion( [chat_info], # type: ignore max_gen_len=None, temperature=0.6, top_p=0.99, )[0] Global_generator = None app = bottle.Bottle() @app.route("/") def index(): return bottle.template("./web/statics/index.html") @app.route("/statics/<filename:path>") def serve_static(filename): return bottle.static_file(filename, "web/statics") @app.post("/api/close") def api_close(): uuid = request.forms.get("uuid", "") if not uuid in chat_uuid_dict: return {"uuid": uuid, "status": 0} del chat_uuid_dict[uuid] return {"uuid": uuid, "status": 1} @app.post("/api/chat") def api_chat(): uuid = request.forms.get("uuid", "") content = request.forms.get("content", "") if content == "": return { "uuid": uuid, "status": 1, "response": "(ᗜ_ᗜ)", } if not uuid in chat_uuid_dict: return {"uuid": uuid, "status": 0} chat_hist = chat_uuid_dict[uuid] chat_hist.append({ "role": "user", "content": content, }) result = generate_chat(Global_generator, chat_hist) answer = result["generation"]['content'] chat_hist.append({ "role": "assistant", "content": answer }) return { "uuid": uuid, "status": 1, "response": answer, } @app.post("/api/initchat") def api_initchat(): content = request.forms.get("content", "Feel free to answer the question.") while True: uuid = str(libuuid.uuid4()) if not uuid in chat_uuid_dict: chat_uuid_dict[uuid] = [{ "role": "system", "content": content, }] break return { "uuid": uuid, } @app.post("/api/chat_once") def api_initchat(): uuid = request.forms.get("uuid", "") content = request.forms.get("content", "") if content == "": return { "uuid": uuid, "status": 1, "response": "(ᗜ_ᗜ)", } if not uuid in chat_uuid_dict: return {"uuid": uuid, "status": 0} chat_hist = [] chat_hist.append(chat_uuid_dict[uuid][0]) chat_hist.append({ "role": "user", "content": content, }) result = generate_chat(Global_generator, chat_hist) answer = result["generation"]['content'] return { "uuid": uuid, "status": 1, "response": answer, } @app.post("/api/reset") def api_reset(): uuid = request.forms.get("uuid", "") if not uuid in chat_uuid_dict: return {"uuid": uuid, "status": 0} chat_hist = chat_uuid_dict[uuid] # type: list init_msg = chat_hist[0] chat_hist.clear() chat_hist.append(init_msg) return {"uuid": uuid, "status": 1} if __name__ == "__main__": print("System loadding...") local_rank = int(os.environ.get("LOCAL_RANK", 0)) if Global_generator is None: if args.llama_version == "7B": os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "65288" os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" Global_generator = Llama.build( ckpt_dir="./downloads/llama-2-7b-chat", tokenizer_path="./downloads/tokenizer.model", max_seq_len=9000, max_batch_size=1, model_parallel_size=1, ) elif args.llama_version == "13B": Global_generator = MultiProcessLlama(2) Global_generator.start( ckpt_dir="./downloads/llama-2-13b-chat", tokenizer_path="./downloads/tokenizer.model", max_seq_len=2048, max_batch_size=1, model_parallel_size=2, ) elif args.llama_version == "70B": print("Use torch run") Global_generator = MultiProcessLlama(8) Global_generator.start( ckpt_dir="./downloads/llama-2-70b-chat", tokenizer_path="./downloads/tokenizer.model", max_seq_len=2048, max_batch_size=1, model_parallel_size=8, ) print("Init with", args.llama_version) app.run(host='0.0.0.0', port=8088, debug=False, reloader=False) if args.llama_version != "7B": try: Global_generator.join() except Exception as e: print(e)
客户端的工具类。
import requests import http class ChatLlama: def __init__(self, addr, content: str = ""): self.addr = addr self.chat_uuid = self.init_chat(content) def init_chat(self, content: str): resp = requests.post(f"{self.addr}/api/initchat", data={ "content": content }) if resp.status_code != http.HTTPStatus.OK: raise ValueError(resp.status_code) chat_uuid = resp.json()["uuid"] resp.close() print("init UUID", chat_uuid) return chat_uuid def chat_request(self, context) -> str: resp = requests.post(f"{self.addr}/api/chat", data={ "uuid": self.chat_uuid, "content": context, }) if resp.status_code != http.HTTPStatus.OK: raise ValueError("HTTP error", resp.status_code) ans = resp.json() if ans["status"] == 0: raise ValueError("UUID does not exist") return ans["response"] def chat_once(self, context) -> str: resp = requests.post(f"{self.addr}/api/chat_once", data={ "uuid": self.chat_uuid, "content": context, }) if resp.status_code != http.HTTPStatus.OK: raise ValueError("HTTP error", resp.status_code) ans = resp.json() if ans["status"] == 0: raise ValueError("UUID does not exist") return ans["response"] def chat_reset(self) -> bool: resp = requests.post(f"{self.addr}/api/reset", data={ "uuid": self.chat_uuid, }) if resp.status_code != http.HTTPStatus.OK: raise ValueError("HTTP error", resp.status_code) ans = resp.json() return ans["status"] == 1 def close_chat(self): resp = requests.post(f"{self.addr}/api/close", data={ "uuid": self.chat_uuid, }) resp.close()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。