当前位置:   article > 正文

ChatGLM2-6B源码解析./ptuning/main.py (二)_enable_input_require_grads()

enable_input_require_grads()
  1. # Get the column names for input/target.
  2. prompt_column = data_args.prompt_column
  3. response_column = data_args.response_column
  4. history_column = data_args.history_column
  5. # Temporarily set max_target_length for training.
  6. max_target_length = data_args.max_target_length
  7. def preprocess_function_eval(examples):
  8. inputs, targets = [], []
  9. for i in range(len(examples[prompt_column])):
  10. if examples[prompt_column][i] and examples[response_column][i]:
  11. query = examples[prompt_column][i]
  12. history = examples[history_column][i] if history_column is not None else None
  13. prompt = tokenizer.build_prompt(query, history)
  14. inputs.append(prompt)
  15. targets.append(examples[response_column][i])
  16. inputs = [prefix + inp for inp in inputs]
  17. model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
  18. labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
  19. if data_args.ignore_pad_token_for_loss:
  20. labels["input_ids"] = [
  21. [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
  22. ]
  23. model_inputs["labels"] = labels["input_ids"]
  24. return model_inputs
  25. def preprocess_function_train(examples):
  26. max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
  27. model_inputs = {
  28. "input_ids": [],
  29. "labels": [],
  30. }
  31. for i in range(len(examples[prompt_column])):
  32. if examples[prompt_column][i] and examples[response_column][i]:
  33. query, answer = examples[prompt_column][i], examples[response_column][i]
  34. history = examples[history_column][i] if history_column is not None else None
  35. prompt = tokenizer.build_prompt(query, history)
  36. prompt = prefix + prompt
  37. a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
  38. max_length=data_args.max_source_length)
  39. b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
  40. max_length=data_args.max_target_length)
  41. context_length = len(a_ids)
  42. input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
  43. labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]
  44. pad_len = max_seq_length - len(input_ids)
  45. input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
  46. labels = labels + [tokenizer.pad_token_id] * pad_len
  47. if data_args.ignore_pad_token_for_loss:
  48. labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
  49. model_inputs["input_ids"].append(input_ids)
  50. model_inputs["labels"].append(labels)
  51. return model_inputs
  52. def print_dataset_example(example):
  53. print("input_ids", example["input_ids"])
  54. print("inputs", tokenizer.decode(example["input_ids"]))
  55. print("label_ids", example["labels"])
  56. print("labels", tokenizer.decode(example["labels"]))
  57. if training_args.do_train:
  58. if "train" not in raw_datasets:
  59. raise ValueError("--do_train requires a train dataset")
  60. train_dataset = raw_datasets["train"]
  61. if data_args.max_train_samples is not None:
  62. max_train_samples = min(len(train_dataset), data_args.max_train_samples)
  63. train_dataset = train_dataset.select(range(max_train_samples))
  64. with training_args.main_process_first(desc="train dataset map pre-processing"):
  65. train_dataset = train_dataset.map(
  66. preprocess_function_train,
  67. batched=True,
  68. num_proc=data_args.preprocessing_num_workers,
  69. remove_columns=column_names,
  70. load_from_cache_file=not data_args.overwrite_cache,
  71. desc="Running tokenizer on train dataset",
  72. )
  73. print_dataset_example(train_dataset[0])
  74. if training_args.do_eval:
  75. max_target_length = data_args.val_max_target_length
  76. if "validation" not in raw_datasets:
  77. raise ValueError("--do_eval requires a validation dataset")
  78. eval_dataset = raw_datasets["validation"]
  79. if data_args.max_eval_samples is not None:
  80. max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
  81. eval_dataset = eval_dataset.select(range(max_eval_samples))
  82. with training_args.main_process_first(desc="validation dataset map pre-processing"):
  83. eval_dataset = eval_dataset.map(
  84. preprocess_function_eval,
  85. batched=True,
  86. num_proc=data_args.preprocessing_num_workers,
  87. remove_columns=column_names,
  88. load_from_cache_file=not data_args.overwrite_cache,
  89. desc="Running tokenizer on validation dataset",
  90. )
  91. print_dataset_example(eval_dataset[0])
  92. if training_args.do_predict:
  93. max_target_length = data_args.val_max_target_length
  94. if "test" not in raw_datasets:
  95. raise ValueError("--do_predict requires a test dataset")
  96. predict_dataset = raw_datasets["test"]
  97. if data_args.max_predict_samples is not None:
  98. max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
  99. predict_dataset = predict_dataset.select(range(max_predict_samples))
  100. with training_args.main_process_first(desc="prediction dataset map pre-processing"):
  101. predict_dataset = predict_dataset.map(
  102. preprocess_function_eval,
  103. batched=True,
  104. num_proc=data_args.preprocessing_num_workers,
  105. remove_columns=column_names,
  106. load_from_cache_file=not data_args.overwrite_cache,
  107. desc="Running tokenizer on prediction dataset",
  108. )
  109. print_dataset_example(predict_dataset[0])
  110. # Data collator
  111. label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
  112. data_collator = DataCollatorForSeq2Seq(
  113. tokenizer,
  114. model=model,
  115. label_pad_token_id=label_pad_token_id,
  116. pad_to_multiple_of=None,
  117. padding=False
  118. )
  119. # Metric
  120. def compute_metrics(eval_preds):
  121. preds, labels = eval_preds
  122. if isinstance(preds, tuple):
  123. preds = preds[0]
  124. decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
  125. if data_args.ignore_pad_token_for_loss:
  126. # Replace -100 in the labels as we can't decode them.
  127. labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  128. decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  129. score_dict = {
  130. "rouge-1": [],
  131. "rouge-2": [],
  132. "rouge-l": [],
  133. "bleu-4": []
  134. }
  135. for pred, label in zip(decoded_preds, decoded_labels):
  136. hypothesis = list(jieba.cut(pred))
  137. reference = list(jieba.cut(label))
  138. rouge = Rouge()
  139. scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
  140. result = scores[0]
  141. for k, v in result.items():
  142. score_dict[k].append(round(v["f"] * 100, 4))
  143. bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
  144. score_dict["bleu-4"].append(round(bleu_score * 100, 4))
  145. for k, v in score_dict.items():
  146. score_dict[k] = float(np.mean(v))
  147. return score_dict
  148. # Override the decoding parameters of Seq2SeqTrainer
  149. training_args.generation_max_length = (
  150. training_args.generation_max_length
  151. if training_args.generation_max_length is not None
  152. else data_args.val_max_target_length
  153. )
  154. training_args.generation_num_beams = (
  155. data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
  156. )
  157. # Initialize our Trainer
  158. trainer = Seq2SeqTrainer(
  159. model=model,
  160. args=training_args,
  161. train_dataset=train_dataset if training_args.do_train else None,
  162. eval_dataset=eval_dataset if training_args.do_eval else None,
  163. tokenizer=tokenizer,
  164. data_collator=data_collator,
  165. compute_metrics=compute_metrics if training_args.predict_with_generate else None,
  166. save_changed=model_args.pre_seq_len is not None
  167. )
  168. # Training
  169. if training_args.do_train:
  170. checkpoint = None
  171. if training_args.resume_from_checkpoint is not None:
  172. checkpoint = training_args.resume_from_checkpoint
  173. # elif last_checkpoint is not None:
  174. # checkpoint = last_checkpoint
  175. model.gradient_checkpointing_enable()
  176. model.enable_input_require_grads()
  177. train_result = trainer.train(resume_from_checkpoint=checkpoint)
  178. # trainer.save_model() # Saves the tokenizer too for easy upload
  179. metrics = train_result.metrics
  180. max_train_samples = (
  181. data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
  182. )
  183. metrics["train_samples"] = min(max_train_samples, len(train_dataset))
  184. trainer.log_metrics("train", metrics)
  185. trainer.save_metrics("train", metrics)
  186. trainer.save_state()
  187. # Evaluation
  188. results = {}
  189. max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
  190. if training_args.do_eval:
  191. logger.info("*** Evaluate ***")
  192. metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
  193. max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
  194. metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
  195. trainer.log_metrics("eval", metrics)
  196. trainer.save_metrics("eval", metrics)
  197. if training_args.do_predict:
  198. logger.info("*** Predict ***")
  199. predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
  200. metrics = predict_results.metrics
  201. max_predict_samples = (
  202. data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
  203. )
  204. metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
  205. trainer.log_metrics("predict", metrics)
  206. trainer.save_metrics("predict", metrics)
  207. if trainer.is_world_process_zero():
  208. if training_args.predict_with_generate:
  209. predictions = tokenizer.batch_decode(
  210. predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
  211. )
  212. predictions = [pred.strip() for pred in predictions]
  213. labels = tokenizer.batch_decode(
  214. predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
  215. )
  216. labels = [label.strip() for label in labels]
  217. output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
  218. with open(output_prediction_file, "w", encoding="utf-8") as writer:
  219. for p, l in zip(predictions, labels):
  220. res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
  221. writer.write(f"{res}\n")
  222. return results
  223. def _mp_fn(index):
  224. # For xla_spawn (TPUs)
  225. main()
  226. if __name__ == "__main__":
  227. main()
  1. prompt_column, response_column, history_column: 这些变量被定义为用于读取训练数据的列名。prompt_column和response_column分别是提问和回答的列,history_column是聊天记录的列。

  2. max_target_length: 这个变量是指预测的最大长度。

  3. preprocess_function_eval: 这是一个预处理函数,用于在评估阶段对数据进行处理。它创建了输入和目标列表,然后迭代数据集中的每个示例。对于每个示例,它检查是否有prompt和response,然后使用tokenizer将prompt和history转换为模型可以理解的格式。然后,所有的输入都被添加到一个前缀,并用tokenizer进行编码。最后,对目标进行同样的处理,并将处理后的输入和目标加入到模型输入中。

  4. preprocess_function_train: 这是一个预处理函数,用于在训练阶段对数据进行处理。它的处理方式与eval的预处理函数类似,但有一些不同之处,例如它还添加了一个eos(end of sentence) token到输入和标签的末尾,并确保输入和标签的长度都符合最大序列长度。

  5. print_dataset_example: 这个函数用于打印数据集中的一个示例。

  6. training_args.do_train: 这是一个条件语句,如果训练参数中的do_train设定为True,那么它会执行训练数据的预处理并打印一个训练数据的示例。

  7. training_args.do_eval: 这也是一个条件语句,如果训练参数中的do_eval设定为True,那么它会执行验证数据的预处理并打印一个验证数据的示例。

  8. training_args.do_predict: 同样是一个条件语句,如果训练参数中的do_predict设定为True,那么它会执行测试数据的预处理并打印一个测试数据的示例。

  9. label_pad_token_id, data_collator: 这些变量被定义为处理序列到序列任务的工具。label_pad_token_id是用于填充标签的token的ID,data_collator用于处理批量数据。

  10. compute_metrics: 这个函数用于计算评估指标。它首先解码预测和标签,然后计算ROUGE和BLEU评分。

  11. trainer: 这个变量是一个Seq2SeqTrainer对象,它用于训练模型。

  12. if training_args.do_train: 如果训练参数中的do_train设定为True,那么它会执行训练,并保存训练的指标和状态。

  13. if training_args.do_eval: 如果训练参数中的do_eval设定为True,那么它会执行评估,并保存评估的指标。

  14. if training_args.do_predict: 如果训练参数中的do_predict设定为True,那么它会执行预测,并保存预测的指标。最后,它还将预测的结果写入文件。

  15. def _mp_fn(index): 这个函数是用于TPUs的函数,如果在TPU上运行,则会调用这个函数。

  16. if name == "main": 这个判断语句用来判断是否为脚本的运行入口,如果是,则执行main()函数。

