赞
踩
python cli_demo.py
- import os
- import platform
- import signal
- from transformers import AutoTokenizer, AutoModel
- import readline
-
- tokenizer = AutoTokenizer.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True)
- model = AutoModel.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True).half().cuda()
- model = model.eval()
-
- os_name = platform.system()
- clear_command = 'cls' if os_name == 'Windows' else 'clear'
- stop_stream = False
-
-
- def build_prompt(history):
- prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
- for query, response in history:
- prompt += f"\n\n用户:{query}"
- prompt += f"\n\nChatGLM-6B:{response}"
- return prompt
-
-
- def signal_handler(signal, frame):
- global stop_stream
- stop_stream = True
-
-
- def main():
- history = []
- global stop_stream
- print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
- while True:
- query = input("\n用户:")
- if query.strip() == "stop":
- break
- if query.strip() == "clear":
- history = []
- os.system(clear_command)
- print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
- continue
- count = 0
- for response, history in model.stream_chat(tokenizer, query, history=history):
- if stop_stream:
- stop_stream = False
- break
- else:
- count += 1
- if count % 8 == 0:
- os.system(clear_command)
- print(build_prompt(history), flush=True)
- signal.signal(signal.SIGINT, signal_handler)
- os.system(clear_command)
- print(build_prompt(history), flush=True)
-
-
- if __name__ == "__main__":
- main()

具体使用了ChatGLM-6B模型进行聊天对话。下面逐行解释这段代码:
import os
: 导入os模块,用于访问操作系统功能。
import platform
: 导入platform模块,用于获取操作系统信息。
import signal
: 导入signal模块,用于处理信号。
from transformers import AutoTokenizer, AutoModel
: 从transformers库导入AutoTokenizer和AutoModel,用于加载预训练的模型和对应的tokenizer。
import readline
: 导入readline模块,用于Python控制台的输入。
tokenizer = AutoTokenizer.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True)
: 加载预训练模型的tokenizer,从相对路径"../ChatGLM-Tuning-master/chatglm-6b"。
model = AutoModel.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True).half().cuda()
: 加载预训练模型,并将其转移到GPU上,同时使用半精度浮点数(half-precision floating point)来提高运算速度。
model = model.eval()
: 将模型设置为评估模式,通常在测试或验证阶段使用。
os_name = platform.system()
: 获取操作系统名称。
clear_command = 'cls' if os_name == 'Windows' else 'clear'
: 根据操作系统类型设置清屏命令。
stop_stream = False
: 定义一个全局变量stop_stream,用于控制是否停止模型的流式对话。
12-18. def build_prompt(history)
: 定义一个函数,根据历史对话构建提示文本。
19-22. def signal_handler(signal, frame)
: 定义一个信号处理函数,当接收到中断信号时,改变全局变量stop_stream的值。
23-47. def main()
: 定义主函数,处理用户输入和模型响应的交互,可以响应"stop"和"clear"命令,可以通过Ctrl+C来中断模型的响应。
48-50. if __name__ == "__main__": main()
: 如果该文件被直接运行(而不是作为模块导入),则调用main()函数。
def build_prompt(history):
这个函数接收一个历史对话的列表,然后将其格式化为一个提示文本,包括用户和ChatGLM-6B的所有对话内容。
def signal_handler(signal, frame):
这是一个信号处理函数,它会在接收到特定信号(如用户按下Ctrl+C)时被调用。函数的作用是将全局变量stop_stream
设为True,从而在主循环中用来停止模型的流式对话。
def main():
这是主函数,处理用户输入和模型响应的交互。它首先定义一个空的历史对话列表,然后进入一个无限循环,在循环中等待用户输入,并对用户输入做出响应。这里有几个关键的部分:
query = input("\n用户:")
: 等待用户输入。if query.strip() == "stop": break
: 如果用户输入"stop",则跳出循环,结束程序。if query.strip() == "clear":
如果用户输入"clear",则清空历史对话列表,清屏,并打印欢迎语句。for response, history in model.stream_chat(tokenizer, query, history=history):
使用模型进行流式聊天,模型将在每个步骤生成一个响应,并更新历史对话。这是一个阻塞操作,会等待模型生成响应。if stop_stream:
如果全局变量stop_stream
为True(表示收到了中断信号),则停止流式对话,并将stop_stream
设回False。if count % 8 == 0:
如果已经生成了8个响应,就清屏并打印历史对话,然后继续等待模型生成更多的响应。if __name__ == "__main__": main()
这是Python的常见模式,只有当脚本被直接运行时,才会执行main()
函数。如果脚本被作为模块导入,main()
函数就不会被执行。
tokenizer = AutoTokenizer.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True)
: 这一行从本地路径加载了一个预训练的tokenizer。这个tokenizer用于将原始文本输入转化为模型可以理解的形式。参数trust_remote_code=True
是说信任远程代码,通常与Hugging Face模型库中的自定义模型有关。
model = AutoModel.from_pretrained("../ChatGLM-Tuning-master/chatglm-6b", trust_remote_code=True).half().cuda()
: 这行代码从本地路径加载了一个预训练的模型,并且使用半精度(half-precision)和GPU(通过.cuda())来进行计算。半精度可以加速模型的计算速度,而牺牲一部分的精度。
for response, history in model.stream_chat(tokenizer, query, history=history):
这里开始了一个流式的聊天对话。模型根据给定的历史对话和用户的最新输入生成响应,生成一个响应后,就立即返回,然后继续生成下一个响应。在每个步骤中,都会返回新的响应和更新后的历史对话。
if count % 8 == 0:
这是一个简单的计数逻辑,每生成8个响应,就清屏并打印所有的历史对话。这样可以保证屏幕上不会有太多的文本。
signal.signal(signal.SIGINT, signal_handler)
: 这行代码设置了一个信号处理函数,当接收到中断信号(如用户按下Ctrl+C)时,会调用signal_handler
函数。这样用户可以通过按Ctrl+C来中断模型的流式对话。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。