当前位置:   article > 正文

llama3源码解读之推理-infer

llama3源码解读之推理

文章目录


前言

本项目是解读开源github的代码,该项目基于Meta最新发布的新一代开源大模型Llama-3开发,是Chinese-LLaMA-Alpaca开源大模型相关系列项目(一期、二期)的第三期。而本项目开源了中文Llama-3基座模型和中文Llama-3-Instruct指令精调大模型。这些模型在原版Llama-3的基础上使用了大规模中文数据进行增量预训练,并且使用精选指令数据进行精调,进一步提升了中文基础语义和指令理解能力,相比二代相关模型获得了显著性能提升。因此,我是基于该项目解读训练与推理相关原理与内容,并以代码形式带领读者一步一步解读,理解其大语言模型运行机理。而该博客首先给出llama3推理源码相关内容解读,我将按照源码流程给出解读。


一、整体源码解读

1、完整main源码

我先给出完整的源码,后面推理使用哪些部分代码,我在深度解读。而一些较为简单内容我不在解读了。

if __name__ == '__main__':
   load_type = torch.float16
   
   # Move the model to the MPS device if available
   if torch.backends.mps.is_available():
       device = torch.device("mps")
   else:
       if torch.cuda.is_available():
           device = torch.device(0)
       else:
           device = torch.device('cpu')
   print(f"Using device: {
     device}")

   if args.tokenizer_path is None:
       args.tokenizer_path = args.base_model
   tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
   terminators = [
               tokenizer.eos_token_id,
               tokenizer.convert_tokens_to_ids("<|eot_id|>")
           ]
   if args.use_vllm:
       model = LLM(model=args.base_model,
           tokenizer=args.tokenizer_path,
           tensor_parallel_size=len(args.gpus.split(',')),
           dtype=load_type
           )
       generation_config["stop_token_ids"] = terminators
       generation_config["stop"] = ["<|eot_id|>", "<|end_of_text|>"]
   else:
       if args.load_in_4bit or args.load_in_8bit:
           quantization_config = BitsAndBytesConfig(
               load_in_4bit=args.load_in_4bit,
               load_in_8bit=args.load_in_8bit,
               bnb_4bit_compute_dtype=load_type,
               bnb_4bit_use_double_quant=True,
               bnb_4bit_quant_type="nf4"
           )

       model = AutoModelForCausalLM.from_pretrained(
           args.base_model,
           torch_dtype=load_type,
           low_cpu_mem_usage=True,
           device_map='auto',
           quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,
           attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa"
       )
       if device==torch.device('cpu'):
           model.float()
       model.eval()
   # test data
   if args.data_file is None:
       examples = sample_data
   else:
       with open(args.data_file, 'r') as f:
           examples = [line.strip() for line in f.readlines()]
       print("first 10 examples:")
       for example in examples[:10]:
           print(example)

   with torch.no_grad():
       if args.interactive:
           print("Start inference with instruction mode.")
           print('='*85)
           print("+ 该模式下仅支持单轮问答,无多轮对话能力。\n"
                 "+ 如要进行多轮对话,请使用llama.cpp")
           print('-'*85)
           print("+ This mode only supports single-turn QA.\n"
                 "+ If you want to experience multi-turn dialogue, please use llama.cpp")
           print('='*85)

           while True:
               raw_input_text = input("Input:")
               if len(raw_input_text.strip())==0:
                   break
               if args.with_prompt:
                   input_text = generate_prompt(instruction=raw_input_text)
               else:
                   input_text = raw_input_text

               if args.use_vllm:
                   output = model.generate([input_text], SamplingParams(**generation_config), use_tqdm=False)
                   response = output[0].outputs[0].text
               else:
                   inputs = tokenizer(input_text,return_tensors="pt")  #add_special_tokens=False ?
                   generation_output = model.generate(
                       input_ids = inputs["input_ids"].to(device),
                       attention_mask = inputs['attention_mask'].to(device),
                       eos_token_id=terminators,
                       pad_token_id=tokenizer.eos_token_id,
                       generation_config = generation_config
                   )
                   s = generation_output[0]
                   output = tokenizer.decode(s, skip_special_tokens=True)
                   if args.with_prompt:
                       response = output.split("assistant\n\n")[-1].strip()
                   else:
                       response = output
               print("Response: ",response)
               print("\n")
       else:
           print("Start inference.")
           results = []
           if args.use_vllm:
               if args.with_prompt is True:
                   inputs = [generate_prompt(example) for example in examples]
               else:
                   inputs = examples
               outputs = model.generate(inputs, SamplingParams(**generation_config))

               for index, (example, output) in enumerate(zip(examples, outputs)):
                   response = output.outputs[0].text
                   print(f"======={
     index}=======")
                   print(f"Input: {
     example}\n")
                   print(f"Output: {
     response}\n")
                   results.append({
   "Input":example,"Output":response})
           else:
               for index, example in enumerate(examples):
                   if args.with_prompt:
                       input_text = generate_prompt(instruction=example)
                   else:
                       input_text = example
                   inputs = tokenizer(input_text,return_tensors="pt")  #add_special_tokens=False ?
                   generation_output = model.generate(
                       input_ids = inputs["input_ids"].to(device),
                       attention_mask = inputs['attention_mask'].to(device),
                       eos_token_id=terminators,
                       pad_token_id=tokenizer.eos_token_id,
                       generation_config = generation_config
                   )
                   s = generation_output[0]
                   output = tokenizer.decode(s,skip_special_tokens=True)
                   if args.with_prompt:
                       response = output.split("assistant\n\n")[1].strip()
                   else:
                       response = output
                   print(f"======={
     index}=======")
                   print(f"Input: {
     example}\n")
                   print(f"Output: {
     response}\n")

                   results.append({
   "Input":input_text,"Output":response})

           dirname = os.path.dirname(args.predictions_file)
           os.makedirs(dirname,exist_ok=True)
           with open(args.predictions_file,'w') as f:
               json.dump(results,f,ensure_ascii=False,indent=2)
           if args.use_vllm:
               with open(dirname+'/generation_config.json','w') as f:
                   json.dump(generation_config,f,ensure_ascii=False,indent=2)
           else:
               generation_config.save_pretrained('./')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159

2、tokenizer加载

有关tokenzier相关加载可参考博客这里。这里,我直接给出其源码,如下:

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
terminators = [
           tokenizer.eos_token_id,
           tokenizer.convert_tokens_to_ids("<|eot_id|>")
       ]
  • 1
  • 2
  • 3
  • 4
  • 5

tokenizer.eos_token_id=128009,而terminators=[128009,128009]。

3、llama3模型加载

huggingface模型加载可参考博客这里。这里,llama3的模型加载不在介绍,如下源码:

model = AutoModelForCausalLM.from_pretrained(
    args.base_model,  # 权重路径文件夹
    torch_dtype=load_type,
    low_cpu_mem_usage=True,
    device_map='auto',
    quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,
    attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa"
)
if device==torch.device('cpu'):
    model.float()
model.eval()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

注意:model.eval()为固定权重方式,这是pytorch评估类似。

4、llama3测试数据文本加载

 # test data
 if args.data_file is None:
     examples = sample_data  #  ["为什么要减少污染,保护环境?","你有什么建议?"]
 else:
     with open(args.data_file, 'r') as f:
         examples 
  • 1
  • 2
  • 3
  • 4
  • 5
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/天景科技苑/article/detail/963829
推荐阅读
相关标签
  

闽ICP备14008679号