赞
踩
python>=3.8,这里使用3.9版本
重要的python包及版本如下:
torch == 2.2.1(若是版本<=2.1,则需要删除后续github项目中model.py中的mmap=True,)
torchaudio == 2.2.1
torchvision == 0.17.1
transformers == 4.38.1
fairscale == 0.4.13
numpy == 1.24.4
immutabledict == 4.1.0
sentencepiece == 0.1.99
protobuf == 4.25.3
(1)去github上下载模型Pytorch代码,链接如下:
https://github.com/google/gemma_pytorch/tree/main
在github项目中添加文件夹google/gemma-7b-it
(2)去kaggle上下载Pytorch版的7b-it模型,链接如下:https://www.kaggle.com/models/google/gemma/frameworks/pyTorch/variations/7b-it
下载的是一个压缩文件archive.tar.gz,用tar指令解压得到三个文件,分别是config.json(需要手动修改使其符合json格式)、gemma-7b-it.ckpt(模型权重文件)、tokenizer.model(分词文件)。将这三个文件放入文件夹google/gemma-7b-it中。
在github项目中添加python代码测试模型。下面代码根据github项目中的scripts/run.py进行修改。
import argparse import contextlib import random import numpy as np import torch from gemma import config from gemma import model as gemma_model @contextlib.contextmanager def _set_default_tensor_type(dtype: torch.dtype): """Sets the default torch dtype to the given dtype.""" torch.set_default_dtype(dtype) yield torch.set_default_dtype(torch.float) def main(args): # Construct the model config. model_config = config.get_model_config(args.variant) model_config.dtype = "float32" #float16可能会超出半精度浮点范围, 改用float32 model_config.quant = args.quant # Seed random. random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) # Create the model and load the weights. device = torch.device(args.device) with _set_default_tensor_type(model_config.get_dtype()): model = gemma_model.GemmaForCausalLM(model_config) model.load_weights(args.ckpt) model = model.to(device).eval() print("Model loading done") # Generate the response. result = model.generate(args.prompt, device) # USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n' #官方提供的prompt格式, 这里未使用 # MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n' #官方提供的prompt格式, 这里未使用 # Print the prompts and results. print('======================================') print(f'PROMPT: {args.prompt}\n') print(f'RESULT: {result}') print('======================================') if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--ckpt", type=str, default="google/gemma-7b-it/gemma-7b-it.ckpt") parser.add_argument("--variant", type=str, default="7b", choices=["2b", "7b"]) parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"]) parser.add_argument("--output_len", type=int, default=256) parser.add_argument("--seed", type=int, default=12345) parser.add_argument("--quant", action='store_true') parser.add_argument("--prompt", type=str, default="""Extract time expressions, locations, directions and distances from following sentences: An attack on Sunday destroyed the train station, shops and homes in the town of Kostiantynivka near the frontline of Ukraine’s two-year-old war against Russia. Police said a guided aerial bomb hit the station and four S-300 missiles followed in the pre-dawn hours. Kostiantynivka lies 30km (18 miles) west of Bakhmut, which fell to Russian forces in May 2023, and north of Avdiivka, captured by the Russians last week.""") #测试的问题 args = parser.parse_args() main(args)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。