当前位置:   article > 正文

Pytorch框架下使用Gemma

Pytorch框架下使用Gemma

Gemma介绍

Gemma是谷歌发布的一款开源大语言模型,并且Gemma是开源大模型的SOTA,超越了Meta的LLaMa2。综合来说,Gemma有以下几个特点:

坚实的模型基础:Gemma使用与谷歌更强大的人工智能模型Gemini相同的研究和技术。这一共同的基础确保了Gemma建立在一个强大的基础上,并具有强大的能力潜力。

轻便易用:与体型较大的Gemini不同,Gemma的设计重量轻,所需资源较少。这使得它可以被更广泛的用户访问,包括研究人员、开发人员,甚至那些计算资源有限的用户。

开放定制:Gemma模型可以进行微调,这意味着它们可以针对特定任务或应用程序进行调整和定制。这允许用户根据自己的特定需求定制模型。

“Gemma”这个名字来源于拉丁语中“宝石”的意思,反映了谷歌将这些模型视为推进人工智能研发的宝贵工具。总的来说,谷歌的Gemma有望实现强大人工智能工具的公众化,并为人工智能开发创造一个更具包容性和协作性的环境。

所以Google将Gemma放到了Github上:google/gemma_pytorch: The official PyTorch implementation of Google's Gemma models (github.com)

这里使用Kaggle这个数据科学竞赛网站提供的环境。

Kaggle申请Gemma

首先你得有个Kaggle账号,这个简单注册即可,就是有时候需要科学上网才能注册成功。

接着打开Gemma的模型页面申请:Gemma | Kaggle

申请过程就是很简单的填写同意书并接受条款和条件

最后选择Pytorch版本,点击new notebook就可以开始使用模型了。

基本使用代码讲解

Gemma模型页面中给出了Pytorch版本使用Gemma的基本代码,只需要复制进新建的notebook就可以运行。

运行示例代码会得到一个结果:
 

'What is a popular area in California? €)\n tanong:\nSneakyThrows model\nSneakyThrows user\nSneakyThrows user (more than two lines)SneakyThrows modelSneakyThrows userSneakyThrows modelSneakyThrows modelSneakyThrows userSneakyThrows modelSneakyThrows modelSneakyThrows\nSneakyThrows answerSneakyThrows userSneakyThrows modelSneakyThrows modelSneakyThrows modelSneakyThrows userSneakyThrows modelSneakyThrows userSneakyThrows modelSneakyThrows\nSneakyThrows userSneakyThrows modelSneakyThrows modelSneakyThrowsSneakyThrows modelSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrowsSneakyThrows'

就是除了第一句话,其他根本毫无逻辑。我们尝试解读一下代码,去了解为什么会生成一堆这种东西。

配置环境

  1. # Setup the environment
  2. # 安装必要的库
  3. !pip install -q -U immutabledict sentencepiece
  4. # 克隆 gemma_pytorch 代码库
  5. !git clone https://github.com/google/gemma_pytorch.git
  6. # 将 gemma_pytorch/gemma 目录下的所有文件移动到当前目录
  7. !mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/gemma/

安装的两个陌生的库——immutabledict 和 sentencepiece。

1. ImmutableDict

简介:

immutabledict 是一个 Python 库,提供不可变字典类型。与普通的字典不同,不可变字典一旦创建后就不能被修改。这使得不可变字典在多线程环境下更加安全,并且可以提高代码的性能。

主要功能:

  • 创建不可变字典
  • 访问不可变字典中的元素
  • 遍历不可变字典
  • 检查不可变字典中是否存在某个键
  • 将不可变字典转换为其他类型

使用场景:

  • 多线程环境
  • 需要提高代码性能的场景
  • 需要保证数据一致性的场景

2. SentencePiece

简介:

sentencepiece 是一个 Python 库,提供文本分词功能。它可以将文本分割成单个字符、子词或词语,以便于后续的处理。

主要功能:

  • 文本分词
  • 词汇表生成
  • 模型训练
  • 模型预测

使用场景:

  • 机器翻译
  • 文本摘要
  • 信息抽取
  • 问答系统

immutabledictsentencepiece 都是非常有用的 Python 库。immutabledict 可以提高代码的安全性 and 性能,sentencepiece 可以提高文本处理的效率。

