当前位置:   article > 正文

源码解析 ChatGLM Efficient Tuning utils/common.py_warning - glmtuner.tuner.core.parser - `ddp_find_u

warning - glmtuner.tuner.core.parser - `ddp_find_unused_parameters` needs to
  1. import os
  2. import sys
  3. import torch
  4. import hashlib
  5. from types import MethodType
  6. from typing import List, Literal, Optional, Tuple
  7. import transformers
  8. from transformers import (
  9. AutoConfig,
  10. AutoModel,
  11. AutoTokenizer,
  12. HfArgumentParser,
  13. Seq2SeqTrainingArguments,
  14. BitsAndBytesConfig
  15. )
  16. from transformers.utils import check_min_version
  17. from transformers.utils.versions import require_version
  18. from transformers.modeling_utils import PreTrainedModel
  19. from transformers.tokenization_utils import PreTrainedTokenizer
  20. import datasets
  21. from datasets import Dataset, concatenate_datasets, load_dataset
  22. from peft import (
  23. PeftModel,
  24. TaskType,
  25. LoraConfig,
  26. get_peft_model
  27. )
  28. from peft.utils import CONFIG_NAME, WEIGHTS_NAME
  29. from trl import AutoModelForCausalLMWithValueHead
  30. from .config import (
  31. ModelArguments,
  32. DataTrainingArguments,
  33. FinetuningArguments,
  34. GeneratingArguments
  35. )
  36. from .other import (
  37. get_logger,
  38. load_trainable_params,
  39. load_valuehead_params,
  40. print_trainable_params,
  41. prepare_model_for_training,
  42. IGNORE_INDEX
  43. )
  44. check_min_version("4.27.4")
  45. require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0")
  46. require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
  47. require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
  48. require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
  49. logger = get_logger(__name__)
  50. def init_adapter(
  51. model: PreTrainedModel,
  52. model_args: ModelArguments,
  53. finetuning_args: FinetuningArguments,
  54. is_trainable: bool
  55. ) -> PreTrainedModel:
  56. r"""
  57. Initializes the adapters.
  58. Support full-parameter, freeze, P-Tuning v2 and LoRA training.
  59. Note that the trainable parameters must be cast to float32.
  60. """
  61. if finetuning_args.finetuning_type == "none" and is_trainable:
  62. raise ValueError("You cannot use finetuning_type=none while training.")
  63. if finetuning_args.finetuning_type == "full":
  64. logger.info("Fine-tuning method: Full")
  65. model = model.float()
  66. if finetuning_args.finetuning_type == "freeze":
  67. logger.info("Fine-tuning method: Freeze")
  68. for name, param in model.named_parameters():
  69. if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
  70. param.requires_grad_(False)
  71. else:
  72. param.data = param.data.to(torch.float32)
  73. if model_args.checkpoint_dir is not None:
  74. assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
  75. if finetuning_args.finetuning_type == "p_tuning":
  76. logger.info("Fine-tuning method: P-Tuning v2")
  77. if model_args.checkpoint_dir is not None:
  78. assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
  79. if finetuning_args.finetuning_type == "lora":
  80. logger.info("Fine-tuning method: LoRA")
  81. lastest_checkpoint = None
  82. if model_args.checkpoint_dir is not None:
  83. assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
  84. "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
  85. assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
  86. "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/p_tuning/freeze` instead."
  87. if is_trainable and model_args.resume_lora_training: # continually train on the lora weights
  88. checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
  89. else:
  90. checkpoints_to_merge = model_args.checkpoint_dir
  91. for checkpoint in checkpoints_to_merge:
  92. model = PeftModel.from_pretrained(model, checkpoint)
  93. model = model.merge_and_unload()
  94. if len(checkpoints_to_merge) > 0:
  95. logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
  96. if lastest_checkpoint is not None: # resume lora training
  97. model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=True)
  98. if is_trainable and lastest_checkpoint is None: # create new lora weights while training
  99. lora_config = LoraConfig(
  100. task_type=TaskType.CAUSAL_LM, # we should regard ChatGLM as a causal LM
  101. inference_mode=False,
  102. r=finetuning_args.lora_rank,
  103. lora_alpha=finetuning_args.lora_alpha,
  104. lora_dropout=finetuning_args.lora_dropout,
  105. target_modules=finetuning_args.lora_target
  106. )
  107. model = get_peft_model(model, lora_config)
  108. if model_args.checkpoint_dir is not None:
  109. logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
  110. return model
  111. def load_pretrained(
  112. model_args: ModelArguments,
  113. finetuning_args: FinetuningArguments,
  114. is_trainable: Optional[bool] = False,
  115. stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
  116. ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
  117. r"""
  118. Loads pretrained model and tokenizer.
  119. Support both training and inference.
  120. """
  121. if (not is_trainable) and model_args.checkpoint_dir is None:
  122. logger.warning("Checkpoint is not found at evaluation, load the original model.")
  123. finetuning_args = FinetuningArguments(finetuning_type="none")
  124. assert stage == "sft" or finetuning_args.finetuning_type == "lora", \
  125. "RM and PPO training can only be performed with LoRA method."
  126. quantization = None
  127. if model_args.quantization_bit is not None:
  128. if is_trainable:
  129. if finetuning_args.finetuning_type == "full":
  130. raise ValueError("Full-parameter fine-tuning does not support quantization.")
  131. elif finetuning_args.finetuning_type == "p_tuning":
  132. quantization = "cpm" # use cpm's quantization
  133. else:
  134. quantization = "bnb" # use bnb's quantization
  135. else:
  136. quantization = "cpm"
  137. config_kwargs = {
  138. "trust_remote_code": True,
  139. "cache_dir": model_args.cache_dir,
  140. "revision": model_args.model_revision,
  141. "use_auth_token": True if model_args.use_auth_token else None,
  142. }
  143. tokenizer = AutoTokenizer.from_pretrained(
  144. model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
  145. use_fast=model_args.use_fast_tokenizer,
  146. padding_side="left",
  147. **config_kwargs
  148. )
  149. config = AutoConfig.from_pretrained(
  150. model_args.config_name if model_args.config_name else model_args.model_name_or_path,
  151. **config_kwargs
  152. )
  153. # P-Tuning v2 configurations. Use the built-in p-tuning method of ChatGLM.
  154. if finetuning_args.finetuning_type == "p_tuning":
  155. config.pre_seq_len = finetuning_args.pre_seq_len # enable this will fix other parameters automatically
  156. config.prefix_projection = finetuning_args.prefix_projection
  157. # Quantization configurations for Full, Freeze and LoRA in training (using bitsandbytes library).
  158. if quantization == "bnb":
  159. if model_args.quantization_bit == 8:
  160. require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
  161. config_kwargs["load_in_8bit"] = True
  162. config_kwargs["quantization_config"] = BitsAndBytesConfig(
  163. load_in_8bit=True,
  164. llm_int8_threshold=6.0
  165. )
  166. elif model_args.quantization_bit == 4:
  167. require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
  168. require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
  169. require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
  170. require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
  171. config_kwargs["load_in_4bit"] = True
  172. config_kwargs["quantization_config"] = BitsAndBytesConfig(
  173. load_in_4bit=True,
  174. bnb_4bit_compute_dtype=model_args.compute_dtype,
  175. bnb_4bit_use_double_quant=model_args.double_quantization,
  176. bnb_4bit_quant_type=model_args.quantization_type
  177. )
  178. config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)}
  179. if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
  180. model_to_load = model_args.checkpoint_dir[0]
  181. else:
  182. model_to_load = model_args.model_name_or_path
  183. # Load and prepare pretrained models (without valuehead).
  184. model = AutoModel.from_pretrained(model_to_load, config=config, **config_kwargs)
  185. # Register auto class to save the custom code files.
  186. if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
  187. config.__class__.register_for_auto_class()
  188. if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
  189. tokenizer.__class__.register_for_auto_class()
  190. if hasattr(config, "auto_map") and "AutoModel" in config.auto_map:
  191. model.__class__.register_for_auto_class()
  192. if model_args.use_v2:
  193. assert tokenizer.eos_token_id is not None, "Please update the *.json and *.py files of ChatGLM2-6B from HuggingFace."
  194. model.lm_head = model.transformer.output_layer
  195. output_embedding_base_layer = model.transformer
  196. output_embedding_layer_name = "output_layer"
  197. else:
  198. assert tokenizer.eos_token_id == 130005, "Please specify `use_v2` argument while using ChatGLM2-6B."
  199. output_embedding_base_layer = model
  200. output_embedding_layer_name = "lm_head"
  201. # Initialize adapters
  202. model = prepare_model_for_training(
  203. model,
  204. finetuning_args.finetuning_type,
  205. output_embedding_base_layer,
  206. output_embedding_layer_name
  207. ) if is_trainable else model
  208. model = init_adapter(model, model_args, finetuning_args, is_trainable)
  209. if not is_trainable:
  210. model.requires_grad_(False) # fix all model params
  211. model = model.half() # cast all params to float16 for inference
  212. # Quantization with the built-in method for P-Tuning v2 training or evaluation.
  213. # Model parameters should be cast to float16 in quantized P-Tuning setting.
  214. if quantization == "cpm":
  215. if is_trainable: # convert all params into half precision except prefix_encoder in training
  216. for name, param in model.named_parameters():
  217. if "prefix_encoder" not in name:
  218. param.data = param.data.to(torch.float16)
  219. model.quantize(model_args.quantization_bit) # built-in method in ChatGLM-6B, also an in-place operation
  220. if quantization is not None:
  221. logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
  222. if stage == "rm" or stage == "ppo": # add value head
  223. model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
  224. if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
  225. logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
  226. if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
  227. model.v_head.load_state_dict({
  228. "summary.weight": getattr(model, "reward_head_weight"),
  229. "summary.bias": getattr(model, "reward_head_bias")
  230. })
  231. if stage == "ppo": # load reward model
  232. assert is_trainable, "PPO stage cannot be performed at evaluation."
  233. assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
  234. logger.info("Load reward model from {}".format(model_args.reward_model))
  235. model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
  236. assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
  237. print_trainable_params(model)
  238. return model, tokenizer
  239. def prepare_args(
  240. stage: Literal["sft", "rm", "ppo"]
  241. ) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
  242. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
  243. if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
  244. model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
  245. else:
  246. model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
  247. # Setup logging
  248. if training_args.should_log:
  249. # The default of training_args.log_level is passive, so we set log level at info here to have that default.
  250. transformers.utils.logging.set_verbosity_info()
  251. log_level = training_args.get_process_log_level()
  252. datasets.utils.logging.set_verbosity(log_level)
  253. transformers.utils.logging.set_verbosity(log_level)
  254. transformers.utils.logging.enable_default_handler()
  255. transformers.utils.logging.enable_explicit_format()
  256. # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
  257. assert stage == "sft" or (not training_args.predict_with_generate), \
  258. "`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
  259. assert not (training_args.do_train and training_args.predict_with_generate), \
  260. "`predict_with_generate` cannot be set as True while training."
  261. assert (not training_args.do_predict) or training_args.predict_with_generate, \
  262. "Please enable `predict_with_generate` to save model predictions."
  263. if model_args.quantization_bit is not None:
  264. assert finetuning_args.finetuning_type != "full" and finetuning_args.finetuning_type != "freeze", \
  265. "Quantization is incompatible with the full-parameter and freeze tuning."
  266. assert not (finetuning_args.finetuning_type == "p_tuning" and training_args.fp16), \
  267. "FP16 training conflicts with quantized P-Tuning."
  268. if not training_args.do_train:
  269. logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
  270. assert model_args.checkpoint_dir is None or finetuning_args.finetuning_type == "lora" \
  271. or len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
  272. if training_args.do_train and (not training_args.fp16):
  273. logger.warning("We recommend enable fp16 mixed precision training for ChatGLM-6B.")
  274. if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
  275. logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
  276. training_args.ddp_find_unused_parameters = False
  277. training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
  278. if model_args.quantization_bit is not None:
  279. if training_args.fp16:
  280. model_args.compute_dtype = torch.float16
  281. elif training_args.bf16:
  282. model_args.compute_dtype = torch.bfloat16
  283. else:
  284. model_args.compute_dtype = torch.float32
  285. # Log on each process the small summary:
  286. logger.info(
  287. f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
  288. + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
  289. )
  290. logger.info(f"Training/evaluation parameters {training_args}")
  291. # Set seed before initializing model.
  292. transformers.set_seed(training_args.seed)
  293. return model_args, data_args, training_args, finetuning_args
  294. def prepare_infer_args() -> Tuple[ModelArguments, FinetuningArguments, GeneratingArguments]:
  295. parser = HfArgumentParser((ModelArguments, FinetuningArguments, GeneratingArguments))
  296. if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
  297. model_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
  298. else:
  299. model_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
  300. assert model_args.checkpoint_dir is None or finetuning_args.finetuning_type == "lora" \
  301. or len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
  302. return model_args, finetuning_args, generating_args
  303. def prepare_data(
  304. model_args: ModelArguments,
  305. data_args: DataTrainingArguments
  306. ) -> Dataset:
  307. def checksum(file_path, hash):
  308. with open(file_path, "rb") as datafile:
  309. binary_data = datafile.read()
  310. sha1 = hashlib.sha1(binary_data).hexdigest()
  311. if sha1 != hash:
  312. logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
  313. ext2type = {
  314. "csv": "csv",
  315. "json": "json",
  316. "jsonl": "json"
  317. }
  318. max_samples = data_args.max_samples
  319. all_datasets: List[Dataset] = [] # support multiple datasets
  320. for dataset_attr in data_args.dataset_list:
  321. logger.info("Loading dataset {}...".format(dataset_attr))
  322. if dataset_attr.load_from == "hf_hub":
  323. data_path = dataset_attr.dataset_name
  324. data_files = None
  325. elif dataset_attr.load_from == "script":
  326. data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
  327. data_files = None
  328. elif dataset_attr.load_from == "file":
  329. data_path = None
  330. data_files: List[str] = []
  331. if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
  332. for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
  333. data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
  334. if data_path is None:
  335. data_path = ext2type.get(data_files[0].split(".")[-1], None)
  336. else:
  337. assert ext2type.get(data_files[-1].split(".")[-1], None) == data_path, "file type does not match."
  338. elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
  339. data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
  340. data_path = ext2type.get(data_files[0].split(".")[-1], None)
  341. else:
  342. raise ValueError("File not found.")
  343. assert data_path, "File extension must be csv, json or jsonl."
  344. if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
  345. checksum(data_files[0], dataset_attr.dataset_sha1)
  346. else:
  347. logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
  348. else:
  349. raise NotImplementedError
  350. raw_datasets = load_dataset(
  351. data_path,
  352. data_files=data_files,
  353. cache_dir=model_args.cache_dir,
  354. use_auth_token=True if model_args.use_auth_token else None
  355. )
  356. dataset = raw_datasets[data_args.split]
  357. if max_samples is not None:
  358. max_samples_temp = min(len(dataset), max_samples)
  359. dataset = dataset.select(range(max_samples_temp))
  360. dummy_data = [None] * len(dataset)
  361. for column_name, target_name in [
  362. ("prompt_column", "prompt"),
  363. ("query_column", "query"),
  364. ("response_column", "response"),
  365. ("history_column", "history")
  366. ]: # every dataset will have 4 columns same as each other
  367. if getattr(dataset_attr, column_name) != target_name:
  368. if getattr(dataset_attr, column_name):
  369. dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
  370. else: # None or empty string
  371. dataset = dataset.add_column(target_name, dummy_data)
  372. all_datasets.append(dataset)
  373. if len(data_args.dataset_list) == 1:
  374. all_datasets = all_datasets[0]
  375. else:
  376. all_datasets = concatenate_datasets(all_datasets)
  377. return all_datasets
  378. def preprocess_data(
  379. dataset: Dataset,
  380. tokenizer: PreTrainedTokenizer,
  381. data_args: DataTrainingArguments,
  382. training_args: Seq2SeqTrainingArguments,
  383. stage: Literal["sft", "rm", "ppo"]
  384. ) -> Dataset:
  385. column_names = list(dataset.column_names)
  386. prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
  387. def format_example(examples): # support question with a single answer or multiple answers
  388. for i in range(len(examples["prompt"])):
  389. if examples["prompt"][i] and examples["response"][i]:
  390. query, answer = examples["prompt"][i], examples["response"][i]
  391. query = query + examples["query"][i] if examples["query"][i] else query
  392. history = examples["history"][i] if examples["history"][i] else []
  393. prompt = ""
  394. for j, (old_query, response) in enumerate(history):
  395. prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(j+1, old_query, response)
  396. prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history)+1, query)
  397. prompt = prefix + prompt
  398. yield prompt, answer
  399. def preprocess_supervised_dataset(examples):
  400. # v1: build inputs with format `X [gMASK] <sop> Y <eop>` and labels with format `[IGNORE] ... [IGNORE] Y <eop>`
  401. # v2: build inputs with format `[gMASK] sop X Y </s>` and labels with format `[IGNORE] ... [IGNORE] Y </s>`
  402. model_inputs = {"input_ids": [], "labels": []}
  403. for prompt, answer in format_example(examples):
  404. source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
  405. target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
  406. if len(source_ids) > data_args.max_source_length - 2: # gmask and sop tokens
  407. source_ids = source_ids[:data_args.max_source_length - 2]
  408. if len(target_ids) > data_args.max_target_length - 1: # eos token
  409. target_ids = target_ids[:data_args.max_target_length - 1]
  410. context_length = len(source_ids) + 2 # gmask and sop tokens
  411. input_ids = tokenizer.build_inputs_with_special_tokens(source_ids, target_ids)
  412. labels = [IGNORE_INDEX] * context_length + input_ids[context_length:]
  413. model_inputs["input_ids"].append(input_ids)
  414. model_inputs["labels"].append(labels)
  415. return model_inputs
  416. def preprocess_evaluation_dataset(examples):
  417. # v1: build inputs with format `X [gMASK] <sop>` and labels with format `Y [gMASK] <sop>`
  418. # v2: build inputs with format `[gMASK] sop X` and labels with format `[gMASK] sop Y`
  419. model_inputs = {"input_ids": [], "labels": []}
  420. for prompt, answer in format_example(examples):
  421. source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
  422. target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
  423. if len(source_ids) > data_args.max_source_length - 2: # gmask and sop tokens
  424. source_ids = source_ids[:data_args.max_source_length - 2]
  425. if len(target_ids) > data_args.max_target_length - 2: # gmask and sop tokens
  426. target_ids = target_ids[:data_args.max_target_length - 2]
  427. input_ids = tokenizer.build_inputs_with_special_tokens(source_ids)
  428. labels = tokenizer.build_inputs_with_special_tokens(target_ids)
  429. model_inputs["input_ids"].append(input_ids)
  430. model_inputs["labels"].append(labels)
  431. return model_inputs
  432. def preprocess_pairwise_dataset(examples):
  433. # v1: build input pairs with format `X [gMASK] <sop> Y1 <eop>` and `X [gMASK] <sop> Y2 <eop>`
  434. # v2: build input pairs with format `[gMASK] sop X Y1 </s>` and `[gMASK] sop X Y2 </s>`
  435. model_inputs = {"accept_ids": [], "reject_ids": []}
  436. for prompt, answer in format_example(examples):
  437. source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
  438. accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
  439. reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
  440. if len(source_ids) > data_args.max_source_length - 2: # gmask and sop tokens
  441. source_ids = source_ids[:data_args.max_source_length - 2]
  442. if len(accept_ids) > data_args.max_target_length - 1: # eos token
  443. accept_ids = accept_ids[:data_args.max_target_length - 1]
  444. if len(reject_ids) > data_args.max_target_length - 1: # eos token
  445. reject_ids = reject_ids[:data_args.max_target_length - 1]
  446. accept_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], accept_ids) # avoid copying error
  447. reject_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], reject_ids)
  448. model_inputs["accept_ids"].append(accept_ids)
  449. model_inputs["reject_ids"].append(reject_ids)
  450. return model_inputs
  451. def print_sft_dataset_example(example):
  452. print("input_ids:\n{}".format(example["input_ids"]))
  453. print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
  454. print("label_ids:\n{}".format(example["labels"]))
  455. print("labels:\n{}".format(
  456. tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
  457. skip_special_tokens=False)
  458. ))
  459. def print_pairwise_dataset_example(example):
  460. print("accept_ids:\n{}".format(example["accept_ids"]))
  461. print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
  462. print("reject_ids:\n{}".format(example["reject_ids"]))
  463. print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
  464. def print_ppo_dataset_example(example):
  465. print("input_ids:\n{}".format(example["input_ids"]))
  466. print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
  467. if stage == "sft":
  468. preprocess_function = preprocess_evaluation_dataset \
  469. if training_args.predict_with_generate else preprocess_supervised_dataset
  470. elif stage == "rm":
  471. preprocess_function = preprocess_pairwise_dataset
  472. elif stage == "ppo":
  473. preprocess_function = preprocess_evaluation_dataset
  474. with training_args.main_process_first(desc="dataset map pre-processing"):
  475. dataset = dataset.map(
  476. preprocess_function,
  477. batched=True,
  478. num_proc=data_args.preprocessing_num_workers,
  479. remove_columns=column_names,
  480. load_from_cache_file=not data_args.overwrite_cache,
  481. desc="Running tokenizer on dataset"
  482. )
  483. if stage == "sft":
  484. print_sft_dataset_example(dataset[0])
  485. elif stage == "rm":
  486. print_pairwise_dataset_example(dataset[0])
  487. elif stage == "ppo":
  488. print_ppo_dataset_example(dataset[0])
  489. return dataset

