当前位置:   article > 正文

Torchserve部署chatglm,实现batch_inference和stream_response_torchserve batchsize

torchserve batchsize

一、 Torchserve安装

torchserve的使用需要java环境,官方建议openjdk11。
使用pip安装Torchserve和模型打包工具torch-model-archiver
pip install torchserve torch-model-archiver

二、 handler.py文件编写

handler文件主要实现模型的加载、请求输入预处理、推理和返回,可继承BaseHandler类。

import torch
from transformers import AutoTokenizer, AutoModel
from ts.torch_handler.base_handler import BaseHandler
import time
import os
import json
from ts.protocol.otf_message_handler import send_intermediate_predict_response
from pkg_resources import packaging
if packaging.version.parse(torch.__version__) >= packaging.version.parse("1.8.1"):
    from torch.profiler import ProfilerActivity, profile, record_function
    PROFILER_AVAILABLE = True
else:
    PROFILER_AVAILABLE = False
   
   
class TransformersGpt2Handler(BaseHandler):
    def initialize(self, content):
       
        # 加载模型并设置到设备上
        self.tokenizer = AutoTokenizer.from_pretrained("/workspace/torchserve-examples/chatglm2", trust_remote_code=True)
        self.model = AutoModel.from_pretrained("/workspace/torchserve-examples/chatglm2", trust_remote_code=True, device='cuda')
        self.model = self.model.eval()
       
    def preprocess(self, request):
        inputs = []
        for request_json in request:
            question_text = request_json["body"]["question"]
            history = request_json["body"]['history']
            prompt = self.tokenizer.build_prompt(question_text, history=history)
            inputs.append(prompt)        
        inputs = self.tokenizer(inputs, return_tensors="pt", padding=True)
        inputs = inputs.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        # if isinstance(question_text, (bytes, bytearray)):
        #     text = text.decode("utf-8")
        return inputs
   
    def inference(self, inputs):
        with torch.no_grad():
            input_length = inputs["input_ids"].shape[-1]
            for outputs in self.model.stream_generate(
                **inputs,
                max_length=8192,
                num_beams=1,
                do_sample=True,
                top_p=0.8,
                temperature=0.8,
            ):
                outputs = outputs[:, input_length:]
                batch_out_sentence = self.tokenizer.batch_decode(outputs)
                send_intermediate_predict_response(batch_out_sentence, self.context.request_ids, "Intermediate Prediction success", 200, self.context)
               
            return batch_out_sentence
               
   
    def postprocess(self, batch_out_sentence):
       
        return batch_out_sentence
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57

其中initialize方法加载模型。

preprocesss方法对请求输入进行预处理。torchserve可设置batch处理,所以将批次数据中每个用户的问题和历史通过tokenizer.build_prompt方法构造拼接起来,存入inputs列表,再通过inputs = self.tokenizer(inputs, return_tensors="pt", padding=True)获取chatglm模型的输入,即input_ids、position_ids等。

inference方法接受preprocess方法返回的inputs作为输入,通过model.stream_generate方法输出回答的id,再通过tokenizer的batch_decode方法将id转换为文字,最后通过send_intermediate_predict_response方法实时将回答返回。

最后postprocess不做处理,直接返回batch_out_sentence。

三、打包模型

利用torch-model-archiver把所有模型以及本地依赖文件打包到一个单独的文档中。

torch-model-archiver --model-name chatglm --version 1.0 --handler myhandler.py

  • 1
  • 2

创建新的文件夹,命名为model_store,将生成的chatglm.mar移动到model_store

四、启动torchserve

torchserve --start --model-store model_store
  • 1

注册模型

curl -X POST “http://localhost:8081/modelsurl=chatglm.mar&batch_size=50&max_batch_delay=50”
设置batchsize为50,max_batch_delay为50ms。max_batch_delay的意思就是
50ms内未集齐50个请求就将已有的请求输入模型。

设置线程数

curl -v -X PUT “http://localhost:8081/models/chatglm?min_worker=1”

至此,部署完毕,可通过http://localhost:8080/predictions/chatglm进行对话,通过curl http://localhost:8082/metrics监测其能力表现。TorchServe 定期收集系统级 metrics,并允许添加自定义 metrics。系统级 metrics 包括 CPU 利用率、主机上可用及已用的磁盘空间和内存等。

五、 并发量测试

笔者测试硬件为一张3090。

每秒增加一个用户,直到增加到50个用户。每个用户的输入的问题为“胃痛应该怎么办?”,该问题的回答一般为
“胃痛的时候,建议采取以下措施:

  1. 休息:尽量减少身体运动,躺下来放松身体,减少胃部的压力。

  2. 饮食:避免食用辛辣、油腻、难以消化、过咸等食物,喝些温水或淡盐水可以缓解一些胃部不适。如果感到饿了,可以适量食用易消化的食物,比如面包、饼干等。

  3. 药物:如果胃痛不是太严重,可以尝试口服一些消化不良的药物,如开塞露、铝碳酸镁等。如果症状较为严重,建议咨询医生并按医嘱服用。

  4. 就医:如果症状持续或加重,或者伴随呕吐、腹泻、发热等其他不适症状,建议尽快就医,以便得到专业的医疗建议和治疗。

如果感到胃痛,建议先采取适当的措施缓解不适,如休息、饮食、药物等。如症状持续或加重,应该及时就医并咨询医生的建议。”

其表现如下:RPS为2.3。
在这里插入图片描述

六、基于streamlit的前端设计

参考streamlit的示例代码
在这里插入图片描述

笔者的代码如下:

import requests
import json
import streamlit as st
import streamlit as st
import os
# Create a session_state to store chat history for multiple sessions
if "chat_sessions" not in st.session_state:
    st.session_state.chat_sessions = []
if "active_chat_session" not in st.session_state:
    st.session_state.active_chat_session = None
# Create a function to start a new chat session
def start_new_chat():
    chat_session = [{"role": "assistant", "content": "你好声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签