整个脚本是用于训练和评估一个聊天机器人模型的代码。脚本先进行数据的预处理,然后定义一个模型训练器并使用它来训练模型。在训练后,脚本使用模型进行评估和预测,并将预测结果保存到文件中。在整个过程中,脚本使用了许多与序列到序列任务相关的工具,如数据整理器和评估指标计算函数。

当然,让我们逐行来解析这些代码。

  1. prompt_column = data_args.prompt_column:从数据参数中获取提示列名称,也就是用于提问的列。
  2. response_column = data_args.response_column:从数据参数中获取回答列的名称,也就是作为回答或目标的列。
  3. history_column = data_args.history_column:从数据参数中获取历史对话列的名称,如果存在的话,这些历史对话将被用作提问的上下文。

以下是预处理函数,它们用于将输入和目标列进行格式化和分词。格式化的结果将被用于模型的训练和验证。

  1. preprocess_function_evalpreprocess_function_train:这两个函数是为评估和训练准备数据的。它们从示例数据中提取问题和回答,并根据需要将其进行格式化和分词。然后它们会将输入和目标添加到model_inputs列表中,然后返回这个列表。

在接下来的代码中,我们根据是否要进行训练、评估或预测,以及提供的数据集中是否包含所需的部分(训练、验证或测试),来分别处理数据集。

  1. if training_args.do_train::如果设置了训练标志,那么就需要检查是否提供了训练数据集,然后根据需要进行预处理。然后打印出第一个训练样例。

  2. if training_args.do_eval::类似地,如果设置了评估标志,那么就需要检查是否提供了评估数据集,并进行预处理。然后打印出第一个评估样例。

  3. if training_args.do_predict::对于预测,我们需要检查是否提供了测试数据集,并进行预处理。然后打印出第一个测试样例。

  4. data_collator = DataCollatorForSeq2Seq(...):创建一个数据整理器,用于将预处理后的输入数据组装成可以直接喂入模型的批次。

