赞
踩
关于我们将输入模型进行训练和评估的数据参数
max_seq_length
。 如果为 False,将在批处理时动态填充样本到批处理中的最大长度关于我们将从哪个模型/配置/标记器进行微调的参数
version_2_with_negative=True
时有用用于解析P-Tuning V2中的所有参数。
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, QuestionAnwseringArguments))
args = parser.parse_args_into_dataclasses()
return args
args = get_args()
_, data_args, training_args, _ = args
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
)
config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
num_labels=dataset.num_labels,
label2id=dataset.label2id,
id2label=dataset.id2label,
finetuning_task=data_args.dataset_name,
revision=model_args.model_revision,
)
通过模型参数可以选择三种不同的训练方式:
if model_args.prefix:
config.hidden_dropout_prob = model_args.hidden_dropout_prob
config.pre_seq_len = model_args.pre_seq_len
config.prefix_projection = model_args.prefix_projection
config.prefix_hidden_size = model_args.prefix_hidden_size
model_class = PREFIX_MODELS[config.model_type][task_type]
model = model_class.from_pretrained(
model_args.model_name_or_path,
config=config,
revision=model_args.model_revision,
)
elif model_args.prompt:
config.pre_seq_len = model_args.pre_seq_len
model_class = PROMPT_MODELS[config.model_type][task_type]
model = model_class.from_pretrained(
model_args.model_name_or_path,
config=config,
revision=model_args.model_revision,
)
else: model_class = AUTO_MODELS[task_type] model = model_class.from_pretrained( model_args.model_name_or_path, config=config, revision=model_args.model_revision, ) bert_param = 0 if fix_bert: if config.model_type == "bert": for param in model.bert.parameters(): param.requires_grad = False for _, param in model.bert.named_parameters(): bert_param += param.numel() elif config.model_type == "roberta": for param in model.roberta.parameters(): param.requires_grad = False for _, param in model.roberta.named_parameters(): bert_param += param.numel() elif config.model_type == "deberta": for param in model.deberta.parameters(): param.requires_grad = False for _, param in model.deberta.named_parameters(): bert_param += param.numel() all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() total_param = all_param - bert_param print('***** total param is {} *****'.format(total_param))
# Initialize our Trainer
trainer = BaseTrainer(
model=model,
args=training_args,
train_dataset=dataset.train_dataset if training_args.do_train else None,
eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
compute_metrics=dataset.compute_metrics,
tokenizer=tokenizer,
data_collator=dataset.data_collator,
test_key=dataset.test_key
)
return trainer, None
if training_args.do_train:
train(trainer, training_args.resume_from_checkpoint, last_checkpoint)
if training_args.do_eval:
evaluate(trainer)
if training_args.do_predict:
predict(trainer, predict_dataset)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。