赞
踩
- from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments
- from deep_training.nlp.models.chatglm import setup_model_profile, ChatGLMConfig
- from deep_training.nlp.models.lora.v2 import LoraArguments
- from transformers import HfArgumentParser
- from typing import Optional, List, Tuple
- from data_utils import train_info_args, NN_DataHelper
- from models import MyTransformer,ChatGLMTokenizer
-
- import os
-
- import gradio as gr
-
- from webui.context import ctx
- from webui.device import torch_gc
- css = "style.css"
- script_path = "scripts"
- _gradio_template_response_orig = gr.routes.templates.TemplateResponse
-
- # 加载模型
- train_info_args['seed'] = None
- parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, LoraArguments))
- model_args, training_args, data_args, _ = parser.parse_dict(train_info_args)
-
- setup_model_profile()
-
- dataHelper = NN_DataHelper(model_args, training_args, data_args)
- tokenizer: ChatGLMTokenizer
- tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config(tokenizer_class_name=ChatGLMTokenizer, config_class_name=ChatGLMConfig)
-
- config = ChatGLMConfig.from_pretrained('./best_ckpt')
- # config = ChatGLMConfig.from_pretrained('./best_ckpt')
- config.initializer_weight = False
-
- lora_args = LoraArguments.from_pretrained('./last_ckpt')
- # lora_args = LoraArguments.from_pretrained('./best_ckpt')
-
- assert lora_args.inference_mode == True and config.pre_seq_len is None
-
- pl_model = MyTransformer(config=config, model_args=model_args, training_args=training_args, lora_args=lora_args)
- # 加载lora权重
- pl_model.backbone.from_pretrained(pl_model.backbone.model, pretrained_model_name_or_path='./last_ckpt',
- lora_config=lora_args)
- # pl_model.backbone.from_pretrained(pl_model.backbone.model, pretrained_model_name_or_path = './best_ckpt', lora_config = lora_args)
-
- model = pl_model.get_glm_model()
- # 按需修改
- model.half().cuda()
- model = model.eval()
-
-
- def infer(query,
- history: Optional[List[Tuple]],
- max_length, top_p, temperature):
- # if cmd_opts.ui_dev:
- # return "hello", "hello, dev mode!"
-
- if not model:
- raise "Model not loaded"
-
- if history is None:
- history = []
- output, history = model.chat(
- tokenizer, query=query, history=history,
- max_length=max_length,
- top_p=top_p,
- temperature=temperature
- )
- print(output)
- torch_gc()
- return query, output
-
- def predict(query, max_length, top_p, temperature):
- ctx.limit_round()
- _, output = infer(
- query=query,
- history=ctx.history,
- max_length=max_length,
- top_p=top_p,
- temperature=temperature
- )
- ctx.append(query, output)
- torch_gc()
- # for clear input textbox
- return ctx.history, ""
-
-
- def clear_history():
- ctx.clear()
- return gr.update(value=[])
-
-
- def apply_max_round_click(max_round):
- ctx.max_rounds = max_round
-
- def reload_javascript():
- scripts_list = [os.path.join(script_path, i) for i in os.listdir(script_path) if i.endswith(".js")]
- javascript = ""
- # with open("script.js", "r", encoding="utf8") as js_file:
- # javascript = f'<script>{js_file.read()}</script>'
-
- for path in scripts_list:
- with open(path, "r", encoding="utf8") as js_file:
- javascript += f"\n<script>{js_file.read()}</script>"
-
- # todo: theme
- # if cmd_opts.theme is not None:
- # javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
-
- def template_response(*args, **kwargs):
- res = _gradio_template_response_orig(*args, **kwargs)
- res.body = res.body.replace(
- b'</head>', f'{javascript}</head>'.encode("utf8"))
- res.init_headers()
- return res
-
- gr.routes.templates.TemplateResponse = template_response
-
- def main():
- # 创建ui
- reload_javascript()
-
- with gr.Blocks(css=css, analytics_enabled=False) as chat_interface:
- prompt = "输入你的内容..."
- with gr.Row():
- with gr.Column(scale=3):
- gr.Markdown("""<h2><center>ChatGLM WebUI</center></h2>""")
- with gr.Row():
- with gr.Column(variant="panel"):
- with gr.Row():
- max_length = gr.Slider(minimum=4, maximum=4096, step=4, label='Max Length', value=2048)
- top_p = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Top P', value=0.7)
- with gr.Row():
- temperature = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Temperature',
- value=0.95)
-
- with gr.Row():
- max_rounds = gr.Slider(minimum=1, maximum=100, step=1, label="最大对话轮数(调小可以显著改善爆显存,但是会丢失上下文)",
- value=20)
- apply_max_rounds = gr.Button("✔", elem_id="del-btn")
-
- with gr.Row():
- with gr.Column(variant="panel"):
- with gr.Row():
- clear = gr.Button("清空对话(上下文)")
-
- with gr.Row():
- save_his_btn = gr.Button("保存对话")
- load_his_btn = gr.UploadButton("读取对话", file_types=['file'], file_count='single')
-
- with gr.Column(scale=7):
- chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=800)
- with gr.Row():
- input_message = gr.Textbox(placeholder=prompt, show_label=False, lines=2, elem_id="chat-input")
- clear_input = gr.Button("声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/405446推荐阅读
相关标签
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。