赞
踩
- def load_pretrained(
- model_args: ModelArguments,
- finetuning_args: FinetuningArguments,
- is_trainable: Optional[bool] = False,
- stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
- ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
- r"""
- Loads pretrained model and tokenizer.
- Support both training and inference.
- """
- if (not is_trainable) and model_args.checkpoint_dir is None:
- logger.warning("Checkpoint is not found at evaluation, load the original model.")
- finetuning_args = FinetuningArguments(finetuning_type="none")
-
- assert stage == "sft" or finetuning_args.finetuning_type == "lora", \
- "RM and PPO training can only be performed with LoRA method."
-
- quantization = None
- if model_args.quantization_bit is not None:
- if is_trainable:
- if finetuning_args.finetuning_type == "full":
- raise ValueError("Full-parameter fine-tuning does not support quantization.")
- elif finetuning_args.finetuning_type == "p_tuning":
- quantization = "cpm" # use cpm's quantization
- else:
- quantization = "bnb" # use bnb's quantization
- else:
- quantization = "cpm"
-
- config_kwargs = {
- "trust_remote_code": True,
- "cache_dir": model_args.cache_dir,
- "revision": model_args.model_revision,
- "use_auth_token": True if model_args.use_auth_token else None,
- }
-
- tokenizer = AutoTokenizer.from_pretrained(
- model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
- use_fast=model_args.use_fast_tokenizer,
- padding_side="left",
- **config_kwargs
- )
-
- config = AutoConfig.from_pretrained(
- model_args.config_name if model_args.config_name else model_args.model_name_or_path,
- **config_kwargs
- )
-
- # P-Tuning v2 configurations. Use the built-in p-tuning method of ChatGLM.
- if finetuning_args.finetuning_type == "p_tuning":
- config.pre_seq_len = finetuning_args.pre_seq_len # enable this will fix other parameters automatically
- config.prefix_projection = finetuning_args.prefix_projection
-
- # Quantization configurations for Full, Freeze and LoRA in training (using bitsandbytes library).
- if quantization == "bnb":
- if model_args.quantization_bit == 8:
- require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
- config_kwargs["load_in_8bit"] = True
- config_kwargs["quantization_config"] = BitsAndBytesConfig(
- load_in_8bit=True,
- llm_int8_threshold=6.0
- )
- elif model_args.quantization_bit == 4:
- require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
- require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
- require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
- require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
- config_kwargs["load_in_4bit"] = True
- config_kwargs["quantization_config"] = BitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_compute_dtype=model_args.compute_dtype,
- bnb_4bit_use_double_quant=model_args.double_quantization,
- bnb_4bit_quant_type=model_args.quantization_type
- )
- config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)}
-
- if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
- model_to_load = model_args.checkpoint_dir[0]
- else:
- model_to_load = model_args.model_name_or_path
-
- # Load and prepare pretrained models (without valuehead).
- model = AutoModel.from_pretrained(model_to_load, config=config, **config_kwargs)
-
- # Register auto class to save the custom code files.
- if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
- config.__class__.register_for_auto_class()
- if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
- tokenizer.__class__.register_for_auto_class()
- if hasattr(config, "auto_map") and "AutoModel" in config.auto_map:
- model.__class__.register_for_auto_class()
-
- if model_args.use_v2:
- assert tokenizer.eos_token_id is not None, "Please update the *.json and *.py files of ChatGLM2-6B from HuggingFace."
- model.lm_head = model.transformer.output_layer
- output_embedding_base_layer = model.transformer
- output_embedding_layer_name = "output_layer"
- else:
- assert tokenizer.eos_token_id == 130005, "Please specify `use_v2` argument while using ChatGLM2-6B."
- output_embedding_base_layer = model
- output_embedding_layer_name = "lm_head"
-
- # Initialize adapters
- model = prepare_model_for_training(
- model,
- finetuning_args.finetuning_type,
- output_embedding_base_layer,
- output_embedding_layer_name
- ) if is_trainable else model
- model = init_adapter(model, model_args, finetuning_args, is_trainable)
-
- if not is_trainable:
- model.requires_grad_(False) # fix all model params
- model = model.half() # cast all params to float16 for inference
-
- # Quantization with the built-in method for P-Tuning v2 training or evaluation.
- # Model parameters should be cast to float16 in quantized P-Tuning setting.
- if quantization == "cpm":
- if is_trainable: # convert all params into half precision except prefix_encoder in training
- for name, param in model.named_parameters():
- if "prefix_encoder" not in name:
- param.data = param.data.to(torch.float16)
-
- model.quantize(model_args.quantization_bit) # built-in method in ChatGLM-6B, also an in-place operation
-
- if quantization is not None:
- logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
-
- if stage == "rm" or stage == "ppo": # add value head
- model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
-
- if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
- logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
- if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
- model.v_head.load_state_dict({
- "summary.weight": getattr(model, "reward_head_weight"),
- "summary.bias": getattr(model, "reward_head_bias")
- })
-
- if stage == "ppo": # load reward model
- assert is_trainable, "PPO stage cannot be performed at evaluation."
- assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
- logger.info("Load reward model from {}".format(model_args.reward_model))
- model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
- assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
-
- print_trainable_params(model)
-
- return model, tokenizer
![](https://csdnimg.cn/release/blogv2/dist/pc/img/newCodeMoreWhite.png)
这个函数 load_pretrained
主要是为了加载预训练的模型和分词器。下面我们逐行进行分析:
1-6. 函数定义及文档字符串,解释了这个函数的主要作用。
8-12. 检查训练状态和模型参数的检查点目录。如果模型不可训练并且检查点目录为空,那么发出警告并设置微调类型为 "none"。
14-16. 断言 stage
参数等于 "sft" 或微调类型等于 "lora",因为 "RM" 和 "PPO" 训练只能使用 "LoRA" 方法。
18-29. 如果设定了模型参数的量化位数,这部分代码会设置相应的量化策略。如果是可训练状态,且微调类型是 "full",会抛出错误,因为全参数微调不支持量化。如果微调类型是 "p_tuning",则采用 "cpm" 的量化,否则采用 "bnb" 的量化。如果模型不可训练,则采用 "cpm" 的量化。
31-37. 构造一个字典 config_kwargs
,其中包含了许多配置参数,比如信任远程代码、缓存目录、模型修订版本以及使用授权令牌等。
39-44. 使用 AutoTokenizer.from_pretrained
函数根据指定的参数加载预训练的分词器。
46-50. 使用 AutoConfig.from_pretrained
函数根据指定的参数加载模型配置。
52-55. 如果微调类型为 "p_tuning",那么设置 P-Tuning v2 的配置,这是 ChatGLM 自带的 P-Tuning 方法。
57-85. 如果采用 "bnb" 的量化策略,这部分代码会设置量化配置并检查所需库的版本。根据量化的位数(8 或 4)不同,需要加载不同的库和配置参数。
87-92. 根据微调类型和检查点目录的存在情况,决定加载哪个模型。
94-98. 使用 AutoModel.from_pretrained
函数加载预训练模型(不含 valuehead)。
100-108. 注册自动类以保存自定义代码文件。
110-121. 根据模型版本设置模型的某些属性。
123-127. 如果模型可训练,那么就准备模型进行训练,否则就使用原模型。然后初始化适配器。
129-133. 如果模型不可训练,那么固定所有模型参数并将所有参数转换为 float16 用于推断。
135-146. 对 P-Tuning v2 训练或评估进行量化。模型参数应该在量化的 P-Tuning 设置中转换为 float16。
148-150. 记录量化位数。
152-185. 对模型进行 RM 或 PPO 阶段的处理。首先,根据预训练模型添加值头。如果处于 RM 阶段并且提供了检查点目录,就会加载值头的权重以评估奖励模型。如果处于 PPO 阶段,会进行一些必要的检查,然后加载奖励模型。
打印模型的可训练参数。
最后返回模型和分词器。
总的来说,这个函数主要是对预训练模型和分词器的加载,根据不同的需求(如训练还是推断、是否进行微调、进行何种量化等)进行相应的设置和配置。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。