当前位置:   article > 正文

本地运行Gemma的pytorch集成

运行gemma

Gemma是Google在2024年2月21日发布的一款轻量的开源大模型,采用了和Google Gemini模型一样的技术。有猜测Google在毫无预告的情况下急忙发布Gemma是对Meta的Llama3的截胡,但不管怎么说作为名厂名牌的大模型,自然要上手尝试尝试。

这次发布的Gemma有2B参数和7B参数两个版本,两个版本又分别提供了预训练(Pretrained)和指令调试(Instruction tuned)两个版本。预训练版本做了基础训练,而指令调试版本做了根据人类语言交互的特定训练调整,所以如果直接拿来做会话使用可以下载it版本。2B和7B在于参数量的多少,7B需要更多的资源去运行。

好了,前面啰嗦了一堆背景,为了引出这里介绍2b-it版本地部署的原因——耗资源少且可以本地使用会话。

准备环境

  • 安装python venv,命名gemma-torch
conda env create -n "gemma-torch"
  • 激活虚拟环境
conda activate gemma-torch
  • 安装依赖的库
pip install torch immutabledict sentencepiece numpy packaging

 后面两个库不是官方文档里要求的,但是根据我执行报错,需要安装。另外上面命令也取消了-q -U简单粗暴也方便观察。

为了后续用代码连接kaggle下载模型,还需要安装kagglehub包:

pip install kagglehub

连接kaggle

这一步的目的是从kaggle上面下载模型。

  • 首先获取kaggle的访问权限

登录kaggle,在设置页面(https://www.kaggle.com/settings)的API一节点击“Create New Token”,会触发下载kaggle.json。

  • 配置环境

将kaggle.json文件拷贝到~/.kaggle/目录下。并在~/.bash_profile中设置环境变量KAGGLE_CONFIG_DIR为~/.kaggle。

这样就可以通过下面代码访问(后面的代码写到一块,不需要此处执行)。

  1. import kagglehub
  2. kagglehub.login()

运行代码

经过前面的配置后,可以代码本地运行2b-it模型了。不过加载模型还需要gemma_pytorch包。

从github仓库clone到本地:

  1. # NOTE: The "installation" is just cloning the repo.
  2. git clone https://github.com/google/gemma_pytorch.git

将下载好的gemma_pytorch文件夹放到下面脚本文件同一级目录下 ,并在~/.bash_profile中设置PYTHONPATH环境变量包含该文件夹路径。

最后运行脚本(gemma_torch.py):

  1. # Choose variant and machine type
  2. import kagglehub
  3. import os
  4. import sys
  5. from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
  6. from gemma_pytorch.gemma.model import GemmaForCausalLM
  7. import torch
  8. VARIANT = '2b-it'
  9. #如果是cpu运行,将下面cuda改成cpu,不过巨慢
  10. MACHINE_TYPE = 'cuda'
  11. # Load model weights
  12. # 模型下载到了~/.cache目录下
  13. weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')
  14. # Ensure that the tokenizer is present
  15. tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
  16. assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'
  17. # Ensure that the checkpoint is present
  18. ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
  19. assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'
  20. # Set up model config.
  21. model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
  22. model_config.tokenizer = tokenizer_path
  23. model_config.quant = 'quant' in VARIANT
  24. # Instantiate the model and load the weights.
  25. torch.set_default_dtype(model_config.get_dtype())
  26. device = torch.device(MACHINE_TYPE)
  27. model = GemmaForCausalLM(model_config)
  28. model.load_weights(ckpt_path)
  29. model = model.to(device).eval()
  30. # Generate with one request in chat mode
  31. # Chat templates
  32. USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
  33. MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'
  34. # Sample formatted prompt
  35. prompt = (
  36. USER_CHAT_TEMPLATE.format(
  37. prompt='What is a good place for travel in the US?'
  38. )
  39. + MODEL_CHAT_TEMPLATE.format(prompt='California.')
  40. + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
  41. + '<start_of_turn>model\n'
  42. )
  43. print('Chat prompt:\n', prompt)
  44. model.generate(
  45. USER_CHAT_TEMPLATE.format(prompt=prompt),
  46. device=device,
  47. output_len=100,
  48. )
  49. # Generate sample
  50. model.generate(
  51. 'Write a poem about an llm writing a poem.',
  52. device=device,
  53. output_len=60,
  54. )

一点后话

能用GPU还是上GPU吧,我本地用的CPU笔记本跑的巨慢。

可以在线使用colab,具体步骤参考这个帖子(昨天Google发布了最新的开源模型Gemma,今天我来体验一下_gemma_lm.generate-CSDN博客)。

不过我在使用过程中发现T4经常在预测执行时报OOM,导致无法产出结果。

参考资料:

pytorch中使用Gemma: https://ai.google.dev/gemma/docs/pytorch_gemma

官方文档地址:https://ai.google.dev/gemma/docs 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/378982
推荐阅读
相关标签
  

闽ICP备14008679号