当前位置:   article > 正文

计算语言模型计算每秒钟生成的token数量it/s_generationconfig.from_pretrained

generationconfig.from_pretrained

main() 函数的stream循环中,我们可以计算每秒钟生成的token数量,然后输出 it/s。在流式生成过程中,我们可以使用Python的time模块来计算速度。在测试时,生成速度会受到多个因素的影响,包括设备性能、模型大小、输入文本长度等。

  1. import os
  2. import torch
  3. import platform
  4. from colorama import Fore, Style
  5. from transformers import AutoModelForCausalLM, AutoTokenizer
  6. from transformers.generation.utils import GenerationConfig
  7. import time
  8. def init_model():
  9. print("init model ...")
  10. model = AutoModelForCausalLM.from_pretrained(
  11. "baichuan-inc/Baichuan-13B-Chat",
  12. torch_dtype=torch.float16,
  13. device_map="cuda",
  14. trust_remote_code=True
  15. )
  16. model.generation_config = GenerationConfig.from_pretrained(
  17. "baichuan-inc/Baichuan-13B-Chat"
  18. )
  19. tokenizer = AutoTokenizer.from_pretrained(
  20. "baichuan-inc/Baichuan-13B-Chat",
  21. use_fast=False,
  22. trust_remote_code=True
  23. )
  24. return model, tokenizer
  25. def clear_screen():
  26. if platform.system() == "Windows":
  27. os.system("cls")
  28. else:
  29. os.system("clear")
  30. print(Fore.YELLOW + Style.BRIGHT + "欢迎使用百川大模型,输入进行对话,clear 清空历史,CTRL+C 中断生成,stream 开关流式生成,exit 结束。")
  31. return []
  32. def main(stream=True):
  33. model, tokenizer = init_model()
  34. messages = clear_screen()
  35. while True:
  36. prompt = input(Fore.GREEN + Style.BRIGHT + "\n用户:" + Style.NORMAL)
  37. if prompt.strip() == "exit":
  38. break
  39. if prompt.strip() == "clear":
  40. messages = clear_screen()
  41. continue
  42. print(Fore.CYAN + Style.BRIGHT + "\nBaichuan:" + Style.NORMAL, end='')
  43. if prompt.strip() == "stream":
  44. stream = not stream
  45. print(Fore.YELLOW + "({}流式生成)\n".format("开启" if stream else "关闭"), end='')
  46. continue
  47. messages.append({"role": "user", "content": prompt})
  48. if stream:
  49. position = 0
  50. try:
  51. start_time = time.time()
  52. total_tokens = 0
  53. for response in model.chat(tokenizer, messages, stream=True):
  54. print(response[position:], end='', flush=True)
  55. position = len(response)
  56. total_tokens += len(tokenizer(response, return_tensors='pt')['input_ids'][0])
  57. if torch.backends.mps.is_available():
  58. torch.mps.empty_cache()
  59. end_time = time.time()
  60. elapsed_time = end_time - start_time
  61. tokens_per_second = total_tokens / elapsed_time
  62. print(f"\n\n生成速度:{tokens_per_second:.2f} tokens/s")
  63. except KeyboardInterrupt:
  64. pass
  65. print()
  66. else:
  67. response = model.chat(tokenizer, messages)
  68. print(response)
  69. if torch.backends.mps.is_available():
  70. torch.mps.empty_cache()
  71. messages.append({"role": "assistant", "content": response})
  72. print(Style.RESET_ALL)
  73. if __name__ == "__main__":
  74. main()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/103317
推荐阅读
相关标签
  

闽ICP备14008679号