这段代码主要用于初始化模型和适配器,并根据参数配置对模型进行微调。下面是对每行代码的详细解释:

  1. from transformers.utils import check_min_version:从 transformers 库导入 check_min_version 函数,用于检查 transformers 库的版本是否满足最低要求。
  2. from transformers.utils.versions import require_version:从 transformers 库导入 require_version 函数,用于检查特定库的版本是否满足要求。
  3. from transformers.modeling_utils import PreTrainedModel:从 transformers 库导入 PreTrainedModel 类,该类是所有预训练模型的基类。
  4. from transformers.tokenization_utils import PreTrainedTokenizer:从 transformers 库导入 PreTrainedTokenizer 类,该类是所有预训练分词器的基类。
  5. import datasets:导入 datasets 库,该库包含大量的公开数据集和评估度量。
  6. from datasets import Dataset, concatenate_datasets, load_dataset:从 datasets 库导入 Dataset 类、concatenate_datasets 函数和 load_dataset 函数。这些都用于处理数据集。
  7. 从 peft 库导入 PeftModelTaskTypeLoraConfigget_peft_model。Peft 库应该是用于特殊的模型微调和参数设置的库,但在 2021 年 9 月的知识截止日期之前并未找到详细信息。
  8. from peft.utils import CONFIG_NAME, WEIGHTS_NAME:从 peft 库的 utils 模块导入 CONFIG_NAMEWEIGHTS_NAME,这两个常量可能是用来指定模型配置和权重的文件名。
  9. from trl import AutoModelForCausalLMWithValueHead:从 trl 库导入 AutoModelForCausalLMWithValueHead 类,用于生成一个带有值头的因果语言模型,这可能是进行强化学习训练的模型。
  10. from .config import (ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments):从本地的 config 模块导入四个类,这些类定义了模型、训练、微调和生成等过程中的参数。
  11. from .other import (get_logger, load_trainable_params, load_valuehead_params, print_trainable_params, prepare_model_for_training, IGNORE_INDEX):从本地的 other 模块导入多个函数和常量,这些函数和常量用于日志记录、加载参数、打印参数、准备训练模型等。
  12. check_min_version("4.27.4"):检查 transformers 库的版本是否至少为 4.27.4。
  13. require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0"):检查 datasets 库的版本是否至少为 2.10.0,如果不满足,则提示如何修复。
  14. require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0"):检查 accelerate 库的版本是否至少为 0.19.0,如果不满足,则提示如何修复。
  15. require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0"):检查 peft 库的版本是否至少为 0.3.0,如果不满足,则提示如何修复。
  16. require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4"):检查 trl 库的版本是否至少为 0.4.4,如果不满足,则提示如何修复。
  17. logger = get_logger(__name__):通过 get_logger 函数获取一个日志记录器,name 是当前模块的名称。

