当前位置:   article > 正文

【小工具】基于LLAMA 的Python简易服务端_llama3本地部署后用python调用

llama3本地部署后用python调用

功能介绍

基于LLAMA,搭建简易版本Python服务端,以提供API服务

前期准备

  1. 下载LLAMA并根据指示获取权重。地址:https://github.com/facebookresearch/llama
  2. 安装必要的Python库
    • bottle: 提供web服务
    • fairscale:模型并行库
    • pytorch:机器学习库
  3. 如有必要,可以调整使用那几张显卡 export CUDA_VISIBLE_DEVICES=0,1,2...

实现思路

要点(详见代码):

  1. 配置环境
  • MASTER_ADDR:主节点地址。一般为127.0.0.1
  • MASTER_PORT:主节点端口。一般1024-65535
  • LOCAL_RANK/RANK: 第几个节点。程序中动态设置
  • WORLD_SIZE: 使用多少张卡。取决于模型。7B->1, 13B->2, 60B->8
  1. 对于13B和60B,需要使用多张卡,一张卡对应一个线/进程。在生成的时候,这些线/进程需要并行启用。代码中的MultiProcessLlama负责处理这些线/进程。
  2. 使用全局字典,记录对话上下文信息。
  3. 运行 python <server-name>.py --llama-version 7B

实现的API:

  • POST /api/initchat: 初始化对话环节
    • 请求参数
      • content: 系统prompt
    • 返回响应(json)
      • uuid:用户标识
  • POST /api/chat: 发起对话
    • 请求参数
      • uuid: 用户标识,来自/api/initchat
      • content: 用户的对话内容
    • 返回响应(json)
      • uuid:用户标识,来自/api/initchat
      • status: 状态。0:失败;1:成功
      • response: LLAMA的回答
  • POST /api/reset: 重置对话状态。仅保留系统prompt。
    • 请求参数
      • uuid: 用户标识,来自/api/initchat
    • 返回响应(json)
      • uuid:用户标识,来自/api/initchat
      • status: 状态。0:失败;1:成功
  • POST:/api/chat_once: 只聊一句。等价于:/api/chat and /api/reset
    • 请求参数
      • uuid: 用户标识,来自/api/initchat
      • content: 用户的对话内容
    • 返回响应(json)
      • uuid:用户标识,来自/api/initchat
      • status: 状态。0:失败;1:成功
      • response: LLAMA的回答

具体代码

import argparse
import multiprocessing
import os
import uuid as libuuid
from multiprocessing import Queue

import bottle
import torch
import torch.distributed
from bottle import get, post, request
from fairscale.nn.model_parallel.initialize import initialize_model_parallel

from llama import Llama

_arg = argparse.ArgumentParser("Server LLAMA Chat Web")
_arg.add_argument("--llama-version", type=str,
                  default="7B", help="LLAMA version, avable [7B, 13B, 70B]")
args = _arg.parse_args()

if not args.llama_version in ["7B", "13B", "70B"]:
    raise ValueError("LLaMA version not found. support 7B, 13B, 70B")


