当前位置:   article > 正文

【peft】用peft调大模型,加载checkpoint失败,报ValueError: Can‘t find a valid checkpoint at_valueerror: can't find a valid checkpoint at

valueerror: can't find a valid checkpoint at

接上文,用peft的lora训练bloomz,训练1轮之后,发现可能不太充分,打算加载检查点继续训练,代码如下:

trainer.train(resume_from_checkpoint = 'checkpoint目录')

然后报错:raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
ValueError: Can't find a valid checkpoint at checkpoint目录

参考Peft Model not resuming from Checkpoint · Issue #24252 · huggingface/transformers · GitHub

就是_load_from_checkpoint有点问题

解决,新建一个Trainer子类,子类里重写了加载检查点的函数,调用时用这个子类来创建trainer对象

  1. from transformers import Trainer
  2. import os
  3. from peft import PeftModel
  4. from transformers.utils import (
  5. ADAPTER_SAFE_WEIGHTS_NAME,
  6. ADAPTER_WEIGHTS_NAME,
  7. is_sagemaker_mp_enabled,
  8. is_peft_available,
  9. logging,
  10. )
  11. logger = logging.get_logger(__name__)
  12. class PeftTrainer(Trainer):
  13. def _load_from_peft_checkpoint(self, resume_from_checkpoint, model):
  14. adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
  15. adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
  16. if not any(
  17. os.path.isfile(f) for f in [adapter_weights_file, adapter_safe_weights_file]
  18. ):
  19. raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
  20. logger.info(f"Loading model from {resume_from_checkpoint}.")
  21. # Load adapters following PR # 24096
  22. if is_peft_available() and isinstance(model, PeftModel):
  23. # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
  24. if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
  25. if os.path.exists(resume_from_checkpoint) or os.path.exists(resume_from_checkpoint):
  26. model.load_adapter(resume_from_checkpoint, model.active_adapter)
  27. # Load_adapter has no return value present, modify it when appropriate.
  28. from torch.nn.modules.module import _IncompatibleKeys
  29. load_result = _IncompatibleKeys([], [])
  30. else:
  31. logger.warning(
  32. "The intermediate checkpoints of PEFT may not be saved correctly, "
  33. f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, "
  34. "here are some examples https://github.com/huggingface/peft/issues/96"
  35. )
  36. else:
  37. logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
  38. def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
  39. if model is None:
  40. model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
  41. if is_peft_available() and isinstance(model, PeftModel):
  42. # Try to load adapters before trying to load a torch model
  43. try:
  44. return self._load_from_peft_checkpoint(resume_from_checkpoint, model=model)
  45. except:
  46. return super()._load_from_checkpoint(resume_from_checkpoint, model=model)
  47. # If it is not a PeftModel, use the original _load_from_checkpoint
  48. else:
  49. return super()._load_from_checkpoint(resume_from_checkpoint, model=model)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/960757
推荐阅读
相关标签
  

闽ICP备14008679号