当前位置:   article > 正文

大语言模型 RLHF(一)——ChatGLM代码逐行解读

chatglm
介绍:
为方便学习,对ChatGlm的代码做了逐行解读,这里主要是main方法,里面核心的部分如数据的解析,评估过程定义,模型推理训练的框架流程等。后续会针对ChatGLM核心优化代码做个解读,

代码获取:
完整代码可以参考下面git链接里的ptuning/main.py里获取
GitHub - Pillars-Creation/ChatGLM-LoRA: ChatGLM-6B添加了LoRA实现,以及部分核心代码的逐行讲解 ,实例部分是做了个新闻短标题的生成

训练评估过程:
  1. if training_args.do_train: # 如果需要进行训练
  2. if "train" not in raw_datasets: # 如果原始数据集中没有训练集
  3. raise ValueError("--do_train requires a train dataset") # 抛出异常,提示需要提供训练集
  4. train_dataset = raw_datasets["train"] # 获取训练集
  5. if data_args.max_train_samples is not None: # 如果设置了最大训练样本数
  6. max_train_samples = min(len(train_dataset), data_args.max_train_samples) # 计算最大训练样本数
  7. train_dataset = train_dataset.select(range(max_train_samples)) # 选择最大训练样本数的子集
  8. with training_args.main_process_first(desc="train dataset map pre-processing"): # 在主进程中进行训练集预处理
  9. train_dataset = train_dataset.map( # 对训练集进行映射操作
  10. preprocess_function_train, # 预处理函数
  11. batched=True, # 是否对数据进行批处理
  12. num_proc=data_args.preprocessing_num_workers, # 预处理使用的进程数
  13. remove_columns=column_names, # 需要移除的列名
  14. load_from_cache_file=not data_args.overwrite_cache, # 是否从缓存文件中加载数据
  15. desc="Running tokenizer on train dataset", # 显示的描述信息
  16. )
  17. print_dataset_example(train_dataset[0]) # 打印训练集的第一个样本
评估指标定义过程:
  1. score_dict = {
  2. "rouge-1": [],
  3. "rouge-2": [],
  4. "rouge-l": [],
  5. "bleu-4": []
  6. }
  7. # 定义 score_dict,包含 rouge-1、rouge-2、rouge-l 和 bleu-4 四个指标
  8. for pred, label in zip(decoded_preds, decoded_labels):
  9. # 遍历解码后的 preds 和 labels
  10. hypothesis = list(jieba.cut(pred))
  11. reference = list(jieba.cut(label))
  12. # 使用 jieba 对预测值和真实值进行分词
  13. rouge = Rouge()
  14. scores = rouge.get_scores(' '.join(hypothesis), ' '.join(reference))
  15. result = scores[0]
  16. # 使用 Rouge 计算 ROUGE 指标,ROUGE(Recall-Oriented Understudy for Gisting Evaluation)
  17. # 是一种用于自动评估文本摘要和机器翻译的指标。它通过比较生成的摘要或翻译与参考摘要或翻译之间的重叠来计算得分。
  18. # ROUGE 指标包括 ROUGE-1、ROUGE-2 和 ROUGE-L 等,
  19. # 其中 ROUGE-1 表示单个词的重叠,ROUGE-2 表示两个词的重叠,ROUGE-L 表示最长公共子序列的重叠。
  20. # ROUGE 指标的取值范围为 0 到 1,值越大表示生成的摘要或翻译与参考摘要或翻译之间的重叠越多,即越好。
  21. # 在使用 Rouge 计算 ROUGE 指标时,rouge.get_scores() 方法返回一个包含多个指标的列表,每个指标都是一个字典,
  22. # 包含 precision、recall 和 f-measure 三个值。因此,scores[0] 表示第一个指标的字典,
  23. # 其中包含 precision、recall 和 f-measure 三个值。在这里,我们默认使用 ROUGE-1 指标,
  24. # 因此 scores[0] 表示 ROUGE-1 指标的字典。
