当前位置:   article > 正文

Qwen2大模型微调_qwen2 微调

qwen2 微调

项目介绍

本项目主要关注在大模型的微调上,所以使用Lora技术对Qwen2大模型进行微调,打造了一个医疗问答助手,相关模型文件已在魔搭平台上发布。

模型介绍

Doctor-Qwen2是一个为医疗健康对话场景而打造的领域大模型,该模型基于Qwen2-1.5B-Instruct进行微调得来,使用的数据集是复旦大学数据智能与社会计算实验室开源的DISC-Med-SFT数据集。

模型推理范例

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("模型路径")
model = AutoModelForCausalLM.from_pretrained("模型路径", device_map={"":0})

prompt = "医生您好,我最近睡眠质量很差,晚上经常醒来,而且早上起来也疲惫不堪,我该如何才能改善睡眠质量?"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
device = 'cuda'
model_inputs = tokenizer([text], return_tensors="pt").to(device)

generated_ids = model.generate( model_inputs.input_ids, max_new_tokens=512)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

数据集预处理

假设原数据集里包括许多一轮对话,每一段对话的格式如下:

{   
    "instruction": "*********",
    "input": "*********", 
    "output": "*********"
}
  • 1
  • 2
  • 3
  • 4
  • 5
from datasets import Dataset
import pandas as pd

def process_func(example):
    """
    将数据集进行预处理
    """
    MAX_LENGTH = 384 
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(
        f"<|im_start|>system\n你是一个医疗领域的专家,你会接收到病人的提问,请输出合适的答案<|im_end|>\n<|im_start|>user\n{example['input']}<|im_end|>\n<|im_start|>assistant\n",
        add_special_tokens=False,
    )
    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = (
        instruction["attention_mask"] + response["attention_mask"] + [1]
    )
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}  

train_df = pd.read_json("原数据集路径", lines=True)
train_ds = Dataset.from_pandas(train_df)
train_dataset = train_ds.map(process_func, remove_columns=train_ds.column_names)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

模型微调范例

使用LoRA技术对模型进行微调,以满足特定任务需求。

from trl import SFTTrainer
from peft import LoraConfig, TaskType
from transformers import AutoModelForCausalLM, TrainingArguments, DataCollatorForSeq2Seq, AutoTokenizer
from datasets import load_from_disk

train_dataset = load_from_disk("处理过的数据集路径")
tokenizer = AutoTokenizer.from_pretrained("模型路径")
model = AutoModelForCausalLM.from_pretrained("模型路径", device_map={"":0})

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False,  
    r=8, 
    lora_alpha=32, 
    lora_dropout=0.1, 
)

args = TrainingArguments(
    output_dir="自定义输出路径",
    per_device_train_batch_size=5,
    gradient_accumulation_steps=5,
    logging_steps=10,
    num_train_epochs=3,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True, 
    gradient_checkpointing=True,
    report_to="none",
)

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
    peft_config=lora_config,
)

trainer.train()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号