当前位置:   article > 正文

SFT实战微调Gemma_微调gemma 7b

微调gemma 7b

1.运行环境搭建:

虚拟环境实验要求:

  • python 3.10及以上版本
  • pytorch 1.12及以上版本,推荐2.0及以上版本
  • 建议使用CUDA 11.4及以上
  • transformers>=4.38.0
    请务必采用上面的环境,否则代码无法运行。如果python环境低于3.10则需要额外安装pip install bitsandbytes

Gemma模型链接和下载:

支持直接下载模型的repo(以7b-it为例,服务器性能低建议2b模型进行Demo实验):

from modelscope import snapshot_download
model_dir = snapshot_download("AI-ModelScope/gemma-7b-it")
  • 1
  • 2

2.SFT微调

SFT Trainer 是transformers.Trainer的子类,增加了处理PeftConfig的逻辑 .
根据不同需求则训练策略不同,下面是几个样例:

2.1在数据集合上二次预训练,对整个序列进行微调

from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("imdb", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)
trainer.train()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

注:dataset_text_field= " text "。dataset_text_field参数用于指示数据集中哪个字段包含作为模型输入的文本数据。它使datasets 库能够基于该字段中的文本数据自动创建ConstantLengthDataset,简化数据准备过程

2.2 仅在响应数据集合上进行二次微调

需要设置响应的模版: response_template = " ### Answer:"

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
        output_texts.append(text)
    return output_texts

response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
)

trainer.train()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

2.3、在对话数据集合上进行微调

需要设置指令模版和响应模版:

instruction_template = “### Human:”

response_template = “### Assistant:”

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

instruction_template = "### Human:"
response_template = "### Assistant:"
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    dataset_text_field="text",
    data_collator=collator,
)

trainer.train()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

2.4 使用alpaca的数据格式

from datasets import load_dataset
from trl import SFTTrainer
import transformers

dataset = load_dataset("tatsu-lab/alpaca", split="train")

model = transformers.AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-350m")

def formatting_prompts_func(examples):
    output_text = []
    for i in range(len(examples["instruction"])):
        instruction = examples["instruction"][i]
        input_text = examples["input"][i]
        response = examples["output"][i]

        if len(input_text) >= 2:
            text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
            
            ### Instruction:
            {instruction}
            
            ### Input:
            {input_text}
            
            ### Response:
            {response}
            '''
        else:
            text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
            
            ### Instruction:
            {instruction}
            
            ### Response:
            {response}
            '''
        output_text.append(text)

    return output_text

trainer = SFTTrainer(
    model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    max_seq_length=256,
    packing=False,
)

trainer.train()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

2.5 数据集打包

通过使用ConstantLengthDataset类,可以使得不同的句子拼成固定成本

设置packing=True SFTTrainer

使用prompt拼接如下:

def formatting_func(example):
    text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
    return text

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    packing=True,
    formatting_func=formatting_func
)

trainer.train()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

2.6、使用Adapter相关

使用adaper进行部分参数训练

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig

dataset = load_dataset("imdb", split="train")

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

trainer = SFTTrainer(
    "EleutherAI/gpt-neo-125m",
    train_dataset=dataset,
    dataset_text_field="text",
    peft_config=peft_config
)

trainer.train()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

2.7 使用int8精度进行训练

在模型加载的时候按照int8进行加载

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLM.from_pretrained(
    "EleutherAI/gpt-neo-125m",
    load_in_8bit=True,
    device_map="auto",
)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    dataset_text_field="text",
    peft_config=peft_config,
)

trainer.train()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

使用总结:
1、SFTTrainer 默认会把序列增加到 max_seq_length 长度;
2、使用 8bit 训练模型的时候,最好在外部加载模型,然后传入SFTTrainer
3、在外面创建模型的时候就不要向SFTTrainer传入from_pretrained()方法相关的参数

3.Gemma模型推理Demo

在此以gemma-2b为例进行实验

from modelscope import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("gemma-2b")
model = AutoModelForCausalLM.from_pretrained("gemma-2b", torch_dtype = torch.bfloat16, device_map="auto")

input_text = "hello."
messages = [
    {"role": "user", "content": input_text}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
input_ids = tokenizer([text], return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids,max_new_tokens=256)
print(tokenizer.decode(outputs[0]))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

4.实战 - SFT微调Gemma

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from peft import LoraConfig
import transformers

dataset =	load_dataset("json", data_files="./traffic_intent.json", split="train")


model = AutoModelForCausalLM.from_pretrained("gemma-2b",load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained("gemma-2b")

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
        output_texts.append(text)
    return output_texts

response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)


trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
    peft_config=peft_config,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        learning_rate=2e-4,
        num_train_epochs=3,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
)

trainer.train() 
trainer.save_model("outputs")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

gpu资源较好可使用7b模型的全精度进行Lora微调。

模型测试:

import os
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
LORA_WEIGHTS = "./outputs/"
model_id ="gemma-2b"
model = AutoModelForCausalLM.from_pretrained(model_id,load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model.eval()
model = PeftModel.from_pretrained(model, LORA_WEIGHTS)
print(model)
model = model.to("cuda")
prompt = "查看市区交通拥堵指数"
inp = tokenizer(prompt, max_length=512, return_tensors="pt").to("cuda")
outputs = model.generate(input_ids=inp["input_ids"], max_new_tokens=256)
print(tokenizer.decode(outputs[0]))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/368300
推荐阅读
相关标签
  

闽ICP备14008679号