接下来是 init_adapter 函数的定义,它用于初始化适配器并根据参数对模型进行微调。输入参数包括预训练模型、模型参数、微调参数和一个指示模型是否可训练的布尔值,返回值是经过微调的模型。函数的主要步骤包括:

  1. def init_adapter(model: PreTrainedModel, model_args: ModelArguments, finetuning_args: FinetuningArguments, is_trainable: bool) -> PreTrainedModel:定义函数 init_adapter,输入是模型、模型参数、微调参数和一个布尔值,输出是微调后的模型

  2. 通过 finetuning_args.finetuning_type 判断微调的类型,并根据微调的类型进行不同的处理。例如,如果微调类型为 "full",则将模型的所有参数转换为 float 类型;如果微调类型为 "freeze",则冻结一些层的参数;如果微调类型为 "p_tuning" 或 "lora",则通过特定的方式加载参数等。

  3. 在微调过程中,根据 model_args.checkpoint_dir 检查并加载模型检查点。

  4. 如果微调类型为 "lora",则进一步创建或加载 LoRA 权重。

  5. 最后,如果存在模型检查点,那么在日志中记录已加载的模型检查点信息,并返回微调后的模型。

注意:本段代码是用于模型微调的一部分,其中包含了许多特定的函数和类,如 PeftModelget_peft_modelLoraConfigload_trainable_params 等,对这些函数和类的具体功能和实现,需要结合相关的库和代码进行理解。