具体main代码如下: 
完整代码可以参考下面git链接里的ptuning/main.py里获取
GitHub - Pillars-Creation/ChatGLM-LoRA: ChatGLM-6B添加了LoRA实现,以及部分核心代码的逐行讲解 ,实例部分是做了个新闻短标题的生成

  1. def main():
  2. parser = argparse.ArgumentParser()
  3. parser.add_argument("--do_train", action="store_true",default="true")
  4. parser.add_argument("--train_file", type=str, default="train.json")
  5. # 解析参数
  6. args = parser.parse_args()
  7. # 将参数赋值给 sys.argv
  8. sys.argv = [sys.argv[0]] + [f"--{k}={v}" for k, v in vars(args).items()]
  9. # 解析命令行参数
  10. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
  11. if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
  12. # 如果命令行参数只有一个,并且是一个 JSON 文件的路径,则解析该文件以获取参数
  13. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
  14. else:
  15. # 否则,解析命令行参数
  16. model_args, data_args, training_args = parser.parse_args_into_dataclasses()
  17. # 设置日志记录的格式和级别
  18. logging.basicConfig(
  19. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  20. datefmt="%m/%d/%Y %H:%M:%S",
  21. handlers=[logging.StreamHandler(sys.stdout)],
  22. )
  23. if training_args.should_log:
  24. # 如果需要记录日志,则设置日志级别为 info
  25. transformers.utils.logging.set_verbosity_info()
  26. log_level = training_args.get_process_log_level()
  27. logger.setLevel(log_level)
  28. # 获取日志级别,并设置 logger 的日志级别
  29. transformers.utils.logging.set_verbosity(log_level)
  30. transformers.utils.logging.enable_default_handler()
  31. transformers.utils.logging.enable_explicit_format()
  32. # 设置 transformers 的日志级别,并启用默认处理程序和显式格式
  33. # Log on each process the small summary:
  34. logger.warning(
  35. f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
  36. + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
  37. )
  38. logger.info(f"Training/evaluation parameters {training_args}")
  39. # 打印进程、设备、GPU 数量、分布式训练和 16 位训练等信息,并打印训练/评估参数
  40. # Set seed before initializing model.
  41. set_seed(training_args.seed)
  42. # 设置随机种子
  43. # Load dataset
  44. data_files = {}
  45. if data_args.train_file is not None:
  46. data_files["train"] = data_args.train_file
  47. extension = data_args.train_file.split(".")[-1]
  48. if data_args.validation_file is not None:
  49. data_files["validation"] = data_args.validation_file
  50. extension = data_args.validation_file.split(".")[-1]
  51. if data_args.test_file is not None:
  52. data_files["test"] = data_args.test_file
  53. extension = data_args.test_file.split(".")[-1]
  54. # 加载数据集文件
  55. raw_datasets = load_dataset(
  56. extension,
  57. data_files=data_files,
  58. cache_dir=model_args.cache_dir,
  59. use_auth_token=True if model_args.use_auth_token else None,
  60. )
  61. # 加载数据集,并设置缓存目录和身份验证令牌
  62. # Load pretrained model and tokenizer
  63. config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
  64. config.pre_seq_len = model_args.pre_seq_len
  65. config.prefix_projection = model_args.prefix_projection
  66. # 从预训练模型中加载配置文件,并设置 pre_seq_len 和 prefix_projection
  67. tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
  68. # 从预训练模型中加载 tokenizer
  69. if model_args.ptuning_checkpoint is not None:
  70. # 如果设置了 ptuning_checkpoint
  71. # Evaluation
  72. # Loading extra state dict of prefix encoder
  73. model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
  74. # 从预训练模型中加载模型
  75. prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
  76. new_prefix_state_dict = {}
  77. for k, v in prefix_state_dict.items():
  78. if k.startswith("transformer.prefix_encoder."):
  79. new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
  80. model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
  81. # 加载额外的 prefix encoder 的状态字典
  82. else:
  83. # Finetune
  84. model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
  85. # 从预训练模型中加载模型
  86. if model_args.quantization_bit is not None:
  87. print(f"Quantized to {model_args.quantization_bit} bit")
  88. model = model.quantize(model_args.quantization_bit).to(device)
  89. # 如果设置了 quantization_bit,将模型量化为指定的位数
  90. if model_args.pre_seq_len is not None:
  91. # 如果设置了 pre_seq_len
  92. # P-tuning v2
  93. model = model.half()
  94. model.transformer.prefix_encoder.float()
  95. # 将模型和 prefix encoder 转换为半精度浮点数
  96. else:
  97. # Finetune
  98. model = model.float()
  99. # 将模型转换为浮点数
  100. prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
  101. # 如果设置了 source_prefix,将其赋值给 prefix,否则将其设置
  102. # Preprocessing the datasets.
  103. # We need to tokenize inputs and targets.
  104. if training_args.do_train:
  105. column_names = raw_datasets["train"].column_names
  106. elif training_args.do_eval:
  107. column_names = raw_datasets["validation"].column_names
  108. elif training_args.do_predict:
  109. column_names = raw_datasets["test"].column_names
  110. else:
  111. logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
  112. return
  113. # Get the column names for input/target.
  114. prompt_column = data_args.prompt_column
  115. response_column = data_args.response_column
  116. history_column = data_args.history_column
  117. # Temporarily set max_target_length for training.
  118. max_target_length = data_args.max_target_length
  119. def preprocess_function_eval(examples):
  120. # 定义预处理函数,输入为 examples,输出为 model_inputs
  121. inputs, targets = [], []
  122. # 定义 inputs 和 targets 列表
  123. for i in range(len(examples[prompt_column])):
  124. # 遍历 examples 中的每个样本
  125. if examples[prompt_column][i] and examples[response_column][i]:
  126. # 如果 prompt 和 response 都不为空
  127. query = examples[prompt_column][i]
  128. # 将 prompt 赋值给 query
  129. if history_column is None or len(examples[history_column][i]) == 0:
  130. prompt = query
  131. else:
  132. prompt = ""
  133. history = examples[history_column][i]
  134. for turn_idx, (old_query, response) in enumerate(history):
  135. prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
  136. prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
  137. # 如果有 history_column,将历史对话和当前的 prompt 拼接起来,否则直接使用 prompt
  138. inputs.append(prompt)
  139. targets.append(examples[response_column][i])
  140. # 将 prompt 和 response 分别添加到 inputs 和 targets 中
  141. inputs = [prefix + inp for inp in inputs]
  142. # 将 prefix 和 inputs 中的每个元素拼接起来
  143. model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
  144. # 使用 tokenizer 对 inputs 进行编码,并进行截断和填充
  145. labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
  146. # 使用 tokenizer 对 targets 进行编码,并进行截断
  147. if data_args.ignore_pad_token_for_loss:
  148. labels["input_ids"] = [
  149. [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
  150. ]
  151. # 如果 ignore_pad_token_for_loss 为 True,将 labels 中的 pad_token_id 替换为 -100
  152. model_inputs["labels"] = labels["input_ids"]
  153. # 将 labels 中的 input_ids 赋值给 model_inputs 中的 labels
  154. return model_inputs
  155. # 返回 model_inputs
  156. def preprocess_function_train(examples):
  157. # 定义预处理函数,输入为 examples,输出为 model_inputs
  158. max_seq_length = data_args.max_source_length + data_args.max_target_length
  159. # 计算最大序列长度,即 prompt 和 response 的最大长度之和
  160. model_inputs = {
  161. "input_ids": [],
  162. "labels": [],
  163. }
  164. # 定义 model_inputs,包含 input_ids 和 labels 两个字段
  165. for i in range(len(examples[prompt_column])):
  166. # 遍历 examples 中的每个样本
  167. if examples[prompt_column][i] and examples[response_column][i]:
  168. # 如果 prompt 和 response 都不为空
  169. query, answer = examples[prompt_column][i], examples[response_column][i]
  170. # 将 prompt 和 response 分别赋值给 query 和 answer
  171. if history_column is None:
  172. prompt = query
  173. else:
  174. prompt = ""
  175. history = examples[history_column][i]
  176. for turn_idx, (old_query, response) in enumerate(history):
  177. prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
  178. prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
  179. # 如果有 history_column,将历史对话和当前的 prompt 拼接起来,否则直接使用 prompt
  180. prompt = prefix + prompt
  181. # 将 prefix 和 prompt 拼接起来
  182. a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
  183. b_ids = tokenizer.encode(text=answer, add_special_tokens=False)
  184. # 使用 tokenizer 对 prompt 和 response 进行编码
  185. if len(a_ids) > data_args.max_source_length - 1:
  186. a_ids = a_ids[: data_args.max_source_length - 1]
  187. # 如果 prompt 的长度超过了最大长度减一,截断 prompt
  188. if len(b_ids) > data_args.max_target_length - 2:
  189. b_ids = b_ids[: data_args.max_target_length - 2]
  190. # 如果 response 的长度超过了最大长度减二,截断 response
  191. input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
  192. # 将 prompt 和 response 拼接起来,并加上特殊的 token
  193. context_length = input_ids.index(tokenizer.bos_token_id)
  194. mask_position = context_length - 1
  195. labels = [-100] * context_length + input_ids[mask_position + 1:]
  196. # 构造 labels,其中 context_length 是 bos_token_id 的位置,mask_position 是下一个 token 的位置
  197. pad_len = max_seq_length - len(input_ids)
  198. input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
  199. labels = labels + [tokenizer.pad_token_id] * pad_len
  200. # 将 input_ids 和 labels 补齐到最大长度
  201. if data_args.ignore_pad_token_for_loss:
  202. labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
  203. # 如果 ignore_pad_token_for_loss 为 True,将 labels 中的 pad_token_id 替换为 -100
  204. model_inputs["input_ids"].append(input_ids)
  205. model_inputs["labels"].append(labels)
  206. # 将 input_ids 和 labels 添加到 model_inputs 中
  207. return model_inputs
  208. # 返回 model_inputs
  209. def print_dataset_example(example):
  210. print("input_ids",example["input_ids"])
  211. print("inputs", tokenizer.decode(example["input_ids"]))
  212. print("label_ids", example["labels"])
  213. print("labels", tokenizer.decode(example["labels"]))
  214. if training_args.do_train: # 如果需要进行训练
  215. if "train" not in raw_datasets: # 如果原始数据集中没有训练集
  216. raise ValueError("--do_train requires a train dataset") # 抛出异常,提示需要提供训练集
  217. train_dataset = raw_datasets["train"] # 获取训练集
  218. if data_args.max_train_samples is not None: # 如果设置了最大训练样本数
  219. max_train_samples = min(len(train_dataset), data_args.max_train_samples) # 计算最大训练样本数
  220. train_dataset = train_dataset.select(range(max_train_samples)) # 选择最大训练样本数的子集
  221. with training_args.main_process_first(desc="train dataset map pre-processing"): # 在主进程中进行训练集预处理
  222. train_dataset = train_dataset.map( # 对训练集进行映射操作
  223. preprocess_function_train, # 预处理函数
  224. batched=True, # 是否对数据进行批处理
  225. num_proc=data_args.preprocessing_num_workers, # 预处理使用的进程数
  226. remove_columns=column_names, # 需要移除的列名
  227. load_from_cache_file=not data_args.overwrite_cache, # 是否从缓存文件中加载数据
  228. desc="Running tokenizer on train dataset", # 显示的描述信息
  229. )
  230. print_dataset_example(train_dataset[0]) # 打印训练集的第一个样本
  231. if training_args.do_eval: # 如果需要进行评估
  232. max_target_length = data_args.val_max_target_length # 获取最大目标长度
  233. if "validation" not in raw_datasets: # 如果原始数据集中没有验证集
  234. raise ValueError("--do_eval requires a validation dataset") # 抛出异常,提示需要提供验证集
  235. eval_dataset = raw_datasets["validation"] # 获取验证集
  236. if data_args.max_eval_samples is not None: # 如果设置了最大评估样本数
  237. max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) # 计算最大评估样本数
  238. eval_dataset = eval_dataset.select(range(max_eval_samples)) # 选择最大评估样本数的子集
  239. with training_args.main_process_first(desc="validation dataset map pre-processing"): # 在主进程中进行验证集预处理
  240. eval_dataset = eval_dataset.map( # 对验证集进行映射操作
  241. preprocess_function_eval, # 预处理函数
  242. batched=True, # 是否对数据进行批处理
  243. num_proc=data_args.preprocessing_num_workers, # 预处理使用的进程数
  244. remove_columns=column_names, # 需要移除的列名
  245. load_from_cache_file=not data_args.overwrite_cache, # 是否从缓存文件中加载数据
  246. desc="Running tokenizer on validation dataset", # 显示的描述信息
  247. )
  248. print_dataset_example(eval_dataset[0]) # 打印验证集的第一个样本
  249. if training_args.do_predict:
  250. # 如果设置了 do_predict 标志
  251. max_target_length = data_args.val_max_target_length
  252. # 设置最大目标长度为 val_max_target_length
  253. if "test" not in raw_datasets:
  254. raise ValueError("--do_predict requires a test dataset")
  255. # 如果 raw_datasets 中没有 test 数据集,抛出 ValueError 异常
  256. predict_dataset = raw_datasets["test"]
  257. # 将 test 数据集赋值给 predict_dataset
  258. if data_args.max_predict_samples is not None:
  259. max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
  260. predict_dataset = predict_dataset.select(range(max_predict_samples))
  261. # 如果设置了 max_predict_samples,将 predict_dataset 截取到指定长度
  262. with training_args.main_process_first(desc="prediction dataset map pre-processing"):
  263. predict_dataset = predict_dataset.map(
  264. preprocess_function_eval,
  265. batched=True,
  266. num_proc=data_args.preprocessing_num_workers,
  267. remove_columns=column_names,
  268. load_from_cache_file=not data_args.overwrite_cache,
  269. desc="Running tokenizer on prediction dataset",
  270. )
  271. # 使用 preprocess_function_eval 对 predict_dataset 进行预处理
  272. print_dataset_example(predict_dataset[0])
  273. # 打印 predict_dataset 的第一个样本
  274. # Data collator
  275. label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
  276. # 如果 ignore_pad_token_for_loss 为 True,将 label_pad_token_id 设置为 -100,否则设置为 tokenizer.pad_token_id
  277. data_collator = DataCollatorForSeq2Seq(
  278. tokenizer,
  279. model=model,
  280. label_pad_token_id=label_pad_token_id,
  281. pad_to_multiple_of=None,
  282. padding=False
  283. )
  284. # 定义 data_collator,用于将数据转换为模型所需的格式
  285. # Metric
  286. def compute_metrics(eval_preds):
  287. # 定义评估指标函数,输入为 eval_preds,输出为 score_dict
  288. preds, labels = eval_preds
  289. # 将 preds 和 labels 赋值给 preds 和 labels
  290. if isinstance(preds, tuple):
  291. preds = preds[0]
  292. # 如果 preds 是元组类型,将其转换为列表类型
  293. decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
  294. # 使用 tokenizer 对 preds 进行解码
  295. if data_args.ignore_pad_token_for_loss:
  296. # 如果 ignore_pad_token_for_loss 为 True,将 labels 中的 -100 替换为 pad_token_id
  297. labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  298. decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  299. # 使用 tokenizer 对 labels 进行解码
  300. score_dict = {
  301. "rouge-1": [],
  302. "rouge-2": [],
  303. "rouge-l": [],
  304. "bleu-4": []
  305. }
  306. # 定义 score_dict,包含 rouge-1、rouge-2、rouge-l 和 bleu-4 四个指标
  307. for pred, label in zip(decoded_preds, decoded_labels):
  308. # 遍历解码后的 preds 和 labels
  309. hypothesis = list(jieba.cut(pred))
  310. reference = list(jieba.cut(label))
  311. # 使用 jieba 对预测值和真实值进行分词
  312. rouge = Rouge()
  313. scores = rouge.get_scores(' '.join(hypothesis), ' '.join(reference))
  314. result = scores[0]
  315. # 使用 Rouge 计算 ROUGE 指标,ROUGE(Recall-Oriented Understudy for Gisting Evaluation)
  316. # 是一种用于自动评估文本摘要和机器翻译的指标。它通过比较生成的摘要或翻译与参考摘要或翻译之间的重叠来计算得分。
  317. # ROUGE 指标包括 ROUGE-1、ROUGE-2 和 ROUGE-L 等,
  318. # 其中 ROUGE-1 表示单个词的重叠,ROUGE-2 表示两个词的重叠,ROUGE-L 表示最长公共子序列的重叠。
  319. # ROUGE 指标的取值范围为 0 到 1,值越大表示生成的摘要或翻译与参考摘要或翻译之间的重叠越多,即越好。
  320. # 在使用 Rouge 计算 ROUGE 指标时,rouge.get_scores() 方法返回一个包含多个指标的列表,每个指标都是一个字典,
  321. # 包含 precision、recall 和 f-measure 三个值。因此,scores[0] 表示第一个指标的字典,
  322. # 其中包含 precision、recall 和 f-measure 三个值。在这里,我们默认使用 ROUGE-1 指标,
  323. # 因此 scores[0] 表示 ROUGE-1 指标的字典。
  324. for k, v in result.items():
  325. score_dict[k].append(round(v["f"] * 100, 4))
  326. # 将 ROUGE 指标添加到 score_dict 中
  327. bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
  328. score_dict["bleu-4"].append(round(bleu_score * 100, 4))
  329. # 计算 BLEU 指标,并将其添加到 score_dict 中
  330. for k, v in score_dict.items():
  331. score_dict[k] = float(np.mean(v))
  332. # 计算每个指标的平均值,并将其转换为浮点数类型
  333. return score_dict
  334. # 返回 score_dict
  335. # Override the decoding parameters of Seq2SeqTrainer
  336. training_args.generation_max_length = (
  337. training_args.generation_max_length
  338. if training_args.generation_max_length is not None
  339. else data_args.val_max_target_length
  340. )
  341. training_args.generation_num_beams = (
  342. data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
  343. )
  344. # 如果设置了 generation_max_length,将其赋值给 training_args.generation_max_length,否则将 val_max_target_length 赋值给其
  345. # 如果设置了 num_beams,将其赋值给 training_args.generation_num_beams,否则将其保持不变
  346. # Initialize our Trainer
  347. trainer = Seq2SeqTrainer(
  348. model=model,
  349. args=training_args,
  350. train_dataset=train_dataset if training_args.do_train else None,
  351. eval_dataset=eval_dataset if training_args.do_eval else None,
  352. tokenizer=tokenizer,
  353. data_collator=data_collator,
  354. compute_metrics=compute_metrics if training_args.predict_with_generate else None,
  355. save_prefixencoder=model_args.pre_seq_len is not None
  356. )
  357. # 初始化 Seq2SeqTrainer,包括模型、参数、训练集、验证集、tokenizer、data_collator、compute_metrics 和 save_prefixencoder
  358. # Training
  359. if training_args.do_train:
  360. checkpoint = None
  361. if training_args.resume_from_checkpoint is not None:
  362. checkpoint = training_args.resume_from_checkpoint
  363. # elif last_checkpoint is not None:
  364. # checkpoint = last_checkpoint
  365. # 启用梯度检查点,减少显存使用,提高训练效率
  366. model.gradient_checkpointing_enable()
  367. # 启用输入梯度,使模型在训练时计算输入的梯度,提高训练效果
  368. model.enable_input_require_grads()
  369. # 开始训练模型,resume_from_checkpoint 表示是否从检查点恢复训练
  370. train_result = trainer.train(resume_from_checkpoint=checkpoint)
  371. # trainer.save_model() # Saves the tokenizer too for easy upload
  372. # 将训练指标保存到 metrics 变量中
  373. metrics = train_result.metrics
  374. # 判断是否设置了 max_train_samples 参数,如果设置了,则将 train_samples 设置为 max_train_samples 和训练数据集大小的较小值,否则将其设置为训练数据集的大小
  375. max_train_samples = (
  376. data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
  377. )
  378. metrics["train_samples"] = min(max_train_samples, len(train_dataset))
  379. # 将训练指标记录到日志中
  380. trainer.log_metrics("train", metrics)
  381. # 将训练指标保存到文件中
  382. trainer.save_metrics("train", metrics)
  383. # 保存训练状态,包括模型和优化器的参数
  384. trainer.save_state()
  385. # Evaluation
  386. results = {}
  387. # 定义一个空字典 results
  388. max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
  389. # 计算最大序列长度
  390. if training_args.do_eval:
  391. # 如果设置了 do_eval 为 True
  392. logger.info("*** Evaluate ***")
  393. # 打印日志信息
  394. metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length,
  395. temperature=0.95)
  396. # 使用 evaluate 方法评估模型,并返回评估指标
  397. max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
  398. metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
  399. # 计算评估样本数,并将其添加到 metrics 中
  400. trainer.log_metrics("eval", metrics)
  401. trainer.save_metrics("eval", metrics)
  402. # 打印评估指标,并将其保存到文件中
  403. if training_args.do_predict:
  404. # 如果设置了 do_predict 为 True
  405. logger.info("*** Predict ***")
  406. # 打印日志信息
  407. predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length,
  408. do_sample=True, top_p=0.7, temperature=0.95)
  409. # 使用 predict 方法预测模型,并返回预测结果
  410. metrics = predict_results.metrics
  411. max_predict_samples = (
  412. data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
  413. )
  414. metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
  415. # 计算预测样本数,并将其添加到 metrics 中
  416. trainer.log_metrics("predict", metrics)
  417. trainer.save_metrics("predict", metrics)
  418. # 打印预测指标,并将其保存到文件中
  419. if trainer.is_world_process_zero():
  420. # 如果是主进程
  421. if training_args.predict_with_generate:
  422. # 如果设置了 predict_with_generate 为 True
  423. predictions = tokenizer.batch_decode(
  424. predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
  425. )
  426. # 使用 tokenizer 对预测值进行解码,并去除特殊标记和空格
  427. predictions = [pred.strip() for pred in predictions]
  428. # 去除预测值中的空格
  429. labels = tokenizer.batch_decode(
  430. predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
  431. )
  432. # 使用 tokenizer 对真实值进行解码,并去除特殊标记和空格
  433. labels = [label.strip() for label in labels]
  434. # 去除真实值中的空格
  435. output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
  436. with open(output_prediction_file, "w", encoding="utf-8") as writer:
  437. for p, l in zip(predictions, labels):
  438. res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
  439. writer.write(f"{res}\n")
  440. # 将预测值和真实值写入文件中
  441. return results
  442. # 返回结果字典 results
  443. def _mp_fn(index):
  444. # For xla_spawn (TPUs)
  445. main()
  446. if __name__ == "__main__":
  447. main()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/137232
推荐阅读
相关标签
  

闽ICP备14008679号