当前位置:   article > 正文

在阿里云上部署Llama3(中文版本)_llama3 chinese

llama3 chinese
1、代码和模型的下载,模型下载到代码根目录下
  1. git clone https://github.com/CrazyBoyM/llama3-Chinese-chat --depth 1
  2. git lfs clone https://www.modelscope.cn/baicai003/Llama3-Chinese_v2.git
2、环境搭建

1.新建conda环境

  1. conda create -n llama3 python=3.10
  2. conda activate llama3

2.新建requirements.txt文件,插入

  1. torch
  2. fairscale
  3. fire
  4. tiktoken==0.4.0
  5. blobfile
  6. transformers
  7. peft

3.执行

  1. pip install -r requirements.txt
  2. pip install -U streamlit
3、运行
  1. #/mnt/workspace/llama3-Chinese-chat/Llama3-Chinese_v2 为自己的实际路径,自己修改
  2. streamlit run deploy/web_streamlit_for_instruct.py /mnt/workspace/llama3-Chinese-chat/Llama3-Chinese_v2 --theme.base="dark" --server.address=127.0.0.1
  • fp16模式 大概占用16G显存,推荐24G显卡使用
  • int4模式大概占用8G显存,推荐至少10G显存使用,需要自行修改搜索代码中load_in_4bit=True
4、最终推理

