赞
踩
- import os
- import sys
- import torch
- import hashlib
- from types import MethodType
- from typing import List, Literal, Optional, Tuple
-
- import transformers
- from transformers import (
- AutoConfig,
- AutoModel,
- AutoTokenizer,
- HfArgumentParser,
- Seq2SeqTrainingArguments,
- BitsAndBytesConfig
- )
- from transformers.utils import check_min_version
- from transformers.utils.versions import require_version
- from transformers.modeling_utils import PreTrainedModel
- from transformers.tokenization_utils import PreTrainedTokenizer
-
- import datasets
- from datasets import Dataset, concatenate_datasets, load_dataset
-
- from peft import (
- PeftModel,
- TaskType,
- LoraConfig,
- get_peft_model
- )
-
- from peft.utils import CONFIG_NAME, WEIGHTS_NAME
-
- from trl import AutoModelForCausalLMWithValueHead
-
- from .config import (
- ModelArguments,
- DataTrainingArguments,
- FinetuningArguments,
- GeneratingArguments
- )
-
- from .other import (
- get_logger,
- load_trainable_params,
- load_valuehead_params,
- print_trainable_params,
- prepare_model_for_training,
- IGNORE_INDEX
- )
-
- check_min_version("4.27.4")
- require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0")
- require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
- require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
- require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
-
-
- logger = get_logger(__name__)
-
-
- def init_adapter(
- model: PreTrainedModel,
- model_args: ModelArguments,
- finetuning_args: FinetuningArguments,
- is_trainable: bool
- ) -> PreTrainedModel:
- r"""
- Initializes the adapters.
- Support full-parameter, freeze, P-Tuning v2 and LoRA training.
- Note that the trainable parameters must be cast to float32.
- """
-
- if finetuning_args.finetuning_type == "none" and is_trainable:
- raise ValueError("You cannot use finetuning_type=none while training.")
-
- if finetuning_args.finetuning_type == "full":
- logger.info("Fine-tuning method: Full")
- model = model.float()
-
- if finetuning_args.finetuning_type == "freeze":
- logger.info("Fine-tuning method: Freeze")
-
- for name, param in model.named_parameters():
- if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
- param.requires_grad_(False)
- else:
- param.data = param.data.to(torch.float32)
-
- if model_args.checkpoint_dir is not None:
- assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
-
- if finetuning_args.finetuning_type == "p_tuning":
- logger.info("Fine-tuning method: P-Tuning v2")
-
- if model_args.checkpoint_dir is not None:
- assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
-
- if finetuning_args.finetuning_type == "lora":
- logger.info("Fine-tuning method: LoRA")
- lastest_checkpoint = None
-
- if model_args.checkpoint_dir is not None:
- assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
- "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
- assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
- "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/p_tuning/freeze` instead."
-
- if is_trainable and model_args.resume_lora_training: # continually train on the lora weights
- checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
- else:
- checkpoints_to_merge = model_args.checkpoint_dir
-
- for checkpoint in checkpoints_to_merge:
- model = PeftModel.from_pretrained(model, checkpoint)
- model = model.merge_and_unload()
-
- if len(checkpoints_to_merge) > 0:
- logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
-
- if lastest_checkpoint is not None: # resume lora training
- model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=True)
-
- if is_trainable and lastest_checkpoint is None: # create new lora weights while training
- lora_config = LoraConfig(
- task_type=TaskType.CAUSAL_LM, # we should regard ChatGLM as a causal LM
- inference_mode=False,
- r=finetuning_args.lora_rank,
- lora_alpha=finetuning_args.lora_alpha,
- lora_dropout=finetuning_args.lora_dropout,
- target_modules=finetuning_args.lora_target
- )
- model = get_peft_model(model, lora_config)
-
- if model_args.checkpoint_dir is not None:
- logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
-
- return model
-
-
- def load_pretrained(
- model_args: ModelArguments,
- finetuning_args: FinetuningArguments,
- is_trainable: Optional[bool] = False,
- stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
- ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
- r"""
- Loads pretrained model and tokenizer.
- Support both training and inference.
- """
- if (not is_trainable) and model_args.checkpoint_dir is None:
- logger.warning("Checkpoint is not found at evaluation, load the original model.")
- finetuning_args = FinetuningArguments(finetuning_type="none")
-
- assert stage == "sft" or finetuning_args.finetuning_type == "lora", \
- "RM and PPO training can only be performed with LoRA method."
-
- quantization = None
- if model_args.quantization_bit is not None:
- if is_trainable:
- if finetuning_args.finetuning_type == "full":
- raise ValueError("Full-parameter fine-tuning does not support quantization.")
- elif finetuning_args.finetuning_type == "p_tuning":
- quantization = "cpm" # use cpm's quantization
- else:
- quantization = "bnb" # use bnb's quantization
- else:
- quantization = "cpm"
-
- config_kwargs = {
- "trust_remote_code": True,
- "cache_dir": model_args.cache_dir,
- "revision": model_args.model_revision,
- "use_auth_token": True if model_args.use_auth_token else None,
- }
-
- tokenizer = AutoTokenizer.from_pretrained(
- model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
- use_fast=model_args.use_fast_tokenizer,
- padding_side="left",
- **config_kwargs
- )
-
- config = AutoConfig.from_pretrained(
- model_args.config_name if model_args.config_name else model_args.model_name_or_path,
- **config_kwargs
- )
-
- # P-Tuning v2 configurations. Use the built-in p-tuning method of ChatGLM.
- if finetuning_args.finetuning_type == "p_tuning":
- config.pre_seq_len = finetuning_args.pre_seq_len # enable this will fix other parameters automatically
- config.prefix_projection = finetuning_args.prefix_projection
-
- # Quantization configurations for Full, Freeze and LoRA in training (using bitsandbytes library).
- if quantization == "bnb":
- if model_args.quantization_bit == 8:
- require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
- config_kwargs["load_in_8bit"] = True
- config_kwargs["quantization_config"] = BitsAndBytesConfig(
- load_in_8bit=True,
- llm_int8_threshold=6.0
- )
- elif model_args.quantization_bit == 4:
- require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
- require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
- require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
- require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
- config_kwargs["load_in_4bit"] = True
- config_kwargs["quantization_config"] = BitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_compute_dtype=model_args.compute_dtype,
- bnb_4bit_use_double_quant=model_args.double_quantization,
- bnb_4bit_quant_type=model_args.quantization_type
- )
- config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)}
-
- if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
- model_to_load = model_args.checkpoint_dir[0]
- else:
- model_to_load = model_args.model_name_or_path
-
- # Load and prepare pretrained models (without valuehead).
- model = AutoModel.from_pretrained(model_to_load, config=config, **config_kwargs)
-
- # Register auto class to save the custom code files.
- if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
- config.__class__.register_for_auto_class()
- if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
- tokenizer.__class__.register_for_auto_class()
- if hasattr(config, "auto_map") and "AutoModel" in config.auto_map:
- model.__class__.register_for_auto_class()
-
- if model_args.use_v2:
- assert tokenizer.eos_token_id is not None, "Please update the *.json and *.py files of ChatGLM2-6B from HuggingFace."
- model.lm_head = model.transformer.output_layer
- output_embedding_base_layer = model.transformer
- output_embedding_layer_name = "output_layer"
- else:
- assert tokenizer.eos_token_id == 130005, "Please specify `use_v2` argument while using ChatGLM2-6B."
- output_embedding_base_layer = model
- output_embedding_layer_name = "lm_head"
-
- # Initialize adapters
- model = prepare_model_for_training(
- model,
- finetuning_args.finetuning_type,
- output_embedding_base_layer,
- output_embedding_layer_name
- ) if is_trainable else model
- model = init_adapter(model, model_args, finetuning_args, is_trainable)
-
- if not is_trainable:
- model.requires_grad_(False) # fix all model params
- model = model.half() # cast all params to float16 for inference
-
- # Quantization with the built-in method for P-Tuning v2 training or evaluation.
- # Model parameters should be cast to float16 in quantized P-Tuning setting.
- if quantization == "cpm":
- if is_trainable: # convert all params into half precision except prefix_encoder in training
- for name, param in model.named_parameters():
- if "prefix_encoder" not in name:
- param.data = param.data.to(torch.float16)
-
- model.quantize(model_args.quantization_bit) # built-in method in ChatGLM-6B, also an in-place operation
-
- if quantization is not None:
- logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
-
- if stage == "rm" or stage == "ppo": # add value head
- model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
-
- if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
- logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
- if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
- model.v_head.load_state_dict({
- "summary.weight": getattr(model, "reward_head_weight"),
- "summary.bias": getattr(model, "reward_head_bias")
- })
-
- if stage == "ppo": # load reward model
- assert is_trainable, "PPO stage cannot be performed at evaluation."
- assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
- logger.info("Load reward model from {}".format(model_args.reward_model))
- model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
- assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
-
- print_trainable_params(model)
-
- return model, tokenizer
-
-
- def prepare_args(
- stage: Literal["sft", "rm", "ppo"]
- ) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
-
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
-
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
- model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
- else:
- model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
-
- # Setup logging
- if training_args.should_log:
- # The default of training_args.log_level is passive, so we set log level at info here to have that default.
- transformers.utils.logging.set_verbosity_info()
-
- log_level = training_args.get_process_log_level()
- datasets.utils.logging.set_verbosity(log_level)
- transformers.utils.logging.set_verbosity(log_level)
- transformers.utils.logging.enable_default_handler()
- transformers.utils.logging.enable_explicit_format()
-
- # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
- assert stage == "sft" or (not training_args.predict_with_generate), \
- "`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
-
- assert not (training_args.do_train and training_args.predict_with_generate), \
- "`predict_with_generate` cannot be set as True while training."
-
- assert (not training_args.do_predict) or training_args.predict_with_generate, \
- "Please enable `predict_with_generate` to save model predictions."
-
- if model_args.quantization_bit is not None:
- assert finetuning_args.finetuning_type != "full" and finetuning_args.finetuning_type != "freeze", \
- "Quantization is incompatible with the full-parameter and freeze tuning."
-
- assert not (finetuning_args.finetuning_type == "p_tuning" and training_args.fp16), \
- "FP16 training conflicts with quantized P-Tuning."
-
- if not training_args.do_train:
- logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
-
- assert model_args.checkpoint_dir is None or finetuning_args.finetuning_type == "lora" \
- or len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
-
- if training_args.do_train and (not training_args.fp16):
- logger.warning("We recommend enable fp16 mixed precision training for ChatGLM-6B.")
-
- if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
- logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
- training_args.ddp_find_unused_parameters = False
-
- training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
-
- if model_args.quantization_bit is not None:
- if training_args.fp16:
- model_args.compute_dtype = torch.float16
- elif training_args.bf16:
- model_args.compute_dtype = torch.bfloat16
- else:
- model_args.compute_dtype = torch.float32
-
- # Log on each process the small summary:
- logger.info(
- f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
- + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
- )
- logger.info(f"Training/evaluation parameters {training_args}")
-
- # Set seed before initializing model.
- transformers.set_seed(training_args.seed)
-
- return model_args, data_args, training_args, finetuning_args
-
-
- def prepare_infer_args() -> Tuple[ModelArguments, FinetuningArguments, GeneratingArguments]:
-
- parser = HfArgumentParser((ModelArguments, FinetuningArguments, GeneratingArguments))
-
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
- model_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
- else:
- model_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
-
- assert model_args.checkpoint_dir is None or finetuning_args.finetuning_type == "lora" \
- or len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
-
- return model_args, finetuning_args, generating_args
-
-
- def prepare_data(
- model_args: ModelArguments,
- data_args: DataTrainingArguments
- ) -> Dataset:
-
- def checksum(file_path, hash):
- with open(file_path, "rb") as datafile:
- binary_data = datafile.read()
- sha1 = hashlib.sha1(binary_data).hexdigest()
- if sha1 != hash:
- logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
-
- ext2type = {
- "csv": "csv",
- "json": "json",
- "jsonl": "json"
- }
-
- max_samples = data_args.max_samples
- all_datasets: List[Dataset] = [] # support multiple datasets
-
- for dataset_attr in data_args.dataset_list:
-
- logger.info("Loading dataset {}...".format(dataset_attr))
-
- if dataset_attr.load_from == "hf_hub":
- data_path = dataset_attr.dataset_name
- data_files = None
- elif dataset_attr.load_from == "script":
- data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
- data_files = None
- elif dataset_attr.load_from == "file":
- data_path = None
- data_files: List[str] = []
-
- if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
- for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
- data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
-
- if data_path is None:
- data_path = ext2type.get(data_files[0].split(".")[-1], None)
- else:
- assert ext2type.get(data_files[-1].split(".")[-1], None) == data_path, "file type does not match."
- elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
- data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
- data_path = ext2type.get(data_files[0].split(".")[-1], None)
- else:
- raise ValueError("File not found.")
-
- assert data_path, "File extension must be csv, json or jsonl."
-
- if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
- checksum(data_files[0], dataset_attr.dataset_sha1)
- else:
- logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
- else:
- raise NotImplementedError
-
- raw_datasets = load_dataset(
- data_path,
- data_files=data_files,
- cache_dir=model_args.cache_dir,
- use_auth_token=True if model_args.use_auth_token else None
- )
- dataset = raw_datasets[data_args.split]
-
- if max_samples is not None:
- max_samples_temp = min(len(dataset), max_samples)
- dataset = dataset.select(range(max_samples_temp))
-
- dummy_data = [None] * len(dataset)
- for column_name, target_name in [
- ("prompt_column", "prompt"),
- ("query_column", "query"),
- ("response_column", "response"),
- ("history_column", "history")
- ]: # every dataset will have 4 columns same as each other
- if getattr(dataset_attr, column_name) != target_name:
- if getattr(dataset_attr, column_name):
- dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
- else: # None or empty string
- dataset = dataset.add_column(target_name, dummy_data)
- all_datasets.append(dataset)
-
- if len(data_args.dataset_list) == 1:
- all_datasets = all_datasets[0]
- else:
- all_datasets = concatenate_datasets(all_datasets)
-
- return all_datasets
-
-
- def preprocess_data(
- dataset: Dataset,
- tokenizer: PreTrainedTokenizer,
- data_args: DataTrainingArguments,
- training_args: Seq2SeqTrainingArguments,
- stage: Literal["sft", "rm", "ppo"]
- ) -> Dataset:
-
- column_names = list(dataset.column_names)
- prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
-
- def format_example(examples): # support question with a single answer or multiple answers
- for i in range(len(examples["prompt"])):
- if examples["prompt"][i] and examples["response"][i]:
- query, answer = examples["prompt"][i], examples["response"][i]
- query = query + examples["query"][i] if examples["query"][i] else query
- history = examples["history"][i] if examples["history"][i] else []
- prompt = ""
- for j, (old_query, response) in enumerate(history):
- prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(j+1, old_query, response)
- prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history)+1, query)
- prompt = prefix + prompt
- yield prompt, answer
-
- def preprocess_supervised_dataset(examples):
- # v1: build inputs with format `X [gMASK] <sop> Y <eop>` and labels with format `[IGNORE] ... [IGNORE] Y <eop>`
- # v2: build inputs with format `[gMASK] sop X Y </s>` and labels with format `[IGNORE] ... [IGNORE] Y </s>`
- model_inputs = {"input_ids": [], "labels": []}
- for prompt, answer in format_example(examples):
- source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
- target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
-
- if len(source_ids) > data_args.max_source_length - 2: # gmask and sop tokens
- source_ids = source_ids[:data_args.max_source_length - 2]
- if len(target_ids) > data_args.max_target_length - 1: # eos token
- target_ids = target_ids[:data_args.max_target_length - 1]
-
- context_length = len(source_ids) + 2 # gmask and sop tokens
- input_ids = tokenizer.build_inputs_with_special_tokens(source_ids, target_ids)
- labels = [IGNORE_INDEX] * context_length + input_ids[context_length:]
-
- model_inputs["input_ids"].append(input_ids)
- model_inputs["labels"].append(labels)
- return model_inputs
-
- def preprocess_evaluation_dataset(examples):
- # v1: build inputs with format `X [gMASK] <sop>` and labels with format `Y [gMASK] <sop>`
- # v2: build inputs with format `[gMASK] sop X` and labels with format `[gMASK] sop Y`
- model_inputs = {"input_ids": [], "labels": []}
- for prompt, answer in format_example(examples):
- source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
- target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
-
- if len(source_ids) > data_args.max_source_length - 2: # gmask and sop tokens
- source_ids = source_ids[:data_args.max_source_length - 2]
- if len(target_ids) > data_args.max_target_length - 2: # gmask and sop tokens
- target_ids = target_ids[:data_args.max_target_length - 2]
-
- input_ids = tokenizer.build_inputs_with_special_tokens(source_ids)
- labels = tokenizer.build_inputs_with_special_tokens(target_ids)
-
- model_inputs["input_ids"].append(input_ids)
- model_inputs["labels"].append(labels)
- return model_inputs
-
- def preprocess_pairwise_dataset(examples):
- # v1: build input pairs with format `X [gMASK] <sop> Y1 <eop>` and `X [gMASK] <sop> Y2 <eop>`
- # v2: build input pairs with format `[gMASK] sop X Y1 </s>` and `[gMASK] sop X Y2 </s>`
- model_inputs = {"accept_ids": [], "reject_ids": []}
- for prompt, answer in format_example(examples):
- source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
- accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
- reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
-
- if len(source_ids) > data_args.max_source_length - 2: # gmask and sop tokens
- source_ids = source_ids[:data_args.max_source_length - 2]
- if len(accept_ids) > data_args.max_target_length - 1: # eos token
- accept_ids = accept_ids[:data_args.max_target_length - 1]
- if len(reject_ids) > data_args.max_target_length - 1: # eos token
- reject_ids = reject_ids[:data_args.max_target_length - 1]
-
- accept_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], accept_ids) # avoid copying error
- reject_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], reject_ids)
-
- model_inputs["accept_ids"].append(accept_ids)
- model_inputs["reject_ids"].append(reject_ids)
- return model_inputs
-
- def print_sft_dataset_example(example):
- print("input_ids:\n{}".format(example["input_ids"]))
- print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
- print("label_ids:\n{}".format(example["labels"]))
- print("labels:\n{}".format(
- tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
- skip_special_tokens=False)
- ))
-
- def print_pairwise_dataset_example(example):
- print("accept_ids:\n{}".format(example["accept_ids"]))
- print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
- print("reject_ids:\n{}".format(example["reject_ids"]))
- print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
-
- def print_ppo_dataset_example(example):
- print("input_ids:\n{}".format(example["input_ids"]))
- print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
-
- if stage == "sft":
- preprocess_function = preprocess_evaluation_dataset \
- if training_args.predict_with_generate else preprocess_supervised_dataset
- elif stage == "rm":
- preprocess_function = preprocess_pairwise_dataset
- elif stage == "ppo":
- preprocess_function = preprocess_evaluation_dataset
-
- with training_args.main_process_first(desc="dataset map pre-processing"):
- dataset = dataset.map(
- preprocess_function,
- batched=True,
- num_proc=data_args.preprocessing_num_workers,
- remove_columns=column_names,
- load_from_cache_file=not data_args.overwrite_cache,
- desc="Running tokenizer on dataset"
- )
-
- if stage == "sft":
- print_sft_dataset_example(dataset[0])
- elif stage == "rm":
- print_pairwise_dataset_example(dataset[0])
- elif stage == "ppo":
- print_ppo_dataset_example(dataset[0])
-
- return dataset
这段代码主要用于初始化模型和适配器,并根据参数配置对模型进行微调。下面是对每行代码的详细解释:
from transformers.utils import check_min_version
:从 transformers 库导入 check_min_version
函数,用于检查 transformers 库的版本是否满足最低要求。from transformers.utils.versions import require_version
:从 transformers 库导入 require_version
函数,用于检查特定库的版本是否满足要求。from transformers.modeling_utils import PreTrainedModel
:从 transformers 库导入 PreTrainedModel
类,该类是所有预训练模型的基类。from transformers.tokenization_utils import PreTrainedTokenizer
:从 transformers 库导入 PreTrainedTokenizer
类,该类是所有预训练分词器的基类。import datasets
:导入 datasets 库,该库包含大量的公开数据集和评估度量。from datasets import Dataset, concatenate_datasets, load_dataset
:从 datasets 库导入 Dataset
类、concatenate_datasets
函数和 load_dataset
函数。这些都用于处理数据集。PeftModel
、TaskType
、LoraConfig
和 get_peft_model
。Peft 库应该是用于特殊的模型微调和参数设置的库,但在 2021 年 9 月的知识截止日期之前并未找到详细信息。from peft.utils import CONFIG_NAME, WEIGHTS_NAME
:从 peft 库的 utils 模块导入 CONFIG_NAME
和 WEIGHTS_NAME
,这两个常量可能是用来指定模型配置和权重的文件名。from trl import AutoModelForCausalLMWithValueHead
:从 trl 库导入 AutoModelForCausalLMWithValueHead
类,用于生成一个带有值头的因果语言模型,这可能是进行强化学习训练的模型。from .config import (ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments)
:从本地的 config 模块导入四个类,这些类定义了模型、训练、微调和生成等过程中的参数。from .other import (get_logger, load_trainable_params, load_valuehead_params, print_trainable_params, prepare_model_for_training, IGNORE_INDEX)
:从本地的 other 模块导入多个函数和常量,这些函数和常量用于日志记录、加载参数、打印参数、准备训练模型等。check_min_version("4.27.4")
:检查 transformers 库的版本是否至少为 4.27.4。require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0")
:检查 datasets 库的版本是否至少为 2.10.0,如果不满足,则提示如何修复。require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
:检查 accelerate 库的版本是否至少为 0.19.0,如果不满足,则提示如何修复。require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
:检查 peft 库的版本是否至少为 0.3.0,如果不满足,则提示如何修复。require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
:检查 trl 库的版本是否至少为 0.4.4,如果不满足,则提示如何修复。logger = get_logger(__name__)
:通过 get_logger 函数获取一个日志记录器,name 是当前模块的名称。接下来是 init_adapter
函数的定义,它用于初始化适配器并根据参数对模型进行微调。输入参数包括预训练模型、模型参数、微调参数和一个指示模型是否可训练的布尔值,返回值是经过微调的模型。函数的主要步骤包括:
def init_adapter(model: PreTrainedModel, model_args: ModelArguments, finetuning_args: FinetuningArguments, is_trainable: bool) -> PreTrainedModel
:定义函数 init_adapter
,输入是模型、模型参数、微调参数和一个布尔值,输出是微调后的模型。
通过 finetuning_args.finetuning_type
判断微调的类型,并根据微调的类型进行不同的处理。例如,如果微调类型为 "full",则将模型的所有参数转换为 float 类型;如果微调类型为 "freeze",则冻结一些层的参数;如果微调类型为 "p_tuning" 或 "lora",则通过特定的方式加载参数等。
在微调过程中,根据 model_args.checkpoint_dir
检查并加载模型检查点。
如果微调类型为 "lora",则进一步创建或加载 LoRA 权重。
最后,如果存在模型检查点,那么在日志中记录已加载的模型检查点信息,并返回微调后的模型。
注意:本段代码是用于模型微调的一部分,其中包含了许多特定的函数和类,如 PeftModel
、get_peft_model
、LoraConfig
、load_trainable_params
等,对这些函数和类的具体功能和实现,需要结合相关的库和代码进行理解。
init_adapter
函数:18-24. if finetuning_args.finetuning_type == "none" and is_trainable: raise ValueError("You cannot use finetuning_type=none while training.")
:如果微调类型为 "none" 并且模型是可训练的,那么抛出 ValueError 异常。这是因为在可训练的情况下,不应设置微调类型为 "none"。
25-29. if finetuning_args.finetuning_type == "full": logger.info("Fine-tuning method: Full") model = model.float()
:如果微调类型为 "full",则在日志中记录微调方法为 "Full",并将模型的所有参数转换为浮点类型。
30-39. if finetuning_args.finetuning_type == "freeze": ...
:如果微调类型为 "freeze",则在日志中记录微调方法为 "Freeze",然后遍历模型的所有参数,如果参数不在 finetuning_args.trainable_layers
列表中,那么将参数的 requires_grad
属性设置为 False,即冻结参数;否则,将参数的数据类型转换为浮点类型。然后,如果 model_args.checkpoint_dir
不为空,那么从检查点目录加载模型的可训练参数。
40-44. if finetuning_args.finetuning_type == "p_tuning": ...
:如果微调类型为 "p_tuning",则在日志中记录微调方法为 "P-Tuning v2"。然后,如果 model_args.checkpoint_dir
不为空,那么从检查点目录加载模型的可训练参数。
45-76. if finetuning_args.finetuning_type == "lora": ...
:如果微调类型为 "lora",则在日志中记录微调方法为 "LoRA",然后执行一系列操作,包括检查检查点目录是否包含 LoRA 权重,从检查点目录加载 LoRA 权重,合并和卸载模型,从最新的检查点重新加载模型,创建新的 LoRA 权重等。
77-78. if model_args.checkpoint_dir is not None: logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
:如果 model_args.checkpoint_dir
不为空,那么在日志中记录已经从检查点加载了微调后的模型。
return model
:返回经过微调的模型。总结:此函数主要用于初始化适配器并进行模型的微调。首先根据微调的类型("full"、"freeze"、"p_tuning" 或 "lora")对模型进行不同的处理,然后如果提供了模型检查点,那么从检查点中加载模型的可训练参数或者 LoRA 权重。最后,返回经过微调的模型。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。