接下来是评估指标的计算函数。这个函数将模型的预测结果与实际标签进行比较,然后计算并返回指标分数。

  1. compute_metrics:这个函数接收预测和标签,首先进行解码,然后计算rouge和bleu分数。

接着,我们覆盖一些解码参数,然后初始化训练器,并进行训练、评估和预测。

  1. trainer = Seq2SeqTrainer(...):初始化一个训练器,它将用于训练、评估和预测。

  2. if training_args.do_train::如果设置了训练标志,就进行训练,并在训练结束后保存模型和指标。

  3. if training_args.do_eval::如果设置了评估标志,就进行评估,并记录并保存评估指标。

  4. if training_args.do_predict::如果设置了预测标志,就进行预测,并记录并保存预测指标。如果预测是使用生成方法完成的,就将预测和标签保存到文件中。

最后,如果此脚本是作为主脚本运行的,就调用main函数。

  1. if __name__ == "__main__"::如果此脚本是作为主脚本运行的,就调用main函数。这是Python的一种常见模式,用于检查脚本是直接运行还是作为模块导入。只有在直接运行脚本时,__name__的值才会是"__main__",因此只有在这种情况下,才会调用main函数。

  1. for i in range(len(examples[prompt_column])):
  2. if examples[prompt_column][i] and examples[response_column][i]:
  3. query = examples[prompt_column][i]
  4. history = examples[history_column][i] if history_column is not None else None
  5. prompt = tokenizer.build_prompt(query, history)
  6. inputs.append(prompt)
  7. targets.append(examples[response_column][i])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/479248
推荐阅读
相关标签
  

闽ICP备14008679号