接下来的代码段细分讨论这个 init_adapter 函数:

18-24. if finetuning_args.finetuning_type == "none" and is_trainable: raise ValueError("You cannot use finetuning_type=none while training."):如果微调类型为 "none" 并且模型是可训练的,那么抛出 ValueError 异常。这是因为在可训练的情况下,不应设置微调类型为 "none"。

25-29. if finetuning_args.finetuning_type == "full": logger.info("Fine-tuning method: Full") model = model.float():如果微调类型为 "full",则在日志中记录微调方法为 "Full",并将模型的所有参数转换为浮点类型。

30-39. if finetuning_args.finetuning_type == "freeze": ...:如果微调类型为 "freeze",则在日志中记录微调方法为 "Freeze",然后遍历模型的所有参数,如果参数不在 finetuning_args.trainable_layers 列表中,那么将参数的 requires_grad 属性设置为 False,即冻结参数;否则,将参数的数据类型转换为浮点类型。然后,如果 model_args.checkpoint_dir 不为空,那么从检查点目录加载模型的可训练参数。

40-44. if finetuning_args.finetuning_type == "p_tuning": ...:如果微调类型为 "p_tuning",则在日志中记录微调方法为 "P-Tuning v2"。然后,如果 model_args.checkpoint_dir 不为空,那么从检查点目录加载模型的可训练参数。

45-76. if finetuning_args.finetuning_type == "lora": ...:如果微调类型为 "lora",则在日志中记录微调方法为 "LoRA",然后执行一系列操作,包括检查检查点目录是否包含 LoRA 权重,从检查点目录加载 LoRA 权重,合并和卸载模型,从最新的检查点重新加载模型,创建新的 LoRA 权重等。

77-78. if model_args.checkpoint_dir is not None: logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))):如果 model_args.checkpoint_dir 不为空,那么在日志中记录已经从检查点加载了微调后的模型。

  1. return model:返回经过微调的模型。

总结:此函数主要用于初始化适配器并进行模型的微调。首先根据微调的类型("full"、"freeze"、"p_tuning" 或 "lora")对模型进行不同的处理,然后如果提供了模型检查点,那么从检查点中加载模型的可训练参数或者 LoRA 权重。最后,返回经过微调的模型。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/179569
推荐阅读
相关标签
  

闽ICP备14008679号