赞
踩
CausalLanguageModelingTrainer Task For Trainer.
Args:
model_name (str): The model name of Task-Trainer. Default: None
Examples:
>>> from mindformers import CausalLanguageModelingTrainer
>>> gen_trainer = CausalLanguageModelingTrainer(model_name="gpt2")
>>> gen_trainer.train()
>>> res = gen_trainer.predict(input_data = "hello world [MASK]")
Raises:
NotImplementedError: If train method or evaluate method or predict method not implemented.
- def __init__(self, model_name: str = None):
- super(CausalLanguageModelingTrainer, self).__init__("text_generation", model_name)
很简单的代码,用的是父类BaseTrainer,并传入两个参数:"text_generation", model_name
- def train(self,
- config: Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]] = None,
- network: Optional[Union[Cell, BaseModel]] = None,
- dataset: Optional[Union[BaseDataset, GeneratorDataset]] = None,
- wrapper: Optional[TrainOneStepCell] = None,
- optimizer: Optional[Optimizer] = None,
- callbacks: Optional[Union[Callback, List[Callback]]] = None,
- **kwargs):
- r"""Train task for CausalLanguageModeling Trainer.
- This function is used to train or fine-tune the network.
- """
- self.training_process(
- config=config,
- network=network,
- callbacks=callbacks,
- dataset=dataset,
- wrapper=wrapper,
- optimizer=optimizer,
- **kwargs)
调用的basetrianer(mindformers/trainer/base_trainer.py · MindSpore/mindformers - Gitee.com) 的training_process方法。该方法用于训练或微调MindFormers中的模型。它需要几个参数,包括配置、网络、数据集、优化器、包装器和回调。
training_process方法首先设置配置参数,然后构建数据集,构建网络,并设置模型包装器。然后,它构建优化器并创建用于在训练期间进行评估的计算度量。该函数初始化模型并开始训练,同时定期进行日志记录。如果需要,可以从检查点恢复培训过程。最后,当训练完成时,函数会记录日志。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。