默认情况下直接运行以下代码即可体验llama3中文对话,请自行修改model_name_or_path为您下载的模型路径

  1. from transformers import AutoTokenizer, AutoConfig, AddedToken, AutoModelForCausalLM, BitsAndBytesConfig
  2. from peft import PeftModel
  3. from dataclasses import dataclass
  4. from typing import Dict
  5. import torch
  6. import copy
  7. ## 定义聊天模板
  8. @dataclass
  9. class Template:
  10. template_name:str
  11. system_format: str
  12. user_format: str
  13. assistant_format: str
  14. system: str
  15. stop_word: str
  16. template_dict: Dict[str, Template] = dict()
  17. def register_template(template_name, system_format, user_format, assistant_format, system, stop_word=None):
  18. template_dict[template_name] = Template(
  19. template_name=template_name,
  20. system_format=system_format,
  21. user_format=user_format,
  22. assistant_format=assistant_format,
  23. system=system,
  24. stop_word=stop_word,
  25. )
  26. # 这里的系统提示词是训练时使用的,推理时可以自行尝试修改效果
  27. register_template(
  28. template_name='llama3',
  29. system_format='<|begin_of_text|><<SYS>>\n{content}\n<</SYS>>\n\n',
  30. user_format='<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>',
  31. assistant_format='<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|end_of_text|>\n',
  32. system="You are a helpful, excellent and smart assistant. "
  33. "Please respond to the user using the language they input, ensuring the language is elegant and fluent."
  34. "If you don't know the answer to a question, please don't share false information.",
  35. stop_word='<|end_of_text|>'
  36. )
  37. ## 加载模型
  38. def load_model(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None):
  39. if load_in_4bit:
  40. quantization_config = BitsAndBytesConfig(
  41. load_in_4bit=True,
  42. bnb_4bit_compute_dtype=torch.float16,
  43. bnb_4bit_use_double_quant=True,
  44. bnb_4bit_quant_type="nf4",
  45. llm_int8_threshold=6.0,
  46. llm_int8_has_fp16_weight=False,
  47. )
  48. else:
  49. quantization_config = None
  50. # 加载base model
  51. model = AutoModelForCausalLM.from_pretrained(
  52. model_name_or_path,
  53. load_in_4bit=load_in_4bit,
  54. trust_remote_code=True,
  55. low_cpu_mem_usage=True,
  56. torch_dtype=torch.float16,
  57. device_map='auto',
  58. quantization_config=quantization_config
  59. )
  60. # 加载adapter
  61. if adapter_name_or_path is not None:
  62. model = PeftModel.from_pretrained(model, adapter_name_or_path)
  63. return model
  64. ## 加载tokenzier
  65. def load_tokenizer(model_name_or_path):
  66. tokenizer = AutoTokenizer.from_pretrained(
  67. model_name_or_path,
  68. trust_remote_code=True,
  69. use_fast=False
  70. )
  71. if tokenizer.pad_token is None:
  72. tokenizer.pad_token = tokenizer.eos_token
  73. return tokenizer
  74. ## 构建prompt
  75. def build_prompt(tokenizer, template, query, history, system=None):
  76. template_name = template.template_name
  77. system_format = template.system_format
  78. user_format = template.user_format
  79. assistant_format = template.assistant_format
  80. system = system if system is not None else template.system
  81. history.append({"role": 'user', 'message': query})
  82. input_ids = []
  83. # 添加系统信息
  84. if system_format is not None:
  85. if system is not None:
  86. system_text = system_format.format(content=system)
  87. input_ids = tokenizer.encode(system_text, add_special_tokens=False)
  88. # 拼接历史对话
  89. for item in history:
  90. role, message = item['role'], item['message']
  91. if role == 'user':
  92. message = user_format.format(content=message, stop_token=tokenizer.eos_token)
  93. else:
  94. message = assistant_format.format(content=message, stop_token=tokenizer.eos_token)
  95. tokens = tokenizer.encode(message, add_special_tokens=False)
  96. input_ids += tokens
  97. input_ids = torch.tensor([input_ids], dtype=torch.long)
  98. return input_ids
  99. def main():
  100. model_name_or_path = 'shareAI/llama3-Chinese-chat-8b' # 模型名称或路径,请修改这里
  101. template_name = 'llama3'
  102. adapter_name_or_path = None
  103. template = template_dict[template_name]
  104. # 若开启4bit推理能够节省很多显存,但效果可能下降
  105. load_in_4bit = False
  106. # 生成超参配置,可修改以取得更好的效果
  107. max_new_tokens = 500 # 每次回复时,AI生成文本的最大长度
  108. top_p = 0.9
  109. temperature = 0.6 # 越大越有创造性,越小越保守
  110. repetition_penalty = 1.1 # 越大越能避免吐字重复
  111. # 加载模型
  112. print(f'Loading model from: {model_name_or_path}')
  113. print(f'adapter_name_or_path: {adapter_name_or_path}')
  114. model = load_model(
  115. model_name_or_path,
  116. load_in_4bit=load_in_4bit,
  117. adapter_name_or_path=adapter_name_or_path
  118. ).eval()
  119. tokenizer = load_tokenizer(model_name_or_path if adapter_name_or_path is None else adapter_name_or_path)
  120. if template.stop_word is None:
  121. template.stop_word = tokenizer.eos_token
  122. stop_token_id = tokenizer.encode(template.stop_word, add_special_tokens=True)
  123. assert len(stop_token_id) == 1
  124. stop_token_id = stop_token_id[0]
  125. history = []
  126. query = input('# User:')
  127. while True:
  128. query = query.strip()
  129. input_ids = build_prompt(tokenizer, template, query, copy.deepcopy(history), system=None).to(model.device)
  130. outputs = model.generate(
  131. input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
  132. top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,
  133. eos_token_id=stop_token_id
  134. )
  135. outputs = outputs.tolist()[0][len(input_ids[0]):]
  136. response = tokenizer.decode(outputs)
  137. response = response.strip().replace(template.stop_word, "").strip()
  138. # 存储对话历史
  139. history.append({"role": 'user', 'message': query})
  140. history.append({"role": 'assistant', 'message': response})
  141. # 当对话长度超过6轮时,清空最早的对话,可自行修改
  142. if len(history) > 12:
  143. history = history[:-12]
  144. print("# Llama3-Chinese:{}".format(response))
  145. query = input('# User:')
  146. if __name__ == '__main__':
  147. main()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/534616
推荐阅读
相关标签
  

闽ICP备14008679号