当前位置:   article > 正文

开源模型应用落地-chatglm3-6b-gradio-入门篇(七)_from transformers import chatglmgenerator, chatglm

from transformers import chatglmgenerator, chatglmconfig, chatglmmodel impor

一、前言

    早前的文章,我们都是通过输入命令的方式来使用Chatglm3-6b模型。现在,我们可以通过使用gradio,通过一个界面与模型进行交互。这样做可以减少重复加载模型和修改代码的麻烦,
让我们更方便地体验模型的效果。


二、术语

2.1、Gradio

    是一个用于构建交互式界面的Python库。它使得在Python中创建快速原型、构建和共享机器学习模型变得更加容易。

    Gradio的主要功能是为机器学习模型提供一个即时的Web界面,使用户能够与模型进行交互,输入数据并查看结果,而无需编写复杂的前端代码。它提供了一个简单的API,可以将输入和输出绑定到模型的函数或方法,并自动生成用户界面。


三、前置条件

3.1. windows or linux操作系统均可

3.2. 下载chatglm3-6b模型

从huggingface下载:https://huggingface.co/THUDM/chatglm3-6b/tree/main

从魔搭下载:魔搭社区汇聚各领域最先进的机器学习模型,提供模型探索体验、推理、训练、部署和应用的一站式服务。https://www.modelscope.cn/models/ZhipuAI/chatglm3-6b/fileshttps://www.modelscope.cn/models/ZhipuAI/chatglm3-6b/files

 3.3. 创建虚拟环境&安装依赖

  1. conda create --name chatglm3 python=3.10
  2. conda activate chatglm3
  3. pip install protobuf transformers==4.39.3 cpm_kernels torch>=2.0 sentencepiece accelerate
  4. pip install gradio

四、技术实现

  1. # -*- coding = utf-8 -*-
  2. import gradio as gr
  3. import torch
  4. from threading import Thread
  5. from transformers import (
  6. AutoModelForCausalLM,
  7. AutoTokenizer,
  8. StoppingCriteria,
  9. StoppingCriteriaList,
  10. TextIteratorStreamer
  11. )
  12. modelPath = "/model/chatglm3-6b"
  13. def loadTokenizer():
  14. tokenizer = AutoTokenizer.from_pretrained(modelPath, use_fast=False, trust_remote_code=True)
  15. return tokenizer
  16. def loadModel():
  17. model = AutoModelForCausalLM.from_pretrained(modelPath, device_map="auto", trust_remote_code=True).cuda()
  18. model = model.eval()
  19. return model
  20. class StopOnTokens(StoppingCriteria):
  21. def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
  22. stop_ids = [0, 2]
  23. for stop_id in stop_ids:
  24. if input_ids[0][-1] == stop_id:
  25. return True
  26. return False
  27. def parse_text(text):
  28. lines = text.split("\n")
  29. lines = [line for line in lines if line != ""]
  30. count = 0
  31. for i, line in enumerate(lines):
  32. if "```" in line:
  33. count += 1
  34. items = line.split('`')
  35. if count % 2 == 1:
  36. lines[i] = f'<pre><code class="language-{items[-1]}">'
  37. else:
  38. lines[i] = f'<br></code></pre>'
  39. else:
  40. if i > 0:
  41. if count % 2 == 1:
  42. line = line.replace("`", "\`")
  43. line = line.replace("<", "&lt;")
  44. line = line.replace(">", "&gt;")
  45. line = line.replace(" ", "&nbsp;")
  46. line = line.replace("*", "&ast;")
  47. line = line.replace("_", "&lowbar;")
  48. line = line.replace("-", "&#45;")
  49. line = line.replace(".", "&#46;")
  50. line = line.replace("!", "&#33;")
  51. line = line.replace("(", "&#40;")
  52. line = line.replace(")", "&#41;")
  53. line = line.replace("$", "&#36;")
  54. lines[i] = "<br>" + line
  55. text = "".join(lines)
  56. return text
  57. def predict(history, max_length, top_p, temperature):
  58. stop = StopOnTokens()
  59. messages = []
  60. for idx, (user_msg, model_msg) in enumerate(history):
  61. if idx == len(history) - 1 and not model_msg:
  62. messages.append({"role": "user", "content": user_msg})
  63. break
  64. if user_msg:
  65. messages.append({"role": "user", "content": user_msg})
  66. if model_msg:
  67. messages.append({"role": "assistant", "content": model_msg})
  68. model_inputs = tokenizer.apply_chat_template(messages,
  69. add_generation_prompt=True,
  70. tokenize=True,
  71. return_tensors="pt").to(next(model.parameters()).device)
  72. streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
  73. generate_kwargs = {
  74. "input_ids": model_inputs,
  75. "streamer": streamer,
  76. "max_new_tokens": max_length,
  77. "do_sample": True,
  78. "top_p": top_p,
  79. "temperature": temperature,
  80. "stopping_criteria": StoppingCriteriaList([stop]),
  81. "repetition_penalty": 1.2,
  82. }
  83. t = Thread(target=model.generate, kwargs=generate_kwargs)
  84. t.start()
  85. for new_token in streamer:
  86. if new_token != '':
  87. history[-1][1] += new_token
  88. yield history
  89. with gr.Blocks() as demo:
  90. gr.HTML("""<h1 align="center">ChatGLM3-6B Gradio Simple Demo</h1>""")
  91. chatbot = gr.Chatbot()
  92. with gr.Row():
  93. with gr.Column(scale=4):
  94. with gr.Column(scale=12):
  95. user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
  96. with gr.Column(min_width=32, scale=1):
  97. submitBtn = gr.Button("Submit")
  98. with gr.Column(scale=1):
  99. emptyBtn = gr.Button("Clear History")
  100. max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
  101. top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
  102. temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
  103. def user(query, history):
  104. return "", history + [[parse_text(query), ""]]
  105. submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
  106. predict, [chatbot, max_length, top_p, temperature], chatbot
  107. )
  108. emptyBtn.click(lambda: None, None, chatbot, queue=False)
  109. if __name__ == '__main__':
  110. model = loadModel()
  111. tokenizer = loadTokenizer()
  112. demo.queue()
  113. demo.launch(server_name="0.0.0.0", server_port=8989, inbrowser=True, share=False)

调用结果:

启动成功:

GPU使用情况:

浏览器访问:

推理:


五、附带说明

5.1. 问题:AttributeError: 'ChatGLMTokenizer' object has no attribute 'apply_chat_template'

1. transformers的版本太低,需要升级

pip install --upgrade transformers==4.39.3

5.2. 界面无法打开

1. 服务监听地址不能是127.0.0.1

2. 检查服务器的安全策略或防火墙配置

 服务端:lsof -i:8989 查看端口是否正常监听

 客户端:telnet ip 8989 查看是否可以正常连接

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/991282
推荐阅读
相关标签
  

闽ICP备14008679号