赞
踩
- # Get the column names for input/target.
- prompt_column = data_args.prompt_column
- response_column = data_args.response_column
- history_column = data_args.history_column
-
- # Temporarily set max_target_length for training.
- max_target_length = data_args.max_target_length
-
- def preprocess_function_eval(examples):
- inputs, targets = [], []
- for i in range(len(examples[prompt_column])):
- if examples[prompt_column][i] and examples[response_column][i]:
- query = examples[prompt_column][i]
- history = examples[history_column][i] if history_column is not None else None
- prompt = tokenizer.build_prompt(query, history)
- inputs.append(prompt)
- targets.append(examples[response_column][i])
-
- inputs = [prefix + inp for inp in inputs]
- model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
- labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
-
- if data_args.ignore_pad_token_for_loss:
- labels["input_ids"] = [
- [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
- ]
- model_inputs["labels"] = labels["input_ids"]
-
- return model_inputs
-
- def preprocess_function_train(examples):
- max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
-
- model_inputs = {
- "input_ids": [],
- "labels": [],
- }
- for i in range(len(examples[prompt_column])):
- if examples[prompt_column][i] and examples[response_column][i]:
- query, answer = examples[prompt_column][i], examples[response_column][i]
-
- history = examples[history_column][i] if history_column is not None else None
- prompt = tokenizer.build_prompt(query, history)
-
- prompt = prefix + prompt
- a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
- max_length=data_args.max_source_length)
- b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
- max_length=data_args.max_target_length)
-
- context_length = len(a_ids)
- input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
- labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]
-
- pad_len = max_seq_length - len(input_ids)
- input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
- labels = labels + [tokenizer.pad_token_id] * pad_len
- if data_args.ignore_pad_token_for_loss:
- labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
-
- model_inputs["input_ids"].append(input_ids)
- model_inputs["labels"].append(labels)
-
- return model_inputs
-
- def print_dataset_example(example):
- print("input_ids", example["input_ids"])
- print("inputs", tokenizer.decode(example["input_ids"]))
- print("label_ids", example["labels"])
- print("labels", tokenizer.decode(example["labels"]))
-
- if training_args.do_train:
- if "train" not in raw_datasets:
- raise ValueError("--do_train requires a train dataset")
- train_dataset = raw_datasets["train"]
- if data_args.max_train_samples is not None:
- max_train_samples = min(len(train_dataset), data_args.max_train_samples)
- train_dataset = train_dataset.select(range(max_train_samples))
- with training_args.main_process_first(desc="train dataset map pre-processing"):
- train_dataset = train_dataset.map(
- preprocess_function_train,
- batched=True,
- num_proc=data_args.preprocessing_num_workers,
- remove_columns=column_names,
- load_from_cache_file=not data_args.overwrite_cache,
- desc="Running tokenizer on train dataset",
- )
- print_dataset_example(train_dataset[0])
-
- if training_args.do_eval:
- max_target_length = data_args.val_max_target_length
- if "validation" not in raw_datasets:
- raise ValueError("--do_eval requires a validation dataset")
- eval_dataset = raw_datasets["validation"]
- if data_args.max_eval_samples is not None:
- max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
- eval_dataset = eval_dataset.select(range(max_eval_samples))
- with training_args.main_process_first(desc="validation dataset map pre-processing"):
- eval_dataset = eval_dataset.map(
- preprocess_function_eval,
- batched=True,
- num_proc=data_args.preprocessing_num_workers,
- remove_columns=column_names,
- load_from_cache_file=not data_args.overwrite_cache,
- desc="Running tokenizer on validation dataset",
- )
- print_dataset_example(eval_dataset[0])
-
- if training_args.do_predict:
- max_target_length = data_args.val_max_target_length
- if "test" not in raw_datasets:
- raise ValueError("--do_predict requires a test dataset")
- predict_dataset = raw_datasets["test"]
- if data_args.max_predict_samples is not None:
- max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
- predict_dataset = predict_dataset.select(range(max_predict_samples))
- with training_args.main_process_first(desc="prediction dataset map pre-processing"):
- predict_dataset = predict_dataset.map(
- preprocess_function_eval,
- batched=True,
- num_proc=data_args.preprocessing_num_workers,
- remove_columns=column_names,
- load_from_cache_file=not data_args.overwrite_cache,
- desc="Running tokenizer on prediction dataset",
- )
- print_dataset_example(predict_dataset[0])
-
- # Data collator
- label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
- data_collator = DataCollatorForSeq2Seq(
- tokenizer,
- model=model,
- label_pad_token_id=label_pad_token_id,
- pad_to_multiple_of=None,
- padding=False
- )
-
- # Metric
- def compute_metrics(eval_preds):
- preds, labels = eval_preds
- if isinstance(preds, tuple):
- preds = preds[0]
- decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
- if data_args.ignore_pad_token_for_loss:
- # Replace -100 in the labels as we can't decode them.
- labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
- decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
- score_dict = {
- "rouge-1": [],
- "rouge-2": [],
- "rouge-l": [],
- "bleu-4": []
- }
- for pred, label in zip(decoded_preds, decoded_labels):
- hypothesis = list(jieba.cut(pred))
- reference = list(jieba.cut(label))
- rouge = Rouge()
- scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
- result = scores[0]
-
- for k, v in result.items():
- score_dict[k].append(round(v["f"] * 100, 4))
- bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
- score_dict["bleu-4"].append(round(bleu_score * 100, 4))
- for k, v in score_dict.items():
- score_dict[k] = float(np.mean(v))
- return score_dict
- # Override the decoding parameters of Seq2SeqTrainer
- training_args.generation_max_length = (
- training_args.generation_max_length
- if training_args.generation_max_length is not None
- else data_args.val_max_target_length
- )
- training_args.generation_num_beams = (
- data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
- )
- # Initialize our Trainer
- trainer = Seq2SeqTrainer(
- model=model,
- args=training_args,
- train_dataset=train_dataset if training_args.do_train else None,
- eval_dataset=eval_dataset if training_args.do_eval else None,
- tokenizer=tokenizer,
- data_collator=data_collator,
- compute_metrics=compute_metrics if training_args.predict_with_generate else None,
- save_changed=model_args.pre_seq_len is not None
- )
- # Training
- if training_args.do_train:
- checkpoint = None
- if training_args.resume_from_checkpoint is not None:
- checkpoint = training_args.resume_from_checkpoint
- # elif last_checkpoint is not None:
- # checkpoint = last_checkpoint
- model.gradient_checkpointing_enable()
- model.enable_input_require_grads()
- train_result = trainer.train(resume_from_checkpoint=checkpoint)
- # trainer.save_model() # Saves the tokenizer too for easy upload
- metrics = train_result.metrics
- max_train_samples = (
- data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
- )
- metrics["train_samples"] = min(max_train_samples, len(train_dataset))
- trainer.log_metrics("train", metrics)
- trainer.save_metrics("train", metrics)
- trainer.save_state()
- # Evaluation
- results = {}
- max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
- if training_args.do_eval:
- logger.info("*** Evaluate ***")
- metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
- max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
- metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
- trainer.log_metrics("eval", metrics)
- trainer.save_metrics("eval", metrics)
- if training_args.do_predict:
- logger.info("*** Predict ***")
- predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
- metrics = predict_results.metrics
- max_predict_samples = (
- data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
- )
- metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
- trainer.log_metrics("predict", metrics)
- trainer.save_metrics("predict", metrics)
- if trainer.is_world_process_zero():
- if training_args.predict_with_generate:
- predictions = tokenizer.batch_decode(
- predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
- )
- predictions = [pred.strip() for pred in predictions]
- labels = tokenizer.batch_decode(
- predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
- )
- labels = [label.strip() for label in labels]
- output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
- with open(output_prediction_file, "w", encoding="utf-8") as writer:
- for p, l in zip(predictions, labels):
- res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
- writer.write(f"{res}\n")
- return results
- def _mp_fn(index):
- # For xla_spawn (TPUs)
- main()
- if __name__ == "__main__":
- main()
prompt_column, response_column, history_column: 这些变量被定义为用于读取训练数据的列名。prompt_column和response_column分别是提问和回答的列,history_column是聊天记录的列。
max_target_length: 这个变量是指预测的最大长度。
preprocess_function_eval: 这是一个预处理函数,用于在评估阶段对数据进行处理。它创建了输入和目标列表,然后迭代数据集中的每个示例。对于每个示例,它检查是否有prompt和response,然后使用tokenizer将prompt和history转换为模型可以理解的格式。然后,所有的输入都被添加到一个前缀,并用tokenizer进行编码。最后,对目标进行同样的处理,并将处理后的输入和目标加入到模型输入中。
preprocess_function_train: 这是一个预处理函数,用于在训练阶段对数据进行处理。它的处理方式与eval的预处理函数类似,但有一些不同之处,例如它还添加了一个eos(end of sentence) token到输入和标签的末尾,并确保输入和标签的长度都符合最大序列长度。
print_dataset_example: 这个函数用于打印数据集中的一个示例。
training_args.do_train: 这是一个条件语句,如果训练参数中的do_train设定为True,那么它会执行训练数据的预处理并打印一个训练数据的示例。
training_args.do_eval: 这也是一个条件语句,如果训练参数中的do_eval设定为True,那么它会执行验证数据的预处理并打印一个验证数据的示例。
training_args.do_predict: 同样是一个条件语句,如果训练参数中的do_predict设定为True,那么它会执行测试数据的预处理并打印一个测试数据的示例。
label_pad_token_id, data_collator: 这些变量被定义为处理序列到序列任务的工具。label_pad_token_id是用于填充标签的token的ID,data_collator用于处理批量数据。
compute_metrics: 这个函数用于计算评估指标。它首先解码预测和标签,然后计算ROUGE和BLEU评分。
trainer: 这个变量是一个Seq2SeqTrainer对象,它用于训练模型。
if training_args.do_train: 如果训练参数中的do_train设定为True,那么它会执行训练,并保存训练的指标和状态。
if training_args.do_eval: 如果训练参数中的do_eval设定为True,那么它会执行评估,并保存评估的指标。
if training_args.do_predict: 如果训练参数中的do_predict设定为True,那么它会执行预测,并保存预测的指标。最后,它还将预测的结果写入文件。
def _mp_fn(index): 这个函数是用于TPUs的函数,如果在TPU上运行,则会调用这个函数。
if name == "main": 这个判断语句用来判断是否为脚本的运行入口,如果是,则执行main()函数。
整个脚本是用于训练和评估一个聊天机器人模型的代码。脚本先进行数据的预处理,然后定义一个模型训练器并使用它来训练模型。在训练后,脚本使用模型进行评估和预测,并将预测结果保存到文件中。在整个过程中,脚本使用了许多与序列到序列任务相关的工具,如数据整理器和评估指标计算函数。
当然,让我们逐行来解析这些代码。
prompt_column = data_args.prompt_column
:从数据参数中获取提示列名称,也就是用于提问的列。response_column = data_args.response_column
:从数据参数中获取回答列的名称,也就是作为回答或目标的列。history_column = data_args.history_column
:从数据参数中获取历史对话列的名称,如果存在的话,这些历史对话将被用作提问的上下文。以下是预处理函数,它们用于将输入和目标列进行格式化和分词。格式化的结果将被用于模型的训练和验证。
preprocess_function_eval
和preprocess_function_train
:这两个函数是为评估和训练准备数据的。它们从示例数据中提取问题和回答,并根据需要将其进行格式化和分词。然后它们会将输入和目标添加到model_inputs
列表中,然后返回这个列表。在接下来的代码中,我们根据是否要进行训练、评估或预测,以及提供的数据集中是否包含所需的部分(训练、验证或测试),来分别处理数据集。
if training_args.do_train:
:如果设置了训练标志,那么就需要检查是否提供了训练数据集,然后根据需要进行预处理。然后打印出第一个训练样例。
if training_args.do_eval:
:类似地,如果设置了评估标志,那么就需要检查是否提供了评估数据集,并进行预处理。然后打印出第一个评估样例。
if training_args.do_predict:
:对于预测,我们需要检查是否提供了测试数据集,并进行预处理。然后打印出第一个测试样例。
data_collator = DataCollatorForSeq2Seq(...)
:创建一个数据整理器,用于将预处理后的输入数据组装成可以直接喂入模型的批次。
接下来是评估指标的计算函数。这个函数将模型的预测结果与实际标签进行比较,然后计算并返回指标分数。
compute_metrics
:这个函数接收预测和标签,首先进行解码,然后计算rouge和bleu分数。接着,我们覆盖一些解码参数,然后初始化训练器,并进行训练、评估和预测。
trainer = Seq2SeqTrainer(...)
:初始化一个训练器,它将用于训练、评估和预测。
if training_args.do_train:
:如果设置了训练标志,就进行训练,并在训练结束后保存模型和指标。
if training_args.do_eval:
:如果设置了评估标志,就进行评估,并记录并保存评估指标。
if training_args.do_predict:
:如果设置了预测标志,就进行预测,并记录并保存预测指标。如果预测是使用生成方法完成的,就将预测和标签保存到文件中。
最后,如果此脚本是作为主脚本运行的,就调用main
函数。
if __name__ == "__main__":
:如果此脚本是作为主脚本运行的,就调用main
函数。这是Python的一种常见模式,用于检查脚本是直接运行还是作为模块导入。只有在直接运行脚本时,__name__
的值才会是"__main__"
,因此只有在这种情况下,才会调用main
函数。- for i in range(len(examples[prompt_column])):
- if examples[prompt_column][i] and examples[response_column][i]:
- query = examples[prompt_column][i]
- history = examples[history_column][i] if history_column is not None else None
- prompt = tokenizer.build_prompt(query, history)
- inputs.append(prompt)
- targets.append(examples[response_column][i])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。