赞
踩
早前的文章,我们都是通过输入命令的方式来使用Chatglm3-6b模型。现在,我们可以通过使用gradio,通过一个界面与模型进行交互。这样做可以减少重复加载模型和修改代码的麻烦,
让我们更方便地体验模型的效果。
是一个用于构建交互式界面的Python库。它使得在Python中创建快速原型、构建和共享机器学习模型变得更加容易。
Gradio的主要功能是为机器学习模型提供一个即时的Web界面,使用户能够与模型进行交互,输入数据并查看结果,而无需编写复杂的前端代码。它提供了一个简单的API,可以将输入和输出绑定到模型的函数或方法,并自动生成用户界面。
从huggingface下载:https://huggingface.co/THUDM/chatglm3-6b/tree/main
- conda create --name chatglm3 python=3.10
- conda activate chatglm3
- pip install protobuf transformers==4.39.3 cpm_kernels torch>=2.0 sentencepiece accelerate
- pip install gradio
- # -*- coding = utf-8 -*-
- import gradio as gr
- import torch
- from threading import Thread
-
- from transformers import (
- AutoModelForCausalLM,
- AutoTokenizer,
- StoppingCriteria,
- StoppingCriteriaList,
- TextIteratorStreamer
- )
-
- modelPath = "/model/chatglm3-6b"
-
- def loadTokenizer():
- tokenizer = AutoTokenizer.from_pretrained(modelPath, use_fast=False, trust_remote_code=True)
- return tokenizer
-
- def loadModel():
- model = AutoModelForCausalLM.from_pretrained(modelPath, device_map="auto", trust_remote_code=True).cuda()
- model = model.eval()
- return model
-
- class StopOnTokens(StoppingCriteria):
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
- stop_ids = [0, 2]
- for stop_id in stop_ids:
- if input_ids[0][-1] == stop_id:
- return True
- return False
-
- def parse_text(text):
- lines = text.split("\n")
- lines = [line for line in lines if line != ""]
- count = 0
- for i, line in enumerate(lines):
- if "```" in line:
- count += 1
- items = line.split('`')
- if count % 2 == 1:
- lines[i] = f'<pre><code class="language-{items[-1]}">'
- else:
- lines[i] = f'<br></code></pre>'
- else:
- if i > 0:
- if count % 2 == 1:
- line = line.replace("`", "\`")
- line = line.replace("<", "<")
- line = line.replace(">", ">")
- line = line.replace(" ", " ")
- line = line.replace("*", "*")
- line = line.replace("_", "_")
- line = line.replace("-", "-")
- line = line.replace(".", ".")
- line = line.replace("!", "!")
- line = line.replace("(", "(")
- line = line.replace(")", ")")
- line = line.replace("$", "$")
- lines[i] = "<br>" + line
- text = "".join(lines)
- return text
-
- def predict(history, max_length, top_p, temperature):
- stop = StopOnTokens()
- messages = []
- for idx, (user_msg, model_msg) in enumerate(history):
- if idx == len(history) - 1 and not model_msg:
- messages.append({"role": "user", "content": user_msg})
- break
- if user_msg:
- messages.append({"role": "user", "content": user_msg})
- if model_msg:
- messages.append({"role": "assistant", "content": model_msg})
-
- model_inputs = tokenizer.apply_chat_template(messages,
- add_generation_prompt=True,
- tokenize=True,
- return_tensors="pt").to(next(model.parameters()).device)
- streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
- generate_kwargs = {
- "input_ids": model_inputs,
- "streamer": streamer,
- "max_new_tokens": max_length,
- "do_sample": True,
- "top_p": top_p,
- "temperature": temperature,
- "stopping_criteria": StoppingCriteriaList([stop]),
- "repetition_penalty": 1.2,
- }
- t = Thread(target=model.generate, kwargs=generate_kwargs)
- t.start()
-
- for new_token in streamer:
- if new_token != '':
- history[-1][1] += new_token
- yield history
-
-
- with gr.Blocks() as demo:
- gr.HTML("""<h1 align="center">ChatGLM3-6B Gradio Simple Demo</h1>""")
- chatbot = gr.Chatbot()
-
- with gr.Row():
- with gr.Column(scale=4):
- with gr.Column(scale=12):
- user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
- with gr.Column(min_width=32, scale=1):
- submitBtn = gr.Button("Submit")
- with gr.Column(scale=1):
- emptyBtn = gr.Button("Clear History")
- max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
- top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
- temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
-
-
- def user(query, history):
- return "", history + [[parse_text(query), ""]]
-
-
- submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
- predict, [chatbot, max_length, top_p, temperature], chatbot
- )
- emptyBtn.click(lambda: None, None, chatbot, queue=False)
-
- if __name__ == '__main__':
- model = loadModel()
- tokenizer = loadTokenizer()
-
- demo.queue()
- demo.launch(server_name="0.0.0.0", server_port=8989, inbrowser=True, share=False)
调用结果:
启动成功:
GPU使用情况:
浏览器访问:
推理:
1. transformers的版本太低,需要升级
pip install --upgrade transformers==4.39.3
1. 服务监听地址不能是127.0.0.1
2. 检查服务器的安全策略或防火墙配置
服务端:lsof -i:8989 查看端口是否正常监听
客户端:telnet ip 8989 查看是否可以正常连接
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。