当前位置:   article > 正文

[AI]如何让语言模型LLMs流式输出:HuggingFace Transformers实现_textiteratorstreamer

textiteratorstreamer

HugginFace Transforms是一个非常方便的库,集成了非常多SOTA的模型,包含:LLAMA, GPT, ChatGLM Moss,等。目前基本上主流的方案都是基于HugginFace Transforms这个框架实现的。以前如果要流式输出需要自己去改模型底层的推理逻辑。

如ChatGLM,自己实现的流式输出如下:

  1. #chatglm-6bmodel/modeling_chatglm.py
  2. @torch.no_grad()
  3. def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
  4. do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
  5. if history is None:
  6. history = []
  7. if logits_processor is None:
  8. logits_processor = LogitsProcessorList()
  9. logits_processor.append(InvalidScoreLogitsProcessor())
  10. gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
  11. "temperature": temperature, "logits_processor": logits_processor, **kwargs}
  12. if not history:
  13. prompt = query
  14. else:
  15. prompt = ""
  16. for i, (old_query, response) in enumerate(history):
  17. prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
  18. prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
  19. inputs = tokenizer([prompt], return_tensors="pt")
  20. inputs = inputs.to(self.device)
  21. for outputs in self.stream_generate(**inputs, **gen_kwargs):
  22. outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
  23. response = tokenizer.decode(outputs)
  24. response = self.process_response(response)
  25. new_history = history + [(query, response)]
  26. yield response, new_history
  27. @torch.no_grad()
  28. def stream_generate(
  29. self,
  30. input_ids,
  31. generation_config: Optional[GenerationConfig] = None,
  32. logits_processor: Optional[LogitsProcessorList] = None,
  33. stopping_criteria: Optional[StoppingCriteriaList] = None,
  34. prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
  35. **kwargs,
  36. ):
  37. batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
  38. if generation_config is None:
  39. generation_config = self.generation_config
  40. generation_config = copy.deepcopy(generation_config)
  41. model_kwargs = generation_config.update(**kwargs)
  42. bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
  43. if isinstance(eos_token_id, int):
  44. eos_token_id = [eos_token_id]
  45. has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
  46. if has_default_max_length and generation_config.max_new_tokens is None:
  47. warnings.warn(
  48. f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
  49. "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
  50. " recommend using `max_new_tokens` to control the maximum length of the generation.",
  51. UserWarning,
  52. )
  53. elif generation_config.max_new_tokens is not None:
  54. generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
  55. if not has_default_max_length:
  56. logger.warn(
  57. f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
  58. f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
  59. "Please refer to the documentation for more information. "
  60. "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
  61. UserWarning,
  62. )
  63. if input_ids_seq_length >= generation_config.max_length:
  64. input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
  65. logger.warning(
  66. f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
  67. f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
  68. " increasing `max_new_tokens`."
  69. )
  70. # 2. Set generation parameters if not already defined
  71. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
  72. stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
  73. logits_processor = self._get_logits_processor(
  74. generation_config=generation_config,
  75. input_ids_seq_length=input_ids_seq_length,
  76. encoder_input_ids=input_ids,
  77. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  78. logits_processor=logits_processor,
  79. )
  80. stopping_criteria = self._get_stopping_criteria(
  81. generation_config=generation_config, stopping_criteria=stopping_criteria
  82. )
  83. logits_warper = self._get_logits_warper(generation_config)
  84. unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
  85. scores = None
  86. while True:
  87. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  88. # forward pass to get next token
  89. outputs = self(
  90. **model_inputs,
  91. return_dict=True,
  92. output_attentions=False,
  93. output_hidden_states=False,
  94. )
  95. next_token_logits = outputs.logits[:, -1, :]
  96. # pre-process distribution
  97. next_token_scores = logits_processor(input_ids, next_token_logits)
  98. next_token_scores = logits_warper(input_ids, next_token_scores)
  99. # sample
  100. probs = nn.functional.softmax(next_token_scores, dim=-1)
  101. if generation_config.do_sample:
  102. next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
  103. else:
  104. next_tokens = torch.argmax(probs, dim=-1)
  105. # update generated ids, model inputs, and length for next step
  106. input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
  107. model_kwargs = self._update_model_kwargs_for_generation(
  108. outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
  109. )
  110. unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
  111. # stop when each sentence is finished, or if we exceed the maximum length
  112. if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
  113. break
  114. yield input_ids

