当前位置:   article > 正文

MLM之GLM-4:GLM-4-9B源码解读(finetune.py)模型微调与评估的完整实现——定义命令行参数→加载微调配置/模型/分词器/数据管理器→定义数据集(训练集/验证集/测试集)→模型训练_glm-4-9b微调

glm-4-9b微调

MLM之GLM-4:GLM-4-9B源码解读(finetune.py)模型微调与评估的完整实现——定义命令行参数→加载微调配置/模型/分词器/数据管理器→定义数据集(训练集/验证集/测试集)→模型训练(梯度检查点/支持从检查点恢复训练)→模型评估(存在测试数据集/基于ROUGE和BLEU分数)

目录

GLM-4-9B源码解读(finetune.py)模型微调与评估的完整实现——定义命令行参数→加载微调配置/模型/分词器/数据管理器→定义数据集(训练集/验证集/测试集)→模型训练(梯度检查点/支持从检查点恢复训练)→模型评估(存在测试数据集/基于ROUGE和BLEU分数)

实现代码


GLM-4-9B源码解读(finetune.py)模型微调与评估的完整实现——定义命令行参数→加载微调配置/模型/分词器/数据管理器→定义数据集(训练集/验证集/测试集)→模型训练(梯度检查点/支持从检查点恢复训练)→模型评估(存在测试数据集/基于ROUGE和BLEU分数)

源码地址GLM-4/finetune_demo at main · THUDM/GLM-4 · GitHub

