当前位置:   article > 正文

vLLM部署qwen大模型加速推理实现_vllm 启动qwen

vllm 启动qwen

目录

step1: 编写 vllm_wrapper.py

step2: 应用场景:给体检指标生成健康建议


step1: 编写 vllm_wrapper.py

  1. from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
  2. from typing import Optional, Callable, List, Tuple, Union
  3. import copy
  4. import torch
  5. from transformers import AutoTokenizer
  6. from transformers.generation.logits_process import LogitsProcessorList
  7. from packaging import version
  8. _ERROR_BAD_CHAT_FORMAT = """\
  9. We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
  10. If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
  11. 我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
  12. 如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
  13. """
  14. IMEND = "<|im_end|>"
  15. ENDOFTEXT = "<|endoftext|>"
  16. HistoryType = List[Tuple[str, str]]
  17. TokensType = List[int]
  18. BatchTokensType = List[List[int]]
  19. def get_stop_words_ids(chat_format, tokenizer):
  20. if chat_format == "raw":
  21. stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
  22. elif chat_format == "chatml":
  23. stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
  24. else:
  25. raise NotImplementedError(f"Unknown chat format {chat_format!r}")
  26. return stop_words_ids
  27. def make_context(
  28. tokenizer: PreTrainedTokenizer,
  29. query: str,
  30. history: List[Tuple[str, str]] = None,
  31. system: str = "",
  32. max_window_size: int = 6144,
  33. chat_format: str = "chatml",
  34. ):
  35. if history is None:
  36. history = []
  37. if chat_format == "chatml":
  38. im_start, im_end = "<|im_start|>", "<|im_end|>"
  39. im_start_tokens = [tokenizer.im_start_id]
  40. im_end_tokens = [tokenizer.im_end_id]
  41. nl_tokens = tokenizer.encode("\n")
  42. def _tokenize_str(role, content):
  43. return f"{role}\n{content}", tokenizer.encode(
  44. role, allowed_special=set()
  45. ) + nl_tokens + tokenizer.encode(content, allowed_special=set())
  46. system_text, system_tokens_part = _tokenize_str("system", system)
  47. system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
  48. raw_text = ""
  49. context_tokens = []
  50. for turn_query, turn_response in reversed(history):
  51. query_text, query_tokens_part = _tokenize_str("user", turn_query)
  52. query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
  53. response_text, response_tokens_part = _tokenize_str(
  54. "assistant", turn_response
  55. )
  56. response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
  57. next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
  58. prev_chat = (
  59. f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
  60. )
  61. current_context_size = (
  62. len(system_tokens) + len(next_context_tokens) + len(context_tokens)
  63. )
  64. if current_context_size < max_window_size:
  65. context_tokens = next_context_tokens + context_tokens
  66. raw_text = prev_chat + raw_text
  67. else:
  68. break
  69. context_tokens = system_tokens + context_tokens
  70. raw_text = f"{im_start}{system_text}{im_end}" + raw_text
  71. context_tokens += (
  72. nl_tokens
  73. + im_start_tokens
  74. + _tokenize_str("user", query)[1]
  75. + im_end_tokens
  76. + nl_tokens
  77. + im_start_tokens
  78. + tokenizer.encode("assistant")
  79. + nl_tokens
  80. )
  81. raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
  82. elif chat_format == "raw":
  83. raw_text = query
  84. context_tokens = tokenizer.encode(raw_text)
  85. else:
  86. raise NotImplementedError(f"Unknown chat format {chat_format!r}")
  87. return raw_text, context_tokens
  88. class vLLMWrapper:
  89. def __init__(self,
  90. model_dir: str,
  91. trust_remote_code: bool = True,
  92. tensor_parallel_size: int = 1,
  93. gpu_memory_utilization: float = 0.98,
  94. dtype: str = "bfloat16",
  95. **kwargs):
  96. if dtype not in ("bfloat16", "float16", "float32"):
  97. print("now not support {}!".format(dtype))
  98. raise Exception
  99. # build generation_config
  100. self.generation_config = GenerationConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
  101. # build tokenizer
  102. self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
  103. self.tokenizer.eos_token_id = self.generation_config.eos_token_id
  104. self.stop_words_ids = []
  105. from vllm import LLM
  106. import vllm
  107. if version.parse(vllm.__version__) >= version.parse("0.2.2"):
  108. self.__vllm_support_repetition_penalty = True
  109. else:
  110. self.__vllm_support_repetition_penalty = False
  111. quantization = getattr(kwargs, 'quantization', None)
  112. self.model = LLM(model=model_dir,
  113. tokenizer=model_dir,
  114. tensor_parallel_size=tensor_parallel_size,
  115. trust_remote_code=trust_remote_code,
  116. quantization=quantization,
  117. gpu_memory_utilization=gpu_memory_utilization,
  118. dtype=dtype)
  119. for stop_id in get_stop_words_ids(self.generation_config.chat_format, self.tokenizer):
  120. self.stop_words_ids.extend(stop_id)
  121. self.stop_words_ids.extend([self.generation_config.eos_token_id])
  122. def chat(self,
  123. query: str,
  124. history: Optional[HistoryType],
  125. tokenizer: PreTrainedTokenizer = None,
  126. system: str = "You are a helpful assistant.",
  127. generation_config: Optional[GenerationConfig] = None,
  128. **kwargs):
  129. generation_config = generation_config if generation_config is not None else self.generation_config
  130. tokenizer = self.tokenizer if tokenizer is None else tokenizer
  131. assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
  132. if not self.__vllm_support_repetition_penalty and generation_config.repetition_penalty != 1:
  133. raise RuntimeError("The installed vLLM doesn't support repetition_penalty, please set ``model.generation_config.repetition_penalty = 1`` or install vllm>=0.2.2")
  134. if history is None:
  135. history = []
  136. else:
  137. # make a copy of the user's input such that is is left untouched
  138. history = copy.deepcopy(history)
  139. extra_stop_words_ids = kwargs.get('stop_words_ids', None)
  140. if extra_stop_words_ids is None:
  141. extra_stop_words_ids = []
  142. max_window_size = kwargs.get('max_window_size', None)
  143. if max_window_size is None:
  144. max_window_size = generation_config.max_window_size
  145. from vllm.sampling_params import SamplingParams
  146. sampling_kwargs = {
  147. "stop_token_ids": self.stop_words_ids,
  148. "early_stopping": False,
  149. "top_p": generation_config.top_p,
  150. "top_k": -1 if generation_config.top_k == 0 else generation_config.top_k,
  151. "temperature": generation_config.temperature,
  152. "max_tokens": generation_config.max_new_tokens,
  153. "repetition_penalty": generation_config.repetition_penalty
  154. }
  155. if not self.__vllm_support_repetition_penalty:
  156. sampling_kwargs.pop("repetition_penalty")
  157. sampling_params = SamplingParams(**sampling_kwargs)
  158. raw_text, context_tokens = make_context(
  159. self.tokenizer,
  160. query,
  161. history=history,
  162. system=system,
  163. max_window_size=max_window_size,
  164. chat_format=generation_config.chat_format,
  165. )
  166. req_outputs = self.model.generate([query],
  167. sampling_params=sampling_params,
  168. prompt_token_ids=[context_tokens])
  169. req_output = req_outputs[0]
  170. prompt_str = req_output.prompt
  171. prompt_ids = req_output.prompt_token_ids
  172. req_sample_output_ids = []
  173. req_sample_output_strs = []
  174. for sample in req_output.outputs:
  175. output_str = sample.text
  176. output_ids = sample.token_ids
  177. if IMEND in output_str:
  178. output_str = output_str[:-len(IMEND)]
  179. if ENDOFTEXT in output_str:
  180. output_str = output_str[:-len(ENDOFTEXT)]
  181. req_sample_output_ids.append(prompt_ids + output_ids)
  182. req_sample_output_strs.append(prompt_str + output_str)
  183. assert len(req_sample_output_strs) == 1
  184. response = req_sample_output_strs[0][len(prompt_str):]
  185. history.append((prompt_str, response))
  186. return response, history
  187. if __name__ == '__main__':
  188. model_dir = 'Qwen/Qwen-72B-Chat'
  189. tensor_parallel_size = 2
  190. model = vLLMWrapper(model_dir,
  191. tensor_parallel_size=tensor_parallel_size,
  192. )
  193. response, history = model.chat(query="你好",
  194. history=None)
  195. print(response)
  196. response, history = model.chat(query="给我讲一个年轻人奋斗创业最终取得成功的故事。",
  197. history=history)
  198. print(response)
  199. response, history = model.chat(query="给这个故事起一个标题",
  200. history=history)
  201. print(response)

