当前位置:   article > 正文

ChatGLM Efficient Tuning源码解析 src/utils/common.py(二)load_pretrained_automodelforcausallmwithvaluehead

automodelforcausallmwithvaluehead
  1. def load_pretrained(
  2. model_args: ModelArguments,
  3. finetuning_args: FinetuningArguments,
  4. is_trainable: Optional[bool] = False,
  5. stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
  6. ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
  7. r"""
  8. Loads pretrained model and tokenizer.
  9. Support both training and inference.
  10. """
  11. if (not is_trainable) and model_args.checkpoint_dir is None:
  12. logger.warning("Checkpoint is not found at evaluation, load the original model.")
  13. finetuning_args = FinetuningArguments(finetuning_type="none")
  14. assert stage == "sft" or finetuning_args.finetuning_type == "lora", \
  15. "RM and PPO training can only be performed with LoRA method."
  16. quantization = None
  17. if model_args.quantization_bit is not None:
  18. if is_trainable:
  19. if finetuning_args.finetuning_type == "full":
  20. raise ValueError("Full-parameter fine-tuning does not support quantization.")
  21. elif finetuning_args.finetuning_type == "p_tuning":
  22. quantization = "cpm" # use cpm's quantization
  23. else:
  24. quantization = "bnb" # use bnb's quantization
  25. else:
  26. quantization = "cpm"
  27. config_kwargs = {
  28. "trust_remote_code": True,
  29. "cache_dir": model_args.cache_dir,
  30. "revision": model_args.model_revision,
  31. "use_auth_token": True if model_args.use_auth_token else None,
  32. }
  33. tokenizer = AutoTokenizer.from_pretrained(
  34. model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
  35. use_fast=model_args.use_fast_tokenizer,
  36. padding_side="left",
  37. **config_kwargs
  38. )
  39. config = AutoConfig.from_pretrained(
  40. model_args.config_name if model_args.config_name else model_args.model_name_or_path,
  41. **config_kwargs
  42. )
  43. # P-Tuning v2 configurations. Use the built-in p-tuning method of ChatGLM.
  44. if finetuning_args.finetuning_type == "p_tuning":
  45. config.pre_seq_len = finetuning_args.pre_seq_len # enable this will fix other parameters automatically
  46. config.prefix_projection = finetuning_args.prefix_projection
  47. # Quantization configurations for Full, Freeze and LoRA in training (using bitsandbytes library).
  48. if quantization == "bnb":
  49. if model_args.quantization_bit == 8:
  50. require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
  51. config_kwargs["load_in_8bit"] = True
  52. config_kwargs["quantization_config"] = BitsAndBytesConfig(
  53. load_in_8bit=True,
  54. llm_int8_threshold=6.0
  55. )
  56. elif model_args.quantization_bit == 4:
  57. require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
  58. require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
  59. require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
  60. require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
  61. config_kwargs["load_in_4bit"] = True
  62. config_kwargs["quantization_config"] = BitsAndBytesConfig(
  63. load_in_4bit=True,
  64. bnb_4bit_compute_dtype=model_args.compute_dtype,
  65. bnb_4bit_use_double_quant=model_args.double_quantization,
  66. bnb_4bit_quant_type=model_args.quantization_type
  67. )
  68. config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)}
  69. if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
  70. model_to_load = model_args.checkpoint_dir[0]
  71. else:
  72. model_to_load = model_args.model_name_or_path
  73. # Load and prepare pretrained models (without valuehead).
  74. model = AutoModel.from_pretrained(model_to_load, config=config, **config_kwargs)
  75. # Register auto class to save the custom code files.
  76. if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
  77. config.__class__.register_for_auto_class()
  78. if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
  79. tokenizer.__class__.register_for_auto_class()
  80. if hasattr(config, "auto_map") and "AutoModel" in config.auto_map:
  81. model.__class__.register_for_auto_class()
  82. if model_args.use_v2:
  83. assert tokenizer.eos_token_id is not None, "Please update the *.json and *.py files of ChatGLM2-6B from HuggingFace."
  84. model.lm_head = model.transformer.output_layer
  85. output_embedding_base_layer = model.transformer
  86. output_embedding_layer_name = "output_layer"
  87. else:
  88. assert tokenizer.eos_token_id == 130005, "Please specify `use_v2` argument while using ChatGLM2-6B."
  89. output_embedding_base_layer = model
  90. output_embedding_layer_name = "lm_head"
  91. # Initialize adapters
  92. model = prepare_model_for_training(
  93. model,
  94. finetuning_args.finetuning_type,
  95. output_embedding_base_layer,
  96. output_embedding_layer_name
  97. ) if is_trainable else model
  98. model = init_adapter(model, model_args, finetuning_args, is_trainable)
  99. if not is_trainable:
  100. model.requires_grad_(False) # fix all model params
  101. model = model.half() # cast all params to float16 for inference
  102. # Quantization with the built-in method for P-Tuning v2 training or evaluation.
  103. # Model parameters should be cast to float16 in quantized P-Tuning setting.
  104. if quantization == "cpm":
  105. if is_trainable: # convert all params into half precision except prefix_encoder in training
  106. for name, param in model.named_parameters():
  107. if "prefix_encoder" not in name:
  108. param.data = param.data.to(torch.float16)
  109. model.quantize(model_args.quantization_bit) # built-in method in ChatGLM-6B, also an in-place operation
  110. if quantization is not None:
  111. logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
  112. if stage == "rm" or stage == "ppo": # add value head
  113. model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
  114. if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
  115. logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
  116. if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
  117. model.v_head.load_state_dict({
  118. "summary.weight": getattr(model, "reward_head_weight"),
  119. "summary.bias": getattr(model, "reward_head_bias")
  120. })
  121. if stage == "ppo": # load reward model
  122. assert is_trainable, "PPO stage cannot be performed at evaluation."
  123. assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
  124. logger.info("Load reward model from {}".format(model_args.reward_model))
  125. model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
  126. assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
  127. print_trainable_params(model)
  128. return model, tokenizer