实现代码

  1. # -*- coding: utf-8 -*-
  2. # MLM之GLM-4:GLM-4-9B源码解读(finetune.py)模型微调与评估的完整实现——定义命令行参数→加载微调配置/模型/分词器/数据管理器→定义数据集(训练集/验证集/测试集)→模型训练(梯度检查点/支持从检查点恢复训练)→模型评估(存在测试数据集/基于ROUGE和BLEU分数)
  3. '''
  4. 代码实现了从配置文件加载微调配置,加载预训练模型和数据集,通过 Seq2SeqTrainer 进行训练和评估的过程。
  5. 核心技术点包括配置加载与管理、数据处理、模型加载、训练过程控制、梯度检查点与输入梯度设置以及评估指标计算。
  6. 这些技术点的结合实现了一个完整的序列到序列模型微调和评估流程。
  7. '''
  8. import json
  9. import os
  10. import jieba
  11. import dataclasses as dc
  12. import functools
  13. from collections.abc import Callable, Mapping, Sequence
  14. from pathlib import Path
  15. from typing import Annotated, Any, Optional, Union
  16. import numpy as np
  17. import ruamel.yaml as yaml
  18. import torch
  19. import typer
  20. from datasets import Dataset, NamedSplit, Split
  21. from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
  22. from peft import PeftConfig, get_peft_config, get_peft_model
  23. from rouge_chinese import Rouge
  24. from torch import nn
  25. from transformers import (
  26. AutoModelForCausalLM,
  27. AutoTokenizer,
  28. EvalPrediction,
  29. GenerationConfig,
  30. PreTrainedTokenizer,
  31. Seq2SeqTrainingArguments,
  32. )
  33. from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
  34. from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
  35. # 定义了typer.Typer()对象,用于命令行接口(CLI)的处理。
  36. app = typer.Typer(pretty_exceptions_show_locals=False)
  37. # 自定义数据整理类
  38. # DataCollatorForSeq2Seq继承并修改了transformers库中的数据整理类,主要用于处理输入和输出的填充(padding),使得输入批次的数据长度一致。
  39. class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
  40. def __call__(self, features, return_tensors=None):
  41. output_ids = ([feature['output_ids'] for feature in features] if 'output_ids' in features[0].keys() else None)
  42. if output_ids is not None:
  43. max_output_length = max(len(out) for out in output_ids)
  44. if self.pad_to_multiple_of is not None:
  45. max_output_length = (
  46. (
  47. max_output_length + self.pad_to_multiple_of - 1) //
  48. self.pad_to_multiple_of * self.pad_to_multiple_of
  49. )
  50. for feature in features:
  51. remainder = [self.tokenizer.pad_token_id] * (
  52. max_output_length - len(feature['output_ids'])
  53. )
  54. if isinstance(feature['output_ids'], list):
  55. feature['output_ids'] = feature['output_ids'] + remainder
  56. else:
  57. feature['output_ids'] = np.concatenate(
  58. [feature['output_ids'], remainder]
  59. ).astype(np.int64)
  60. return super().__call__(features, return_tensors)
  61. # Seq2SeqTrainer继承并修改了transformers库中的训练类,在预测步骤中,处理生成的tokens,调整输入和标签,使其适应生成任务需求。
  62. class Seq2SeqTrainer(_Seq2SeqTrainer):
  63. def prediction_step(
  64. self,
  65. model: nn.Module,
  66. inputs: dict[str, Any],
  67. prediction_loss_only: bool,
  68. ignore_keys=None,
  69. **gen_kwargs,
  70. ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
  71. if self.args.predict_with_generate:
  72. output_ids = inputs.pop('output_ids')
  73. input_ids = inputs['input_ids']
  74. loss, generated_tokens, labels = super().prediction_step(
  75. model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
  76. )
  77. generated_tokens = generated_tokens[:, input_ids.size()[1]:]
  78. labels = output_ids
  79. return loss, generated_tokens, labels
  80. # 数据配置与加载:
  81. # DataConfig和FinetuningConfig:两个数据配置类,分别用于数据文件路径的管理和微调配置。
  82. # DataConfig:管理数据文件路径和相关配置。
  83. # FinetuningConfig:管理微调配置,包括训练参数、数据配置等。
  84. @dc.dataclass
  85. class DataConfig(object):
  86. train_file: Optional[str] = None
  87. val_file: Optional[str] = None
  88. test_file: Optional[str] = None
  89. num_proc: Optional[int] = None
  90. @property
  91. def data_format(self) -> str:
  92. return Path(self.train_file).suffix
  93. @property
  94. def data_files(self) -> dict[NamedSplit, str]:
  95. return {
  96. split: data_file
  97. for split, data_file in zip(
  98. [Split.TRAIN, Split.VALIDATION, Split.TEST],
  99. [self.train_file, self.val_file, self.test_file],
  100. )
  101. if data_file is not None
  102. }
  103. @dc.dataclass
  104. class FinetuningConfig(object):
  105. data_config: DataConfig
  106. max_input_length: int
  107. max_output_length: int
  108. training_args: Seq2SeqTrainingArguments = dc.field(
  109. default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
  110. )
  111. peft_config: Optional[PeftConfig] = None
  112. def __post_init__(self):
  113. if not self.training_args.do_eval or self.data_config.val_file is None:
  114. self.training_args.do_eval = False
  115. self.training_args.evaluation_strategy = 'no'
  116. self.data_config.val_file = None
  117. else:
  118. self.training_args.per_device_eval_batch_size = (
  119. self.training_args.per_device_eval_batch_size
  120. or self.training_args.per_device_train_batch_size
  121. )
  122. @classmethod
  123. def from_dict(cls, **kwargs) -> 'FinetuningConfig':
  124. training_args = kwargs.get('training_args', None)
  125. if training_args is not None and not isinstance(
  126. training_args, Seq2SeqTrainingArguments
  127. ):
  128. gen_config = training_args.get('generation_config')
  129. # TODO: a bit hacky
  130. if not isinstance(gen_config, GenerationConfig):
  131. training_args['generation_config'] = GenerationConfig(
  132. **gen_config
  133. )
  134. kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)
  135. data_config = kwargs.get('data_config')
  136. if not isinstance(data_config, DataConfig):
  137. kwargs['data_config'] = DataConfig(**data_config)
  138. peft_config = kwargs.get('peft_config', None)
  139. if peft_config is not None and not isinstance(peft_config, PeftConfig):
  140. kwargs['peft_config'] = get_peft_config(config_dict=peft_config)
  141. return cls(**kwargs)
  142. @classmethod
  143. def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
  144. path = Path(path)
  145. parser = yaml.YAML(typ='safe', pure=True)
  146. parser.indent(mapping=2, offset=2, sequence=4)
  147. parser.default_flow_style = False
  148. kwargs = parser.load(path)
  149. return cls.from_dict(**kwargs)
  150. # _load_datasets和DataManager类:负责根据配置加载数据集,并提供数据处理功能。
  151. # _load_datasets:根据数据配置加载数据集。
  152. # DataManager:管理数据集加载和处理。
  153. from datasets import load_dataset, DatasetDict, NamedSplit
  154. from typing import Optional
  155. def _load_datasets(
  156. data_dir: str,
  157. data_format: str,
  158. data_files: dict[NamedSplit, str],
  159. num_proc: Optional[int],
  160. ) -> DatasetDict:
  161. if data_format == '.jsonl':
  162. dataset_dct = load_dataset(
  163. data_dir,
  164. data_files=data_files,
  165. split=None,
  166. num_proc=num_proc,
  167. )
  168. else:
  169. raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
  170. return dataset_dct
  171. class DataManager(object):
  172. def __init__(self, data_dir: str, data_config: DataConfig):
  173. self._num_proc = data_config.num_proc
  174. self._dataset_dct = _load_datasets(
  175. data_dir,
  176. data_config.data_format,
  177. data_config.data_files,
  178. self._num_proc,
  179. )
  180. def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
  181. return self._dataset_dct.get(split, None)
  182. def get_dataset(
  183. self,
  184. split: NamedSplit,
  185. process_fn: Callable[[dict[str, Any]], dict[str, Any]],
  186. batched: bool = True,
  187. remove_orig_columns: bool = True,
  188. ) -> Optional[Dataset]:
  189. orig_dataset = self._get_dataset(split)
  190. if orig_dataset is None:
  191. return
  192. if remove_orig_columns:
  193. remove_columns = orig_dataset.column_names
  194. else:
  195. remove_columns = None
  196. return orig_dataset.map(
  197. process_fn,
  198. batched=batched,
  199. remove_columns=remove_columns,
  200. num_proc=self._num_proc,
  201. )
  202. # 数据处理函数:
  203. # process_message:处理单条消息,过滤无效字段。
  204. # process_message:过滤无效字段。
  205. # process_batch和process_batch_eval:批量处理数据,生成模型输入和标签。
  206. def process_message(message):
  207. if 'tools' in message and message['role'] == 'system':
  208. for tool in message['tools']:
  209. parameters = tool['function']['parameters']['properties']
  210. tool['function']['parameters']['properties'] = \
  211. {k: v for k, v in parameters.items() if
  212. v is not None}
  213. elif 'tools' in message:
  214. del message['tools']
  215. return message
  216. # process_batch和process_batch_eval:批量处理训练和评估数据,生成输入ID和标签ID。
  217. def process_batch(
  218. batch: Mapping[str, Sequence],
  219. tokenizer: PreTrainedTokenizer,
  220. max_input_length: int,
  221. max_output_length: int,
  222. ) -> dict[str, list]:
  223. batched_conv = batch['messages']
  224. batched_input_ids = []
  225. batched_labels = []
  226. for conv in batched_conv:
  227. input_ids = [151331, 151333]
  228. loss_masks = [False, False]
  229. for message in conv:
  230. message = process_message(message)
  231. loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
  232. new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
  233. new_loss_masks = [loss_mask_val] * len(new_input_ids)
  234. input_ids += new_input_ids
  235. loss_masks += new_loss_masks
  236. input_ids.append(tokenizer.eos_token_id)
  237. loss_masks = [False, *loss_masks]
  238. labels = []
  239. for input_id, mask in zip(input_ids, loss_masks):
  240. if mask:
  241. labels.append(input_id)
  242. else:
  243. labels.append(-100)
  244. max_length = max_input_length + max_output_length + 1
  245. batched_input_ids.append(input_ids[:max_length])
  246. batched_labels.append(labels[:max_length])
  247. return {'input_ids': batched_input_ids, 'labels': batched_labels}
  248. def process_batch_eval(
  249. batch: Mapping[str, Sequence],
  250. tokenizer: PreTrainedTokenizer,
  251. max_input_length: int,
  252. max_output_length: int,
  253. ) -> dict[str, list]:
  254. batched_conv = batch['messages']
  255. batched_input_ids = []
  256. batched_output_ids = []
  257. for conv in batched_conv:
  258. input_ids = [151331, 151333]
  259. for message in conv:
  260. if len(input_ids) >= max_input_length:
  261. break
  262. else:
  263. message = process_message(message)
  264. new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
  265. if message['role'] == 'assistant':
  266. output_prompt, output_ids = (
  267. new_input_ids[:1],
  268. new_input_ids[1:],
  269. )
  270. output_ids.append(tokenizer.eos_token_id)
  271. batched_input_ids.append(
  272. input_ids[:max_input_length] + output_prompt[:1]
  273. )
  274. batched_output_ids.append(output_ids[:max_output_length])
  275. input_ids += new_input_ids
  276. return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
  277. # load_tokenizer_and_model:加载预训练模型和分词器,支持PEFT配置(参数高效微调)。
  278. def load_tokenizer_and_model(
  279. model_dir: str,
  280. peft_config: Optional[PeftConfig] = None,
  281. ):
  282. tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
  283. if peft_config is not None:
  284. model = AutoModelForCausalLM.from_pretrained(
  285. model_dir,
  286. trust_remote_code=True,
  287. empty_init=False,
  288. use_cache=False,
  289. torch_dtype=torch.bfloat16 # Must use BFloat 16
  290. )
  291. model = get_peft_model(model, peft_config)
  292. model.print_trainable_parameters()
  293. else:
  294. model = AutoModelForCausalLM.from_pretrained(
  295. model_dir,
  296. trust_remote_code=True,
  297. empty_init=False,
  298. use_cache=False,
  299. torch_dtype=torch.bfloat16
  300. )
  301. return tokenizer, model
  302. # compute_metrics:计算ROUGE和BLEU分数,用于模型生成效果评估。
  303. def compute_metrics(eval_preds: EvalPrediction, tokenizer):
  304. batched_pred_ids, batched_label_ids = eval_preds
  305. metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
  306. for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
  307. pred_txt = tokenizer.decode(pred_ids).strip()
  308. label_txt = tokenizer.decode(label_ids).strip()
  309. pred_tokens = list(jieba.cut(pred_txt))
  310. label_tokens = list(jieba.cut(label_txt))
  311. rouge = Rouge()
  312. scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
  313. for k, v in scores[0].items():
  314. metrics_dct[k].append(round(v['f'] * 100, 4))
  315. metrics_dct['bleu-4'].append(
  316. sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3))
  317. return {k: np.mean(v) for k, v in metrics_dct.items()}
  318. @app.command()
  319. def main(
  320. data_dir: Annotated[str, typer.Argument(help='')],
  321. model_dir: Annotated[
  322. str,
  323. typer.Argument(
  324. help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
  325. ),
  326. ],
  327. config_file: Annotated[str, typer.Argument(help='')],
  328. auto_resume_from_checkpoint: str = typer.Argument(
  329. default='',
  330. help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
  331. ),
  332. ):
  333. # 0、定义命令行参数:数据目录、模型目录或预训练模型ID、配置文件路径、是否自动从最新检查点恢复训练。
  334. # 1、加载微调配置、模型、分词器、数据管理器
  335. ft_config = FinetuningConfig.from_file(config_file)
  336. tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
  337. data_manager = DataManager(data_dir, ft_config.data_config)
  338. # 2、定义数据集:获取训练集、验证集和测试集数据,分别调用 process_batch 和 process_batch_eval 进行处理
  339. train_dataset = data_manager.get_dataset(
  340. Split.TRAIN,
  341. functools.partial(
  342. process_batch,
  343. tokenizer=tokenizer,
  344. max_input_length=ft_config.max_input_length,
  345. max_output_length=ft_config.max_output_length,
  346. ),
  347. batched=True,
  348. )
  349. print('train_dataset:', train_dataset)
  350. #############新增代码,查看训练集的前几个样本###############
  351. if train_dataset is not None:
  352. print("Training dataset samples:")
  353. for i in range(3):
  354. print(tokenizer.decode(train_dataset[i]['input_ids'], skip_special_tokens=True))
  355. val_dataset = data_manager.get_dataset(
  356. Split.VALIDATION,
  357. functools.partial(
  358. process_batch_eval,
  359. tokenizer=tokenizer,
  360. max_input_length=ft_config.max_input_length,
  361. max_output_length=ft_config.max_output_length,
  362. ),
  363. batched=True,
  364. )
  365. if val_dataset is not None:
  366. print('val_dataset:', val_dataset)
  367. test_dataset = data_manager.get_dataset(
  368. Split.TEST,
  369. functools.partial(
  370. process_batch_eval,
  371. tokenizer=tokenizer,
  372. max_input_length=ft_config.max_input_length,
  373. max_output_length=ft_config.max_output_length,
  374. ),
  375. batched=True,
  376. )
  377. if test_dataset is not None:
  378. print('test_dataset:', test_dataset)
  379. # 3、模型训练
  380. # 3.1、设置模型训练属性:启用模型的梯度检查点、设置输入需要梯度
  381. model.gradient_checkpointing_enable()
  382. model.enable_input_require_grads()
  383. # 3.2、创建训练器:创建 Seq2SeqTrainer对象,用于训练和评估序列到序列模型。并且支持从检查点恢复训练。
  384. trainer = Seq2SeqTrainer(
  385. model=model,
  386. args=ft_config.training_args, # 指定训练参数:从配置文件加载的训练参数,包括学习率、训练轮数、批量大小等。
  387. data_collator=DataCollatorForSeq2Seq( # 将输入数据整理成批次,进行填充和转换为张量。
  388. tokenizer=tokenizer,
  389. padding='longest',
  390. return_tensors='pt',
  391. ),
  392. train_dataset=train_dataset,
  393. eval_dataset=val_dataset.select(list(range(50))), # 从验证数据集中选择前50条数据进行评估
  394. compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer), # 指定评估指标的计算方法
  395. )
  396. # 3.3、训练模型:根据 auto_resume_from_checkpoint参数决定是否从检查点恢复训练或重新开始训练。
  397. # 判断是否需要从检查点恢复训练
  398. if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
  399. trainer.train()
  400. else:
  401. # 处理从检查点恢复训练的情况
  402. output_dir = ft_config.training_args.output_dir
  403. dirlist = os.listdir(output_dir)
  404. checkpoint_sn = 0
  405. # 遍历目录列表,找到所有有效的检查点(文件名包含 "checkpoint" 且不包含 "tmp"),并记录最新的检查点编号 checkpoint_sn。
  406. for checkpoint_str in dirlist:
  407. if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
  408. checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
  409. if checkpoint > checkpoint_sn:
  410. checkpoint_sn = checkpoint
  411. # 根据 auto_resume_from_checkpoint 的值确定恢复策略
  412. if auto_resume_from_checkpoint.upper() == "YES":
  413. # 如果 auto_resume_from_checkpoint 的值为 "YES",且找到有效的检查点编号。
  414. if checkpoint_sn > 0:
  415. model.gradient_checkpointing_enable()
  416. model.enable_input_require_grads()
  417. checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
  418. print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
  419. trainer.train(resume_from_checkpoint=checkpoint_directory)
  420. else:
  421. trainer.train()
  422. else:
  423. # 如果 auto_resume_from_checkpoint 是一个正整数。
  424. if auto_resume_from_checkpoint.isdigit():
  425. if int(auto_resume_from_checkpoint) > 0:
  426. checkpoint_sn = int(auto_resume_from_checkpoint)
  427. model.gradient_checkpointing_enable()
  428. model.enable_input_require_grads()
  429. checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
  430. print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
  431. trainer.train(resume_from_checkpoint=checkpoint_directory)
  432. # 如果 auto_resume_from_checkpoint 既不是 "YES" 也不是正整数。打印错误信息,说明指定的检查点编号无效,需要用户手动检查和选择正确的检查点。
  433. else:
  434. print(auto_resume_from_checkpoint,
  435. "The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct checkpoint in the model output directory")
  436. # 4、模型评估:如果存在测试数据集,则对测试数据集进行预测评估,包括 ROUGE 和 BLEU 分数。
  437. if test_dataset is not None:
  438. trainer.predict(test_dataset)
  439. if __name__ == '__main__':
  440. app()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/1007228
推荐阅读
相关标签
  

闽ICP备14008679号