赞
踩
目录
LLM 相关训练框架都会引入 ModelArguments、DataArguments、TrainingArguments、GeneratingArguments 并通过 Transformer.HfArgumentParser 进行整合,实现了两行代码处理训练全程的参数问题。
ModelArguments - 模型参数
DataArguments - 数据集参数
TrainingArguments - 训练参数
GeneratingArguments - 生成参数
- @dataclass
- class ModelArguments:
- model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
ModelArguments 主要存储模型加载与配置的相关参数,一般还有以下参数,大家可以自定义:
参数名称 | 默认 | 类型 | 含义 |
model_name_or_path | None | str | 模型地址或名称 |
cache_dir | None | str | 缓存地址 |
use_fast_tokenizer | False | bool | 使用快速 tokenizer |
padding_side | left | str | 模型 pad 选择 |
quantization_bit | None | int | 量化 bit 选择 |
compute_type | None | torch.dtype | 模型参数类型 |
checkpoint_dir | None | str | 微调参数地址 |
mode | None | str | reward、lora |
plot_loss | False | bool | 打印训练 Loss |
- @dataclass
- class DataArguments:
- data_path: str = field(
- default=None, metadata={"help": "Path to the training data."}
- )
DataArguments 主要负责数据集相关参数,数据集通过 dataset 构成,通常包含下述参数:
参数名称 | 默认 | 类型 | 含义 |
data_path | None | str | 数据集地址 |
process_num | None | int | 并行处理 |
max_source_length | 512 | int | source 最大长度 |
max_target_length | 512 | int | target 最大长度 |
max_samples | None | int | 最大样本数 |
ignore_pad_token | None | int | loss 计算是否忽略 |
prompt_template | None | str | 样本生成 prompt 模板 |
- @dataclass
- class TrainingArguments(transformers.TrainingArguments):
- cache_dir: Optional[str] = field(default=None)
- optim: str = field(default="adamw_torch")
- model_max_length: int = field(
- default=512,
- metadata={
- "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
- },
- )
- use_lora: bool = field(default=False)
- output_dir: str = field(default="")
TrainingArguments 主要存储模型微调,训练相关的参数:
参数名称 | 默认 | 类型 | 含义 |
finetuning_type | lora | str | 微调类型 |
lora_target | q_proj,v_proj | str | 微调 Layer |
lora_rank | 8 | int | lora 降维维度 |
lora_alpha | 32.0 | float | lora 微调比例因子 |
lora_dropout | 0.1 | float | dropout 比例 |
num_hidden_layers | 32 | int | Decode 数量 |
num_layer_trainable | 3 | int | freeze layer 数量 |
name_module_trainable | mlp | str | freeze 训练层选择 |
output_dir | None | str | 模型输出地址 |
- @dataclass
- class GeneratingArguments:
- do_sample: Optional[bool] = field(
- default=True,
- metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
- )
GeneratingArguments 主要负责 model generate 生成的配置:
参数名称 | 默认 | 类型 | 含义 |
do_sample | True | bool | 采样或贪心 |
temperature | 0.95 | float | 调整下一个 token 的概率 |
top_p | 0.7 | float | token 概率 top 区间 |
top_k | 50 | int | token 词库数量 |
num_beams | 1 | int | beam search 数量 |
max_length | None | int | 最大生成 token 数 |
max_new_tokens | 512 | int | 最多新 toekn 生成数 |
repatition_penalty | 1.0 | float | 重复惩罚 |
length_penalty | 1.0 | float | 长度惩罚 |
之前单独整理了生成的参数和代码,可以参考: LLM - model batch generate 生成文本
- from typing import Optional
- from dataclasses import dataclass, field
- import transformers
-
-
- ...
-
- 添加上述的 Argument Class
-
- ...
-
-
- if __name__ == '__main__':
- parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, GeneratingArguments))
- model_args, data_args, training_args, generate_args = parser.parse_args_into_dataclasses()
-
- print(model_args)
- print(data_args)
- print(training_args)
- print(generate_args)
两行搞定多类参数,参数对应属性使用 args.xxx 调用即可。
- #!/bin/bash
-
- python GetConfigByArgs.py \
- --report_to "none" \
- --data_path "data/belle_chat_ramdon_10k.json" \
- --model_name_or_path "baichuan-inc/Baichuan2-7B-Base" \
- --output_dir "output" \
- --model_max_length 512 \
- --num_train_epochs 4 \
- --per_device_train_batch_size 16 \
- --gradient_accumulation_steps 1 \
- --save_strategy epoch \
- --learning_rate 2e-5 \
- --lr_scheduler_type constant \
- --adam_beta1 0.9 \
- --adam_beta2 0.98 \
- --adam_epsilon 1e-8 \
- --max_grad_norm 1.0 \
- --weight_decay 1e-4 \
- --warmup_ratio 0.0 \
- --logging_steps 1 \
- --gradient_checkpointing True \
- --deepspeed ds_config.json \
- --bf16 False \
- --tf32 False
通过 -- 传递我们需要的参数即可。
这个没啥总结的了,就是觉得写法比较优雅,后面自己的脚本也可以借用。
赞
踩
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。