赞
踩
提示:不废话直接干货
驱动:建议MLU370 -5.10.22
镜像:
git clone -b v4.33.0 https://githubfast.com/huggingface/transformers.git
python /torch/src/catch/tools/torch_gpu2mlu/torch_gpu2mlu.py -i transformers/
pip install -e ./transformers_mlu/
git clone -b v0.22.0 https://githubfast.com/huggingface/accelerate.git
python /torch/src/catch/tools/torch_gpu2mlu/torch_gpu2mlu.py -i accelerate/
pip install -e ./accelerate_mlu/
wget https://sdk.cambricon.com/static/Basis/MLU370_X86_ubuntu20.04/deepspeed_mlu-0.9.0-py3-none-any.whl
python /torch/src/catch/tools/torch_gpu2mlu/torch_gpu2mlu.py -i accelerate/
pip install deepspeed_mlu-0.9.0-py3-none-any.whl
pip install modelscope sentencepiece
from modelscope import snapshot_download
model_dir = snapshot_download("ZhipuAI/chatglm3-6b", revision = "v1.0.0")
提示:下载好的模型记得mv到存储卷中,我们需要修改部分源码
32-36注释
279行
++attention_mask = attention_mask.cpu()
attention_mask.tril_()
++attention_mask = attention_mask.mlu()
提示:模型改成自己得
import os import platform from transformers import AutoTokenizer, AutoModel MODEL_PATH = os.environ.get('MODEL_PATH', '/workspace/volume/gpt/zhouguojun/GLM3/chatglm3-6b') TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval() os_name = platform.system() clear_command = 'cls' if os_name == 'Windows' else 'clear' stop_stream = False welcome_prompt = "欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" def build_prompt(history): prompt = welcome_prompt for query, response in history: prompt += f"\n\n用户:{query}" prompt += f"\n\nChatGLM3-6B:{response}" return prompt def main(): past_key_values, history = None, [] global stop_stream print(welcome_prompt) while True: query = input("\n用户:") if query.strip() == "stop": break if query.strip() == "clear": past_key_values, history = None, [] os.system(clear_command) print(welcome_prompt) continue print("\nChatGLM:", end="") current_length = 0 for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1, temperature=0.01, past_key_values=past_key_values, return_past_key_values=True): if stop_stream: stop_stream = False break else: print(response[current_length:], end="", flush=True) current_length = len(response) print("") if __name__ == "__main__": main()
cnmon可查看显存,下期见拜拜!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。