class MultiProcessLlama:
    def __init__(self, world_size):
        self.in_queue_list = [
            Queue(8)
            for _ in range(world_size)
        ]
        self.out_queue = Queue(8)

        self.world_size = world_size
        self.process_list = []
        print("init Done")

    def chat_completion(self, *args, **kwargs):
        for ique in self.in_queue_list:
            # print("Call chat_completion")
            ique.put((
                "chat_completion",
                args,
                kwargs,
            ))

        out = self.out_queue.get()
        return out

    def start(self, *args, **kwargs):

        def __loop(rank: int, world_size: int, in_queue: Queue, out_queue: Queue, args, kwargs):
            if rank == 0:
                assert out_queue != None

            os.environ["MASTER_ADDR"] = "127.0.0.1"
            os.environ["MASTER_PORT"] = "65288"
            # os.environ["RANK"] = str(rank)
            os.environ["LOCAL_RANK"] = str(rank)
            # os.environ["WORLD_SIZE"] = str(world_size)
            torch.distributed.init_process_group(
                "nccl",
                rank=rank,
                world_size=world_size,
            )
            initialize_model_parallel(world_size)

            generator = Llama.build(
                *args,
                **kwargs,
            )

            while True:
                # print(f"[{rank}] in queue wait")
                cmd, args, kwargs = in_queue.get()
                # print(f"[{rank}] in queue get", cmd)
                out = None
                if cmd is None:
                    break
                if cmd == "chat_completion":
                    out = generator.chat_completion(*args, **kwargs)
                elif cmd == "text_completion":
                    out = generator.text_completion(*args, **kwargs)
                else:
                    print("Warnning, unknown command", cmd)
                # all responses are the same. write to rank 0 only
                if rank == 0:
                    out_queue.put(out)

                # print(f"[{rank}] {cmd} {args}, {kwargs} => {out}")

        for i in range(self.world_size):
            pi = multiprocessing.Process(
                target=__loop,
                args=(
                    i,
                    self.world_size,
                    self.in_queue_list[i],
                    self.out_queue if i == 0 else None,
                    args,
                    kwargs,
                )
            )
            self.process_list.append(pi)
            pi.start()

    def join(self):
        for _ in range(4):
            # put 4 times to prevent missing queue
            for que in self.in_queue_list:
                que.put((None, None, None))

        for pi in self.process_list:
            pi.join()


chat_uuid_dict = dict()


def generate_chat(generator, chat_info):
    if generator is None:
        return [{"role": "assistant", "content": "(ᗜ_ᗜ)"}]

    return generator.chat_completion(
        [chat_info],  # type: ignore
        max_gen_len=None,
        temperature=0.6,
        top_p=0.99,
    )[0]


Global_generator = None

app = bottle.Bottle()


@app.route("/")
def index():
    return bottle.template("./web/statics/index.html")


@app.route("/statics/<filename:path>")
def serve_static(filename):
    return bottle.static_file(filename, "web/statics")


@app.post("/api/close")
def api_close():
    uuid = request.forms.get("uuid", "")
    if not uuid in chat_uuid_dict:
        return {"uuid": uuid, "status": 0}
    del chat_uuid_dict[uuid]
    return {"uuid": uuid, "status": 1}


@app.post("/api/chat")
def api_chat():
    uuid = request.forms.get("uuid", "")
    content = request.forms.get("content", "")
    if content == "":
        return {
            "uuid": uuid,
            "status": 1,
            "response": "(ᗜ_ᗜ)",
        }

    if not uuid in chat_uuid_dict:
        return {"uuid": uuid, "status": 0}

    chat_hist = chat_uuid_dict[uuid]

    chat_hist.append({
        "role": "user",
        "content": content,
    })

    result = generate_chat(Global_generator, chat_hist)
    answer = result["generation"]['content']
    chat_hist.append({
        "role": "assistant",
        "content": answer
    })

    return {
        "uuid": uuid,
        "status": 1,
        "response": answer,
    }


@app.post("/api/initchat")
def api_initchat():
    content = request.forms.get("content", "Feel free to answer the question.")

    while True:
        uuid = str(libuuid.uuid4())
        if not uuid in chat_uuid_dict:
            chat_uuid_dict[uuid] = [{
                "role": "system",
                "content": content,
            }]
            break
    return {
        "uuid": uuid,
    }


@app.post("/api/chat_once")
def api_initchat():
    uuid = request.forms.get("uuid", "")
    content = request.forms.get("content", "")
    if content == "":
        return {
            "uuid": uuid,
            "status": 1,
            "response": "(ᗜ_ᗜ)",
        }

    if not uuid in chat_uuid_dict:
        return {"uuid": uuid, "status": 0}

    chat_hist = []
    chat_hist.append(chat_uuid_dict[uuid][0])

    chat_hist.append({
        "role": "user",
        "content": content,
    })

    result = generate_chat(Global_generator, chat_hist)
    answer = result["generation"]['content']

    return {
        "uuid": uuid,
        "status": 1,
        "response": answer,
    }