这个函数 load_pretrained 主要是为了加载预训练的模型和分词器。下面我们逐行进行分析:

1-6. 函数定义及文档字符串,解释了这个函数的主要作用。

8-12. 检查训练状态和模型参数的检查点目录。如果模型不可训练并且检查点目录为空,那么发出警告并设置微调类型为 "none"。

14-16. 断言 stage 参数等于 "sft" 或微调类型等于 "lora",因为 "RM" 和 "PPO" 训练只能使用 "LoRA" 方法。

18-29. 如果设定了模型参数的量化位数,这部分代码会设置相应的量化策略。如果是可训练状态,且微调类型是 "full",会抛出错误,因为全参数微调不支持量化。如果微调类型是 "p_tuning",则采用 "cpm" 的量化,否则采用 "bnb" 的量化。如果模型不可训练,则采用 "cpm" 的量化。

31-37. 构造一个字典 config_kwargs,其中包含了许多配置参数,比如信任远程代码、缓存目录、模型修订版本以及使用授权令牌等。

39-44. 使用 AutoTokenizer.from_pretrained 函数根据指定的参数加载预训练的分词器。

46-50. 使用 AutoConfig.from_pretrained 函数根据指定的参数加载模型配置。

52-55. 如果微调类型为 "p_tuning",那么设置 P-Tuning v2 的配置,这是 ChatGLM 自带的 P-Tuning 方法。

57-85. 如果采用 "bnb" 的量化策略,这部分代码会设置量化配置并检查所需库的版本。根据量化的位数(8 或 4)不同,需要加载不同的库和配置参数。

87-92. 根据微调类型和检查点目录的存在情况,决定加载哪个模型。

94-98. 使用 AutoModel.from_pretrained 函数加载预训练模型(不含 valuehead)。

100-108. 注册自动类以保存自定义代码文件。

110-121. 根据模型版本设置模型的某些属性。

123-127. 如果模型可训练,那么就准备模型进行训练,否则就使用原模型。然后初始化适配器。

129-133. 如果模型不可训练,那么固定所有模型参数并将所有参数转换为 float16 用于推断。

135-146. 对 P-Tuning v2 训练或评估进行量化。模型参数应该在量化的 P-Tuning 设置中转换为 float16。

148-150. 记录量化位数。

152-185. 对模型进行 RM 或 PPO 阶段的处理。首先,根据预训练模型添加值头。如果处于 RM 阶段并且提供了检查点目录,就会加载值头的权重以评估奖励模型。如果处于 PPO 阶段,会进行一些必要的检查,然后加载奖励模型。

  1. 打印模型的可训练参数

  2. 最后返回模型和分词器。

总的来说,这个函数主要是对预训练模型和分词器的加载,根据不同的需求(如训练还是推断、是否进行微调、进行何种量化等)进行相应的设置和配置。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/57784
推荐阅读
相关标签
  

闽ICP备14008679号