赞
踩
MLM之GLM-4:GLM-4-9B源码解读(finetune.py)模型微调与评估的完整实现——定义命令行参数→加载微调配置/模型/分词器/数据管理器→定义数据集(训练集/验证集/测试集)→模型训练(梯度检查点/支持从检查点恢复训练)→模型评估(存在测试数据集/基于ROUGE和BLEU分数)
目录
源码地址:GLM-4/finetune_demo at main · THUDM/GLM-4 · GitHub
- # -*- coding: utf-8 -*-
-
-
-
- # MLM之GLM-4:GLM-4-9B源码解读(finetune.py)模型微调与评估的完整实现——定义命令行参数→加载微调配置/模型/分词器/数据管理器→定义数据集(训练集/验证集/测试集)→模型训练(梯度检查点/支持从检查点恢复训练)→模型评估(存在测试数据集/基于ROUGE和BLEU分数)
- '''
- 代码实现了从配置文件加载微调配置,加载预训练模型和数据集,通过 Seq2SeqTrainer 进行训练和评估的过程。
- 核心技术点包括配置加载与管理、数据处理、模型加载、训练过程控制、梯度检查点与输入梯度设置以及评估指标计算。
- 这些技术点的结合实现了一个完整的序列到序列模型微调和评估流程。
- '''
-
-
-
- import json
- import os
- import jieba
- import dataclasses as dc
- import functools
- from collections.abc import Callable, Mapping, Sequence
- from pathlib import Path
- from typing import Annotated, Any, Optional, Union
- import numpy as np
- import ruamel.yaml as yaml
- import torch
- import typer
- from datasets import Dataset, NamedSplit, Split
- from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
- from peft import PeftConfig, get_peft_config, get_peft_model
- from rouge_chinese import Rouge
- from torch import nn
- from transformers import (
- AutoModelForCausalLM,
- AutoTokenizer,
- EvalPrediction,
- GenerationConfig,
- PreTrainedTokenizer,
- Seq2SeqTrainingArguments,
- )
- from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
- from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
-
-
- # 定义了typer.Typer()对象,用于命令行接口(CLI)的处理。
- app = typer.Typer(pretty_exceptions_show_locals=False)
-
-
- # 自定义数据整理类
- # DataCollatorForSeq2Seq继承并修改了transformers库中的数据整理类,主要用于处理输入和输出的填充(padding),使得输入批次的数据长度一致。
- class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
- def __call__(self, features, return_tensors=None):
- output_ids = ([feature['output_ids'] for feature in features] if 'output_ids' in features[0].keys() else None)
- if output_ids is not None:
- max_output_length = max(len(out) for out in output_ids)
- if self.pad_to_multiple_of is not None:
- max_output_length = (
- (
- max_output_length + self.pad_to_multiple_of - 1) //
- self.pad_to_multiple_of * self.pad_to_multiple_of
- )
- for feature in features:
- remainder = [self.tokenizer.pad_token_id] * (
- max_output_length - len(feature['output_ids'])
- )
- if isinstance(feature['output_ids'], list):
- feature['output_ids'] = feature['output_ids'] + remainder
- else:
- feature['output_ids'] = np.concatenate(
- [feature['output_ids'], remainder]
- ).astype(np.int64)
- return super().__call__(features, return_tensors)
-
-
- # Seq2SeqTrainer继承并修改了transformers库中的训练类,在预测步骤中,处理生成的tokens,调整输入和标签,使其适应生成任务需求。
- class Seq2SeqTrainer(_Seq2SeqTrainer):
- def prediction_step(
- self,
- model: nn.Module,
- inputs: dict[str, Any],
- prediction_loss_only: bool,
- ignore_keys=None,
- **gen_kwargs,
- ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
- if self.args.predict_with_generate:
- output_ids = inputs.pop('output_ids')
- input_ids = inputs['input_ids']
- loss, generated_tokens, labels = super().prediction_step(
- model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
- )
- generated_tokens = generated_tokens[:, input_ids.size()[1]:]
- labels = output_ids
- return loss, generated_tokens, labels
-
-
- # 数据配置与加载:
- # DataConfig和FinetuningConfig:两个数据配置类,分别用于数据文件路径的管理和微调配置。
- # DataConfig:管理数据文件路径和相关配置。
- # FinetuningConfig:管理微调配置,包括训练参数、数据配置等。
- @dc.dataclass
- class DataConfig(object):
- train_file: Optional[str] = None
- val_file: Optional[str] = None
- test_file: Optional[str] = None
- num_proc: Optional[int] = None
-
- @property
- def data_format(self) -> str:
- return Path(self.train_file).suffix
-
- @property
- def data_files(self) -> dict[NamedSplit, str]:
- return {
- split: data_file
- for split, data_file in zip(
- [Split.TRAIN, Split.VALIDATION, Split.TEST],
- [self.train_file, self.val_file, self.test_file],
- )
- if data_file is not None
- }
-
- @dc.dataclass
- class FinetuningConfig(object):
- data_config: DataConfig
-
- max_input_length: int
- max_output_length: int
-
- training_args: Seq2SeqTrainingArguments = dc.field(
- default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
- )
- peft_config: Optional[PeftConfig] = None
-
- def __post_init__(self):
- if not self.training_args.do_eval or self.data_config.val_file is None:
- self.training_args.do_eval = False
- self.training_args.evaluation_strategy = 'no'
- self.data_config.val_file = None
- else:
- self.training_args.per_device_eval_batch_size = (
- self.training_args.per_device_eval_batch_size
- or self.training_args.per_device_train_batch_size
- )
-
- @classmethod
- def from_dict(cls, **kwargs) -> 'FinetuningConfig':
- training_args = kwargs.get('training_args', None)
- if training_args is not None and not isinstance(
- training_args, Seq2SeqTrainingArguments
- ):
- gen_config = training_args.get('generation_config')
- # TODO: a bit hacky
- if not isinstance(gen_config, GenerationConfig):
- training_args['generation_config'] = GenerationConfig(
- **gen_config
- )
- kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)
-
- data_config = kwargs.get('data_config')
- if not isinstance(data_config, DataConfig):
- kwargs['data_config'] = DataConfig(**data_config)
-
- peft_config = kwargs.get('peft_config', None)
- if peft_config is not None and not isinstance(peft_config, PeftConfig):
- kwargs['peft_config'] = get_peft_config(config_dict=peft_config)
- return cls(**kwargs)
-
- @classmethod
- def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
- path = Path(path)
- parser = yaml.YAML(typ='safe', pure=True)
- parser.indent(mapping=2, offset=2, sequence=4)
- parser.default_flow_style = False
- kwargs = parser.load(path)
- return cls.from_dict(**kwargs)
-
-
- # _load_datasets和DataManager类:负责根据配置加载数据集,并提供数据处理功能。
- # _load_datasets:根据数据配置加载数据集。
- # DataManager:管理数据集加载和处理。
- from datasets import load_dataset, DatasetDict, NamedSplit
- from typing import Optional
- def _load_datasets(
- data_dir: str,
- data_format: str,
- data_files: dict[NamedSplit, str],
- num_proc: Optional[int],
- ) -> DatasetDict:
- if data_format == '.jsonl':
- dataset_dct = load_dataset(
- data_dir,
- data_files=data_files,
- split=None,
- num_proc=num_proc,
- )
- else:
- raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
- return dataset_dct
-
-
- class DataManager(object):
- def __init__(self, data_dir: str, data_config: DataConfig):
- self._num_proc = data_config.num_proc
-
- self._dataset_dct = _load_datasets(
- data_dir,
- data_config.data_format,
- data_config.data_files,
- self._num_proc,
- )
-
- def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
- return self._dataset_dct.get(split, None)
-
- def get_dataset(
- self,
- split: NamedSplit,
- process_fn: Callable[[dict[str, Any]], dict[str, Any]],
- batched: bool = True,
- remove_orig_columns: bool = True,
- ) -> Optional[Dataset]:
- orig_dataset = self._get_dataset(split)
- if orig_dataset is None:
- return
-
- if remove_orig_columns:
- remove_columns = orig_dataset.column_names
- else:
- remove_columns = None
- return orig_dataset.map(
- process_fn,
- batched=batched,
- remove_columns=remove_columns,
- num_proc=self._num_proc,
- )
-
- # 数据处理函数:
- # process_message:处理单条消息,过滤无效字段。
- # process_message:过滤无效字段。
- # process_batch和process_batch_eval:批量处理数据,生成模型输入和标签。
- def process_message(message):
- if 'tools' in message and message['role'] == 'system':
- for tool in message['tools']:
- parameters = tool['function']['parameters']['properties']
- tool['function']['parameters']['properties'] = \
- {k: v for k, v in parameters.items() if
- v is not None}
- elif 'tools' in message:
- del message['tools']
- return message
-
-
- # process_batch和process_batch_eval:批量处理训练和评估数据,生成输入ID和标签ID。
- def process_batch(
- batch: Mapping[str, Sequence],
- tokenizer: PreTrainedTokenizer,
- max_input_length: int,
- max_output_length: int,
- ) -> dict[str, list]:
- batched_conv = batch['messages']
- batched_input_ids = []
- batched_labels = []
-
- for conv in batched_conv:
- input_ids = [151331, 151333]
- loss_masks = [False, False]
- for message in conv:
- message = process_message(message)
- loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
- new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
- new_loss_masks = [loss_mask_val] * len(new_input_ids)
- input_ids += new_input_ids
- loss_masks += new_loss_masks
- input_ids.append(tokenizer.eos_token_id)
- loss_masks = [False, *loss_masks]
- labels = []
- for input_id, mask in zip(input_ids, loss_masks):
- if mask:
- labels.append(input_id)
- else:
- labels.append(-100)
- max_length = max_input_length + max_output_length + 1
- batched_input_ids.append(input_ids[:max_length])
- batched_labels.append(labels[:max_length])
- return {'input_ids': batched_input_ids, 'labels': batched_labels}
- def process_batch_eval(
- batch: Mapping[str, Sequence],
- tokenizer: PreTrainedTokenizer,
- max_input_length: int,
- max_output_length: int,
- ) -> dict[str, list]:
- batched_conv = batch['messages']
- batched_input_ids = []
- batched_output_ids = []
-
- for conv in batched_conv:
-
- input_ids = [151331, 151333]
- for message in conv:
- if len(input_ids) >= max_input_length:
- break
- else:
- message = process_message(message)
- new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
- if message['role'] == 'assistant':
- output_prompt, output_ids = (
- new_input_ids[:1],
- new_input_ids[1:],
- )
- output_ids.append(tokenizer.eos_token_id)
- batched_input_ids.append(
- input_ids[:max_input_length] + output_prompt[:1]
- )
- batched_output_ids.append(output_ids[:max_output_length])
- input_ids += new_input_ids
- return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
-
-
- # load_tokenizer_and_model:加载预训练模型和分词器,支持PEFT配置(参数高效微调)。
- def load_tokenizer_and_model(
- model_dir: str,
- peft_config: Optional[PeftConfig] = None,
- ):
- tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
- if peft_config is not None:
- model = AutoModelForCausalLM.from_pretrained(
- model_dir,
- trust_remote_code=True,
- empty_init=False,
- use_cache=False,
- torch_dtype=torch.bfloat16 # Must use BFloat 16
- )
- model = get_peft_model(model, peft_config)
- model.print_trainable_parameters()
- else:
- model = AutoModelForCausalLM.from_pretrained(
- model_dir,
- trust_remote_code=True,
- empty_init=False,
- use_cache=False,
- torch_dtype=torch.bfloat16
- )
- return tokenizer, model
-
-
- # compute_metrics:计算ROUGE和BLEU分数,用于模型生成效果评估。
- def compute_metrics(eval_preds: EvalPrediction, tokenizer):
- batched_pred_ids, batched_label_ids = eval_preds
- metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
- for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
- pred_txt = tokenizer.decode(pred_ids).strip()
- label_txt = tokenizer.decode(label_ids).strip()
- pred_tokens = list(jieba.cut(pred_txt))
- label_tokens = list(jieba.cut(label_txt))
- rouge = Rouge()
- scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
- for k, v in scores[0].items():
- metrics_dct[k].append(round(v['f'] * 100, 4))
- metrics_dct['bleu-4'].append(
- sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3))
- return {k: np.mean(v) for k, v in metrics_dct.items()}
-
-
- @app.command()
- def main(
- data_dir: Annotated[str, typer.Argument(help='')],
- model_dir: Annotated[
- str,
- typer.Argument(
- help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
- ),
- ],
- config_file: Annotated[str, typer.Argument(help='')],
- auto_resume_from_checkpoint: str = typer.Argument(
- default='',
- help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
- ),
- ):
- # 0、定义命令行参数:数据目录、模型目录或预训练模型ID、配置文件路径、是否自动从最新检查点恢复训练。
- # 1、加载微调配置、模型、分词器、数据管理器
- ft_config = FinetuningConfig.from_file(config_file)
- tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
- data_manager = DataManager(data_dir, ft_config.data_config)
-
- # 2、定义数据集:获取训练集、验证集和测试集数据,分别调用 process_batch 和 process_batch_eval 进行处理
- train_dataset = data_manager.get_dataset(
- Split.TRAIN,
- functools.partial(
- process_batch,
- tokenizer=tokenizer,
- max_input_length=ft_config.max_input_length,
- max_output_length=ft_config.max_output_length,
- ),
- batched=True,
- )
- print('train_dataset:', train_dataset)
- #############新增代码,查看训练集的前几个样本###############
- if train_dataset is not None:
- print("Training dataset samples:")
- for i in range(3):
- print(tokenizer.decode(train_dataset[i]['input_ids'], skip_special_tokens=True))
-
- val_dataset = data_manager.get_dataset(
- Split.VALIDATION,
- functools.partial(
- process_batch_eval,
- tokenizer=tokenizer,
- max_input_length=ft_config.max_input_length,
- max_output_length=ft_config.max_output_length,
- ),
- batched=True,
- )
- if val_dataset is not None:
- print('val_dataset:', val_dataset)
- test_dataset = data_manager.get_dataset(
- Split.TEST,
- functools.partial(
- process_batch_eval,
- tokenizer=tokenizer,
- max_input_length=ft_config.max_input_length,
- max_output_length=ft_config.max_output_length,
- ),
- batched=True,
- )
- if test_dataset is not None:
- print('test_dataset:', test_dataset)
-
- # 3、模型训练
- # 3.1、设置模型训练属性:启用模型的梯度检查点、设置输入需要梯度
- model.gradient_checkpointing_enable()
- model.enable_input_require_grads()
-
- # 3.2、创建训练器:创建 Seq2SeqTrainer对象,用于训练和评估序列到序列模型。并且支持从检查点恢复训练。
- trainer = Seq2SeqTrainer(
- model=model,
- args=ft_config.training_args, # 指定训练参数:从配置文件加载的训练参数,包括学习率、训练轮数、批量大小等。
- data_collator=DataCollatorForSeq2Seq( # 将输入数据整理成批次,进行填充和转换为张量。
- tokenizer=tokenizer,
- padding='longest',
- return_tensors='pt',
- ),
- train_dataset=train_dataset,
- eval_dataset=val_dataset.select(list(range(50))), # 从验证数据集中选择前50条数据进行评估
- compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer), # 指定评估指标的计算方法
- )
-
- # 3.3、训练模型:根据 auto_resume_from_checkpoint参数决定是否从检查点恢复训练或重新开始训练。
- # 判断是否需要从检查点恢复训练
- if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
- trainer.train()
- else:
- # 处理从检查点恢复训练的情况
- output_dir = ft_config.training_args.output_dir
- dirlist = os.listdir(output_dir)
- checkpoint_sn = 0
- # 遍历目录列表,找到所有有效的检查点(文件名包含 "checkpoint" 且不包含 "tmp"),并记录最新的检查点编号 checkpoint_sn。
- for checkpoint_str in dirlist:
- if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
- checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
- if checkpoint > checkpoint_sn:
- checkpoint_sn = checkpoint
- # 根据 auto_resume_from_checkpoint 的值确定恢复策略
- if auto_resume_from_checkpoint.upper() == "YES":
- # 如果 auto_resume_from_checkpoint 的值为 "YES",且找到有效的检查点编号。
- if checkpoint_sn > 0:
- model.gradient_checkpointing_enable()
- model.enable_input_require_grads()
- checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
- print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
- trainer.train(resume_from_checkpoint=checkpoint_directory)
- else:
- trainer.train()
- else:
- # 如果 auto_resume_from_checkpoint 是一个正整数。
- if auto_resume_from_checkpoint.isdigit():
- if int(auto_resume_from_checkpoint) > 0:
- checkpoint_sn = int(auto_resume_from_checkpoint)
- model.gradient_checkpointing_enable()
- model.enable_input_require_grads()
- checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
- print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
- trainer.train(resume_from_checkpoint=checkpoint_directory)
- # 如果 auto_resume_from_checkpoint 既不是 "YES" 也不是正整数。打印错误信息,说明指定的检查点编号无效,需要用户手动检查和选择正确的检查点。
- else:
- print(auto_resume_from_checkpoint,
- "The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct checkpoint in the model output directory")
-
- # 4、模型评估:如果存在测试数据集,则对测试数据集进行预测评估,包括 ROUGE 和 BLEU 分数。
- if test_dataset is not None:
- trainer.predict(test_dataset)
-
-
- if __name__ == '__main__':
- app()
-
-
-
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。