赞
踩
Miniconda-conda3-Python3.8(ubuntu18.04)-Cuda11.3
!pip install peft
!pip install cpm_kernels
!pip install icetk
!git clone https://github.com/THUDM/ChatGLM-6B.git
cd ChatGLM-6B
!pip install -r requirements.txt
import torch
device = torch.device('cuda:0')
import os
import torch
import numpy as np
import json
from transformers import AutoTokenizer, AutoModel, AutoConfig, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType
from torch.utils.data import Dataset
checkpoint = "THUDM/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True)
def load_lora_config(model):
# 设定 lora 微调参数,r=16层可训练参数
config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["query_key_value"]
)
return get_peft_model(model, config)
training_args = TrainingArguments( output_dir = "./", fp16 = True, save_steps = 1000, save_total_limit = 5, gradient_accumulation_steps = 4, per_device_train_batch_size = 1, learning_rate = 1e-4, max_steps=3000, logging_steps=100, remove_unused_columns=False, seed=500, data_seed=500, group_by_length=False, dataloader_pin_memory=False ) # 配置QADataset和Trainer实例 train_dataset = QADataset(train_data, tokenizer=tokenizer) trainer = ModifiedTrainer( model=model, train_dataset=train_dataset, args=training_args, data_collator=collate_fn, tokenizer=tokenizer )
QADataset
和 ModifiedTrainer
的构造可以参考文章:trainer.train()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。