step2: 应用场景:给体检指标生成健康建议

  1. import pandas as pd
  2. from peft import AutoPeftModelForCausalLM
  3. from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
  4. from vllm_wrapper import vLLMWrapper
  5. from tqdm import tqdm
  6. sft_path = "/mnt/sdd/Qwen-7B-Chat"
  7. tokenizer = AutoTokenizer.from_pretrained(sft_path, trust_remote_code=True)
  8. model = vLLMWrapper(sft_path, tensor_parallel_size=1)
  9. # model = AutoModelForCausalLM.from_pretrained(sft_path, device_map="auto", trust_remote_code=True).eval()
  10. import uuid
  11. def data_load():
  12. zb = pd.read_csv('/home/wangyp/Big_Model/infectious_disease/data/zb.csv', header=0)
  13. # print(zb.head(10))
  14. nums = []
  15. yins = {}
  16. nulls = []
  17. for item in zb.itertuples(index=False):
  18. if item.normal_desc is not None and item.normal_desc.strip() != '':
  19. yins[item.NAME.strip()] = item.normal_desc.strip()
  20. if( (item.range_floor is not None and str(item.range_floor).strip() != '') and (item.range_ceil is not None and str(item.range_ceil).strip() != '') ):
  21. nums.append({item.NAME.strip(): str(item.range_floor).strip() + "-" + str(item.range_ceil).strip() })
  22. else:
  23. nulls.append(item.NAME.strip())
  24. return yins, nums, nulls
  25. def zao_yins(yins):
  26. res = []
  27. for key in yins.keys():
  28. res.append( key + "这种体检指标阳性。")
  29. res.append( key + "的检测结果为阳性。")
  30. res.append( key + "的检测显示出阳性反应。")
  31. res.append( key + "检测结果为阳性。")
  32. res.append( key + "的阳性结果在体检中被检测到。")
  33. res.append( key + "试验为阳性状态")
  34. res.append( key + "的阳性结果在体检中得到了确认。")
  35. res.append( key + "的检测结果表明为阳性。")
  36. res.append( "在进行检测" + key + "指标时,结果被判定为阳性。")
  37. return res
  38. def zao_nums(nums):
  39. res = []
  40. keys_list = [key for d in nums for key in d.keys()]
  41. for key in keys_list:
  42. res.append({"name": key, "value": key + "检测结果显示异常。"})
  43. # res.append({"name": key, "value": key + "的检查值超出了正常范围。"})
  44. res.append({"name": key, "value": key + "的测量值与标准值不符。"})
  45. # res.append({"name": key, "value": key + "检测结果呈现异常状态。"})
  46. res.append({"name": key, "value": key + "的数值在体检中被标记为异常。"})
  47. # res.append({"name": key, "value": key + "检查结果显示了不正常的数值。"})
  48. res.append({"name": key, "value": key + "的检测结果不在正常参考值内。"})
  49. # res.append({"name": key, "value": key + "检查报告提示数值异常。"})
  50. # res.append({"name": key, "value": "体检报告指出" + key + "水平不正常。"})
  51. res.append({"name": key, "value": "体检中发现" + key + "水平异常。"})
  52. # res.append(key + "检测结果显示异常。")
  53. return res
  54. # 体检中发现尿酮水平异常,帮我生成10条描述,保持句子意思不变
  55. def z_nulls(nulls):
  56. res = []
  57. for key in nulls:
  58. res.append("体检结果显示" + key + "水平出现异常。")
  59. res.append("在进行体检时,发现" + key + "的数值不在正常范围内。")
  60. res.append("体检报告中指出" + key + "水平有异常情况。")
  61. res.append("体检时," + key + "水平的测定结果超出了预期的正常值。")
  62. res.append("体检中测得的" + key + "水平与正常值有所偏差。")
  63. res.append("体检数据中," + key + "的数值检测出异常。")
  64. res.append(key + "的检测结果表明存在异常。")
  65. res.append(key + "的检测值在体检中被标记为异常。")
  66. res.append(key + "水平的体检结果提示有异常。")
  67. return res
  68. yins_template = """
  69. """
  70. # 定义一个带有槽位的字符串模板
  71. # yins_template = "Hello, {name}! You are {age} years old."
  72. yins_template = """hhh******************"""
  73. nums_template = """
  74. {disc}你是一名体检报告领域的专家,请生成一段关于该体检指标异常的改善建议。\n下面是生成体检指标相关的建议时的要求:健康建议严格包含如下几项:复检确认、营养评估、医疗咨询、健康饮食、生活方式调整、药物治疗、定期监测、记录症状这几项。生成建议请参考以下格式:\n体检结果提示您的{name}不在正常参考值内,这可能与多种因素有关,包括营养不良、维生素缺乏或某些疾病状态。以下是一些建议:\n复检确认:{name}相关的复检建议。\n营养评估:考虑针对{name}进行一次全面的营养评估。\n医疗咨询:咨询医生,以确定是否需要进一步的检查和{name}相关的其他检测。如血红蛋白电泳、血清铁蛋白、维生素B12和叶酸水平检测。\n健康饮食:饮食建议,这些食物富含补充{name}必要的营养素。\n生活方式调整:保持适度的体育活动,避免饮酒和吸烟,这些都可能影响{name}的健康。\n药物治疗:如果医生建议,可能需要服用补充剂或药物来纠正{name}异常。\n定期监测:根据医生的建议,定期监测{name}和其他{name}相关指标。\n记录症状:注意任何可能与{name}相关的症状,如疲劳、头晕或呼吸困难,并及时告知医生。\n请记住,{name}的异常可能是多种情况的指标,因此重要的是遵循医疗专业人员的指导进行进一步的评估和治疗。\n
  75. """
  76. # filled_nums_template = nums_template.format(name=name)
  77. def load_model():
  78. pass
  79. if __name__ == '__main__':
  80. all = []
  81. yins, nums, nulls = data_load()
  82. # 遍历,一个字段造10个template,存储到list中,写入文件
  83. yins_tem = zao_yins(yins)
  84. nums_tem = zao_nums(nums)
  85. nulls_tem = z_nulls(nulls)
  86. # all = yins_tem + nums_tem + nulls_tem
  87. # print(len(all))
  88. nums_conversations = []
  89. for num in tqdm(nums_tem):
  90. filled_nums_template = nums_template.format(disc=num["value"], name=num["name"])
  91. response, history = model.chat(filled_nums_template, history=None)
  92. nums_conversations.append({"id": str(uuid.uuid4()), "conversations": [{"from": "user", "value": num["value"]}, {"from": "assistant", "value": response}]})
  93. with open("/home/wangyp/Big_Model/infectious_disease/data/zb_train.json", "w", encoding="utf-8") as f:
  94. f.write(",\n".join(str(i) for i in nums_conversations))
  95. print("nums_conversations数据处理完毕。。。。。。。。。。。。。。。。。。。")

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/885491
推荐阅读
相关标签
  

闽ICP备14008679号