HuggingFace Transformers实现

hugging face也注意到这个需求,在v4.30.1加入了两个流式输出的接口:

  • TextStreamer: 能够在stdout中流式输出结果
  • TextIteratorStreamer:能够在自定义loop中进行操作

详细介绍如下

TextStreamer

Text generation strategiesWe’re on a journey to advance and democratize artificial intelligence through open source and open science.https://huggingface.co/docs/transformers/main/generation_strategies

The generate() supports streaming, through its streamer input. The streamer input is compatible any instance from a class that has the following methods: put() and end(). Internally, put() is used to push new tokens and end() is used to flag the end of text generation.

The API for the streamer classes is still under development and may change in the future.

In practice, you can craft your own streaming class for all sorts of purposes! We also have basic streaming classes ready for you to use. For example, you can use the TextStreamer class to stream the output of generate() into your screen, one word at a time:

  1. from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
  2. tok = AutoTokenizer.from_pretrained("gpt2")
  3. model = AutoModelForCausalLM.from_pretrained("gpt2")
  4. inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
  5. streamer = TextStreamer(tok)
  6. # Despite returning the usual output, the streamer will also print the generated text to stdout.
  7. _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)

 TextIteratorStreamer

Utilities for GenerationWe’re on a journey to advance and democratize artificial intelligence through open source and open science.icon-default.png?t=N4P3https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.TextStreamer

Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive Gradio demo).

  1. from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
  2. from threading import Thread
  3. tok = AutoTokenizer.from_pretrained("gpt2")
  4. model = AutoModelForCausalLM.from_pretrained("gpt2")
  5. inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
  6. streamer = TextIteratorStreamer(tok)
  7. # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
  8. generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
  9. thread = Thread(target=model.generate, kwargs=generation_kwargs)
  10. thread.start()
  11. generated_text = ""
  12. for new_text in streamer:
  13. generated_text += new_text
  14. generated_text

ChatGLM流式回复Demo 

以下是使用chatGLM6B加上TextIteratorStreamerTextStreamer的一个简单的cli demo

  1. import os
  2. from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, AutoModel
  3. from transformers import TextIteratorStreamer
  4. from threading import Thread
  5. tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
  6. model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
  7. model = model.eval()
  8. # 建构显示对话
  9. def build_prompt(history):
  10. prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
  11. for query, response in history:
  12. prompt += f"\n\n用户:{query}"
  13. prompt += f"\n\nChatGLM-6B:{response}"
  14. return prompt
  15. # 维护多轮历史
  16. def build_history(history, query, response, index):
  17. history[index] = [query, response]
  18. return history
  19. if __name__ == "__main__":
  20. # TextIteratorStreamer实现
  21. streamer = TextIteratorStreamer(tokenizer)
  22. history = []
  23. turn_count = 0
  24. while True:
  25. query = input("\n用户:")
  26. if query.strip() == "stop":
  27. break
  28. if query.strip() == "clear":
  29. history = []
  30. turn_count = 0
  31. os.system(clear_command)
  32. print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
  33. continue
  34. history.append([query, ""])
  35. inputs = tokenizer([query], return_tensors="pt").to('cuda')
  36. generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
  37. thread = Thread(target=model.generate, kwargs=generation_kwargs)
  38. thread.start()
  39. generated_text = ""
  40. count = 0
  41. # 流式输出
  42. for new_text in streamer:
  43. generated_text += new_text
  44. history = build_history(history, query, generated_text, turn_count)
  45. count += 1
  46. if count % 8 == 0:
  47. os.system("clear")
  48. print(build_prompt(history), flush=True)
  49. os.system("clear")
  50. print(build_prompt(history), flush=True)
  51. turn_count += 1
  52. # TextStreamer实现
  53. # streamer = TextStreamer(tokenizer)
  54. # _ = model.generate(**inputs, streamer=streamer, max_new_tokens=512)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/679104
推荐阅读
相关标签
  

闽ICP备14008679号