当前位置:   article > 正文

模型运行页面设计webui代码_torch_gc

torch_gc
  1. from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments
  2. from deep_training.nlp.models.chatglm import setup_model_profile, ChatGLMConfig
  3. from deep_training.nlp.models.lora.v2 import LoraArguments
  4. from transformers import HfArgumentParser
  5. from typing import Optional, List, Tuple
  6. from data_utils import train_info_args, NN_DataHelper
  7. from models import MyTransformer,ChatGLMTokenizer
  8. import os
  9. import gradio as gr
  10. from webui.context import ctx
  11. from webui.device import torch_gc
  12. css = "style.css"
  13. script_path = "scripts"
  14. _gradio_template_response_orig = gr.routes.templates.TemplateResponse
  15. # 加载模型
  16. train_info_args['seed'] = None
  17. parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, LoraArguments))
  18. model_args, training_args, data_args, _ = parser.parse_dict(train_info_args)
  19. setup_model_profile()
  20. dataHelper = NN_DataHelper(model_args, training_args, data_args)
  21. tokenizer: ChatGLMTokenizer
  22. tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config(tokenizer_class_name=ChatGLMTokenizer, config_class_name=ChatGLMConfig)
  23. config = ChatGLMConfig.from_pretrained('./best_ckpt')
  24. # config = ChatGLMConfig.from_pretrained('./best_ckpt')
  25. config.initializer_weight = False
  26. lora_args = LoraArguments.from_pretrained('./last_ckpt')
  27. # lora_args = LoraArguments.from_pretrained('./best_ckpt')
  28. assert lora_args.inference_mode == True and config.pre_seq_len is None
  29. pl_model = MyTransformer(config=config, model_args=model_args, training_args=training_args, lora_args=lora_args)
  30. # 加载lora权重
  31. pl_model.backbone.from_pretrained(pl_model.backbone.model, pretrained_model_name_or_path='./last_ckpt',
  32. lora_config=lora_args)
  33. # pl_model.backbone.from_pretrained(pl_model.backbone.model, pretrained_model_name_or_path = './best_ckpt', lora_config = lora_args)
  34. model = pl_model.get_glm_model()
  35. # 按需修改
  36. model.half().cuda()
  37. model = model.eval()
  38. def infer(query,
  39. history: Optional[List[Tuple]],
  40. max_length, top_p, temperature):
  41. # if cmd_opts.ui_dev:
  42. # return "hello", "hello, dev mode!"
  43. if not model:
  44. raise "Model not loaded"
  45. if history is None:
  46. history = []
  47. output, history = model.chat(
  48. tokenizer, query=query, history=history,
  49. max_length=max_length,
  50. top_p=top_p,
  51. temperature=temperature
  52. )
  53. print(output)
  54. torch_gc()
  55. return query, output
  56. def predict(query, max_length, top_p, temperature):
  57. ctx.limit_round()
  58. _, output = infer(
  59. query=query,
  60. history=ctx.history,
  61. max_length=max_length,
  62. top_p=top_p,
  63. temperature=temperature
  64. )
  65. ctx.append(query, output)
  66. torch_gc()
  67. # for clear input textbox
  68. return ctx.history, ""
  69. def clear_history():
  70. ctx.clear()
  71. return gr.update(value=[])
  72. def apply_max_round_click(max_round):
  73. ctx.max_rounds = max_round
  74. def reload_javascript():
  75. scripts_list = [os.path.join(script_path, i) for i in os.listdir(script_path) if i.endswith(".js")]
  76. javascript = ""
  77. # with open("script.js", "r", encoding="utf8") as js_file:
  78. # javascript = f'<script>{js_file.read()}</script>'
  79. for path in scripts_list:
  80. with open(path, "r", encoding="utf8") as js_file:
  81. javascript += f"\n<script>{js_file.read()}</script>"
  82. # todo: theme
  83. # if cmd_opts.theme is not None:
  84. # javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
  85. def template_response(*args, **kwargs):
  86. res = _gradio_template_response_orig(*args, **kwargs)
  87. res.body = res.body.replace(
  88. b'</head>', f'{javascript}</head>'.encode("utf8"))
  89. res.init_headers()
  90. return res
  91. gr.routes.templates.TemplateResponse = template_response
  92. def main():
  93. # 创建ui
  94. reload_javascript()
  95. with gr.Blocks(css=css, analytics_enabled=False) as chat_interface:
  96. prompt = "输入你的内容..."
  97. with gr.Row():
  98. with gr.Column(scale=3):
  99. gr.Markdown("""<h2><center>ChatGLM WebUI</center></h2>""")
  100. with gr.Row():
  101. with gr.Column(variant="panel"):
  102. with gr.Row():
  103. max_length = gr.Slider(minimum=4, maximum=4096, step=4, label='Max Length', value=2048)
  104. top_p = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Top P', value=0.7)
  105. with gr.Row():
  106. temperature = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Temperature',
  107. value=0.95)
  108. with gr.Row():
  109. max_rounds = gr.Slider(minimum=1, maximum=100, step=1, label="最大对话轮数(调小可以显著改善爆显存,但是会丢失上下文)",
  110. value=20)
  111. apply_max_rounds = gr.Button("✔", elem_id="del-btn")
  112. with gr.Row():
  113. with gr.Column(variant="panel"):
  114. with gr.Row():
  115. clear = gr.Button("清空对话(上下文)")
  116. with gr.Row():
  117. save_his_btn = gr.Button("保存对话")
  118. load_his_btn = gr.UploadButton("读取对话", file_types=['file'], file_count='single')
  119. with gr.Column(scale=7):
  120. chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=800)
  121. with gr.Row():
  122. input_message = gr.Textbox(placeholder=prompt, show_label=False, lines=2, elem_id="chat-input")
  123. clear_input = gr.Button("
    声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/405446
    推荐阅读
    相关标签