@app.post("/api/reset")
def api_reset():
    uuid = request.forms.get("uuid", "")
    if not uuid in chat_uuid_dict:
        return {"uuid": uuid, "status": 0}
    chat_hist = chat_uuid_dict[uuid]  # type: list
    init_msg = chat_hist[0]
    chat_hist.clear()
    chat_hist.append(init_msg)
    return {"uuid": uuid, "status": 1}


if __name__ == "__main__":
    print("System loadding...")
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    if Global_generator is None:
        if args.llama_version == "7B":
            os.environ["MASTER_ADDR"] = "127.0.0.1"
            os.environ["MASTER_PORT"] = "65288"
            os.environ["RANK"] = "0"
            os.environ["WORLD_SIZE"] = "1"
            Global_generator = Llama.build(
                ckpt_dir="./downloads/llama-2-7b-chat",
                tokenizer_path="./downloads/tokenizer.model",
                max_seq_len=9000,
                max_batch_size=1,
                model_parallel_size=1,
            )
        elif args.llama_version == "13B":
            Global_generator = MultiProcessLlama(2)
            Global_generator.start(
                ckpt_dir="./downloads/llama-2-13b-chat",
                tokenizer_path="./downloads/tokenizer.model",
                max_seq_len=2048,
                max_batch_size=1,
                model_parallel_size=2,
            )

        elif args.llama_version == "70B":
            print("Use torch run")
            Global_generator = MultiProcessLlama(8)
            Global_generator.start(
                ckpt_dir="./downloads/llama-2-70b-chat",
                tokenizer_path="./downloads/tokenizer.model",
                max_seq_len=2048,
                max_batch_size=1,
                model_parallel_size=8,
            )

    print("Init with", args.llama_version)
    app.run(host='0.0.0.0', port=8088, debug=False, reloader=False)
    if args.llama_version != "7B":
        try:
            Global_generator.join()
        except Exception as e:
            print(e)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293

远程调用

客户端的工具类。

import requests
import http

class ChatLlama:
    def __init__(self, addr, content: str = ""):
        self.addr = addr
        self.chat_uuid = self.init_chat(content)

    def init_chat(self, content: str):
        resp = requests.post(f"{self.addr}/api/initchat", data={ "content": content })

        if resp.status_code != http.HTTPStatus.OK:
            raise ValueError(resp.status_code)

        chat_uuid = resp.json()["uuid"]
        resp.close()

        print("init UUID", chat_uuid)
        return chat_uuid

    def chat_request(self, context) -> str:
        resp = requests.post(f"{self.addr}/api/chat", data={
            "uuid": self.chat_uuid,
            "content": context,
        })

        if resp.status_code != http.HTTPStatus.OK:
            raise ValueError("HTTP error", resp.status_code)

        ans = resp.json()
        if ans["status"] == 0:
            raise ValueError("UUID does not exist")

        return ans["response"]

    def chat_once(self, context) -> str:
        resp = requests.post(f"{self.addr}/api/chat_once", data={
            "uuid": self.chat_uuid,
            "content": context,
        })

        if resp.status_code != http.HTTPStatus.OK:
            raise ValueError("HTTP error", resp.status_code)

        ans = resp.json()
        if ans["status"] == 0:
            raise ValueError("UUID does not exist")

        return ans["response"]
    
    def chat_reset(self) -> bool:
        resp = requests.post(f"{self.addr}/api/reset", data={
            "uuid": self.chat_uuid,
        })

        if resp.status_code != http.HTTPStatus.OK:
            raise ValueError("HTTP error", resp.status_code)

        ans = resp.json()
        return ans["status"] == 1

    def close_chat(self):
        resp = requests.post(f"{self.addr}/api/close", data={
            "uuid": self.chat_uuid,
        })

        resp.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号