当前位置:   article > 正文

Pytorch框架下的transformers的使用_pytorch transformers

pytorch transformers

huggingface团队在pytorch框架下开发了transformers工具包:https://github.com/huggingface/transformers,工具包实现了大量基于transformer的模型,如albert,bert,roberta等。工具包的代码结构如图所示:

transformers工具包的包结构

其中比较重要的是src/transformers以及example这两个文件夹。其中,src/transformers文件夹下是各类transformer模型的实现代码;而examples下主要是各类下游任务的微调代码。我们以文本分类任务为例来说明微调过程具体是如何实现的,在官方的例子中,使用GLUE数据集。

一、run_glue.sh文件解析

按照官方文档的指引,首先需要构建用于启动微调程序的脚本文件,脚本为微调程序提供参数。

  1. export GLUE_DIR=/path/to/glue
  2. export TASK_NAME=MRPC
  3. python ./examples/text-classification/run_glue.py \
  4. --model_name_or_path bert-base-uncased \
  5. --task_name $TASK_NAME \
  6. --do_train \
  7. --do_eval \
  8. --data_dir $GLUE_DIR/$TASK_NAME \
  9. --max_seq_length 128 \
  10. --per_device_eval_batch_size=8 \
  11. --per_device_train_batch_size=8 \
  12. --learning_rate 2e-5 \
  13. --num_train_epochs 3.0 \
  14. --output_dir /tmp/$TASK_NAME/

其中几个主要参数的意义如下:

  • model_name_or_path:用于指定进行微调的预训练模型。参数可以是模型名称,在第一次执行微调程序时,会自动下载对应的模型;参数也可以是模型路径,此时需要提前下载对应的模型到设定的路径中。
  • task_name:用于指定具体的下游任务,微调程序需要根据任务名称选择相应的processor以实现数据加载。
  • data_dir:用于指定微调数据的存储路径。
  • output_dir:用于指定微调好的模型的存放路径

二、run_glue.py文件解析

启动脚本会调用run_glue.py文件来执行微调程序。程序主要有三部分功能:加载模型,加载数据,进行微调(训练,验证,预测)。

1、加载预训练模型

(1)加载用于构建模型以及用于微调过程的参数

  1. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
  2. if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
  3. # If we pass only one argument to the script and it's the path to a json file,
  4. # let's parse it to get our arguments.
  5. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
  6. else:
  7. model_args, data_args, training_args = parser.parse_args_into_dataclasses()

其中,类ModelArguments中包含的是关于模型的属性,如model_name,config_name,tokenizer_name等,类在run.py文件中定义;类DataTrainingArguments中包含的是关于微调数据的属性,如task_name,data_dir等,类在transformers/data/datasets/glue.py文件中定义;TrainingArguments中包含的是关于微调过程的参数,如batch_size,learning_rate等参数,类在transformers/training_args.py中定义。

(2)生成model,config,tokenizer

其中,config用于加载配置信息,model根据config加载模型,tokenize用于在加载数据时提供编码信息。

  1. config = AutoConfig.from_pretrained(
  2. model_args.config_name if model_args.config_name else model_args.model_name_or_path,
  3. num_labels=num_labels,
  4. finetuning_task=data_args.task_name,
  5. cache_dir=model_args.cache_dir,
  6. )
  7. tokenizer = AutoTokenizer.from_pretrained(
  8. model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
  9. cache_dir=model_args.cache_dir,
  10. )
  11. model = AutoModelForSequenceClassification.from_pretrained(
  12. model_args.model_name_or_path,
  13. from_tf=bool(".ckpt" in model_args.model_name_or_path),
  14. config=config,
  15. cache_dir=model_args.cache_dir,
  16. )

2、加载数据

需要使用GlueDataset类构建数据,类定义在transformers/data/datasets/glue.py中,是对Dataset类的继承。

  1. train_dataset = (
  2. GlueDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
  3. )
  4. eval_dataset = (
  5. GlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
  6. if training_args.do_eval
  7. else None
  8. )
  9. test_dataset = (
  10. GlueDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
  11. if training_args.do_predict
  12. else None
  13. )

在GlueDataset类中,需要利用glue_processors类来加载数据内容。glue_processors类定义在transformers/data/processors/glue.py中。

  1. self.processor = glue_processors[args.task_name]()
  2. if mode == Split.dev:
  3. examples = self.processor.get_dev_examples(args.data_dir)
  4. elif mode == Split.test:
  5. examples = self.processor.get_test_examples(args.data_dir)
  6. else:
  7. examples = self.processor.get_train_examples(args.data_dir)

3、微调(训练,验证,预测)

(1)构建训练器

训练器Trainer类:主要用于指定使用的模型,数据,微调过程所用参数的信息。类中包含用于训练,验证,预测的方法:trainer.train(train_dataset),trainer.evaluate(eval_dataset),trainer.predicate(test_dataset)。

  1. trainer = Trainer(
  2. model=model,
  3. args=training_args,
  4. train_dataset=train_dataset,
  5. eval_dataset=eval_dataset,
  6. compute_metrics=build_compute_metrics_fn(data_args.task_name),
  7. )

(2)进行微调(训练,验证,预测)

三、如何定义自己的微调方法

有时候,我们的数据可能与官方所用的数据形式不同,这时候需要对方法进行重写以定义自己的微调方法,重写的内容主要包括:

  1. 重写dataset类
  2. 重写processor类

所有用到的参数都以属性的形式存在于ModelArguments,DataTrainingArguments,TrainingArguments这三个类中,若要改变某个参数,只需要在启动脚本中设置即可。

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号