加载模型

  1. # 导入必要的类库
  2. import sys # 系统模块
  3. sys.path.append("/kaggle/working/gemma_pytorch/") # 添加自定义库路径
  4. from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b # 加载配置类
  5. from gemma.model import GemmaForCausalLM # 加载模型类
  6. from gemma.tokenizer import Tokenizer # 加载分词器类
  7. import contextlib # 上下文管理工具
  8. import os # 操作系统模块
  9. import torch # 深度学习框架
  10. # 设置模型类型和设备
  11. VARIANT = "2b" # 模型类型(2b 或 7b)
  12. MACHINE_TYPE = "cpu" # 运行设备(cpu 或 cuda)
  13. weights_dir = '/kaggle/input/gemma/pytorch/2b/2' # 模型权重所在目录
  14. # 定义上下文管理器,设置默认张量类型
  15. @contextlib.contextmanager
  16. def _set_default_tensor_type(dtype: torch.dtype):
  17. """
  18. 设置默认 torch dtype 为指定值,并在上下文结束后恢复为 float。
  19. """
  20. torch.set_default_dtype(dtype)
  21. yield
  22. torch.set_default_dtype(torch.float)
  23. # 加载模型配置
  24. model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
  25. model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model") # 设置分词器路径
  26. # 设置设备
  27. device = torch.device(MACHINE_TYPE)
  28. # 使用上下文管理器设置默认张量类型,并加载模型
  29. with _set_default_tensor_type(model_config.get_dtype()):
  30. model = GemmaForCausalLM(model_config) # 创建模型对象
  31. ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt') # 获取模型权重路径
  32. model.load_weights(ckpt_path) # 加载模型权重
  33. model = model.to(device).eval() # 将模型移动到指定设备并设置为评估模式

这段代码首先导入必要的类库,然后设置模型类型、设备和权重目录。接着定义了一个上下文管理器,用于设置默认张量类型。最后,加载模型配置,创建模型对象,加载模型权重,并将其移动到指定设备并设置为评估模式。

使用模型

  1. # 用户聊天模板
  2. USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
  3. # 解释:
  4. # - `<start_of_turn>` 表示对话开始
  5. # - `user` 表示用户
  6. # - `{prompt}` 表示用户输入
  7. # - `<end_of_turn>` 表示对话结束
  8. # 模型聊天模板
  9. MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn>\n"
  10. # 解释:
  11. # - `<start_of_turn>` 表示对话开始
  12. # - `model` 表示模型
  13. # - `{prompt}` 表示模型输出
  14. # - `<end_of_turn>` 表示对话结束
  15. # 生成对话
  16. prompt = (
  17. USER_CHAT_TEMPLATE.format(prompt="中国哪个大学最好?")
  18. + MODEL_CHAT_TEMPLATE.format(prompt="北京大学")
  19. + USER_CHAT_TEMPLATE.format(prompt="北京大学在QS的排名是多少")
  20. + "<start_of_turn>model\n"
  21. )
  22. # 解释:
  23. # - `prompt` 变量包含了完整的对话文本,包括用户输入和模型输出
  24. # 生成模型回复
  25. model.generate(
  26. prompt,
  27. device=device,
  28. output_len=100,
  29. )
  30. # 解释:
  31. # - `model.generate` 函数用于生成模型回复
  32. # - `prompt` 是模型输入
  33. # - `device` 是模型运行设备
  34. # - `output_len` 是模型输出的最大长度

这里我对原代码做出了改变,即最后的生成的输入使用prompt代替

USER_CHAT_TEMPLATE.format(prompt=prompt)

这是因为我觉得prompt已经遵循user和model一问一答的格式,并且最后并没有加end_of_turn这样符合模型继续生成下面文字的逻辑,而如果在格式化原来的prompt则不合逻辑。

另一个改变是我将prompt内容做了改变,测试模型对中文问题的回答怎么样。

当然结果比较惨:

"2005-2006内考满170分以上,2007-2008年170分以上,2009年1分子到160以上就真的不会考啦 prospetvi\n conquête+ de l'argent\n北京大学在QS的排名是209! RequiresApi to the moon! (๑•̀ㅂ•́)و✧\n conquête+ de l'argent\n北京大学在QS"

生成的答案更为离谱,能从生成的文字中看出答案,不过堂堂北大竟被认为QS排名209的水校。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/143375
推荐阅读
相关标签
  

闽ICP备14008679号

        
cppcmd=keepalive&