当前位置:   article > 正文

华为mindspore-如何训练一个gpt一样的文本生成模型_mindformers 可以在windows上训练吗

mindformers 可以在windows上训练吗

CausalLanguageModelingTrainer Task For Trainer.
    Args:
        model_name (str): The model name of Task-Trainer. Default: None
    Examples:
        >>> from mindformers import CausalLanguageModelingTrainer
        >>> gen_trainer = CausalLanguageModelingTrainer(model_name="gpt2")
        >>> gen_trainer.train()
        >>> res = gen_trainer.predict(input_data = "hello world [MASK]")
    Raises:
        NotImplementedError: If train method or evaluate method or predict method not implemented.

初始化

  1. def __init__(self, model_name: str = None):
  2. super(CausalLanguageModelingTrainer, self).__init__("text_generation", model_name)

很简单的代码,用的是父类BaseTrainer,并传入两个参数:"text_generation", model_name

训练

  1. def train(self,
  2. config: Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]] = None,
  3. network: Optional[Union[Cell, BaseModel]] = None,
  4. dataset: Optional[Union[BaseDataset, GeneratorDataset]] = None,
  5. wrapper: Optional[TrainOneStepCell] = None,
  6. optimizer: Optional[Optimizer] = None,
  7. callbacks: Optional[Union[Callback, List[Callback]]] = None,
  8. **kwargs):
  9. r"""Train task for CausalLanguageModeling Trainer.
  10. This function is used to train or fine-tune the network.
  11. """
  12. self.training_process(
  13. config=config,
  14. network=network,
  15. callbacks=callbacks,
  16. dataset=dataset,
  17. wrapper=wrapper,
  18. optimizer=optimizer,
  19. **kwargs)

调用的basetrianer(mindformers/trainer/base_trainer.py · MindSpore/mindformers - Gitee.com) 的training_process方法。该方法用于训练或微调MindFormers中的模型。它需要几个参数,包括配置、网络、数据集、优化器、包装器和回调。

training_process方法首先设置配置参数,然后构建数据集,构建网络,并设置模型包装器。然后,它构建优化器并创建用于在训练期间进行评估的计算度量。该函数初始化模型并开始训练,同时定期进行日志记录。如果需要,可以从检查点恢复培训过程。最后,当训练完成时,函数会记录日志。

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号