当前位置:   article > 正文

LLM - Model、Data、Training、Generate Agruments 超参解析_modelargs

modelargs

目录

一.引言

二.常用参数

◆ ModelArguments

◆ DataArguments

◆ TrainingArguments

◆ GeneratingArguments

三.代码实现

◆ Python 代码

◆ Shell 代码

四.总结


一.引言

LLM 相关训练框架都会引入 ModelArguments、DataArguments、TrainingArguments、GeneratingArguments 并通过 Transformer.HfArgumentParser 进行整合,实现了两行代码处理训练全程的参数问题。

ModelArguments - 模型参数

DataArguments - 数据集参数

TrainingArguments - 训练参数

GeneratingArguments - 生成参数

二.常用参数

◆ ModelArguments

  1. @dataclass
  2. class ModelArguments:
  3. model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")

ModelArguments 主要存储模型加载与配置的相关参数,一般还有以下参数,大家可以自定义:

参数名称默认类型含义
model_name_or_pathNonestr模型地址或名称
cache_dirNonestr缓存地址
use_fast_tokenizerFalsebool使用快速 tokenizer
padding_sideleftstr模型 pad 选择
quantization_bitNoneint量化 bit 选择
compute_typeNonetorch.dtype模型参数类型
checkpoint_dirNonestr微调参数地址
modeNonestrreward、lora
plot_lossFalsebool打印训练 Loss

◆ DataArguments

  1. @dataclass
  2. class DataArguments:
  3. data_path: str = field(
  4. default=None, metadata={"help": "Path to the training data."}
  5. )

DataArguments 主要负责数据集相关参数,数据集通过 dataset 构成,通常包含下述参数:

参数名称默认类型含义
data_pathNonestr数据集地址
process_numNoneint并行处理
max_source_length512intsource 最大长度
max_target_length512inttarget 最大长度
max_samplesNoneint最大样本数
ignore_pad_tokenNoneintloss 计算是否忽略
prompt_templateNonestr样本生成 prompt 模板

◆ TrainingArguments

  1. @dataclass
  2. class TrainingArguments(transformers.TrainingArguments):
  3. cache_dir: Optional[str] = field(default=None)
  4. optim: str = field(default="adamw_torch")
  5. model_max_length: int = field(
  6. default=512,
  7. metadata={
  8. "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
  9. },
  10. )
  11. use_lora: bool = field(default=False)
  12. output_dir: str = field(default="")

TrainingArguments 主要存储模型微调,训练相关的参数:

参数名称默认类型含义
finetuning_typelorastr微调类型
lora_targetq_proj,v_projstr微调 Layer
lora_rank8intlora 降维维度
lora_alpha32.0floatlora 微调比例因子
lora_dropout0.1floatdropout 比例
num_hidden_layers32intDecode 数量
num_layer_trainable3intfreeze layer 数量
name_module_trainablemlpstrfreeze 训练层选择
output_dirNonestr模型输出地址

◆ GeneratingArguments

  1. @dataclass
  2. class GeneratingArguments:
  3. do_sample: Optional[bool] = field(
  4. default=True,
  5. metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
  6. )

GeneratingArguments 主要负责 model generate 生成的配置:

参数名称默认类型含义
do_sampleTruebool采样或贪心
temperature0.95float调整下一个 token 的概率
top_p0.7floattoken 概率 top 区间
top_k50inttoken 词库数量
num_beams1intbeam search 数量
max_lengthNoneint最大生成 token 数
max_new_tokens512int最多新 toekn 生成数
repatition_penalty1.0float重复惩罚
length_penalty1.0float长度惩罚

之前单独整理了生成的参数和代码,可以参考: LLM - model batch generate 生成文本

三.代码实现

◆ Python 代码

  1. from typing import Optional
  2. from dataclasses import dataclass, field
  3. import transformers
  4. ...
  5. 添加上述的 Argument Class
  6. ...
  7. if __name__ == '__main__':
  8. parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, GeneratingArguments))
  9. model_args, data_args, training_args, generate_args = parser.parse_args_into_dataclasses()
  10. print(model_args)
  11. print(data_args)
  12. print(training_args)
  13. print(generate_args)

两行搞定多类参数,参数对应属性使用 args.xxx 调用即可。

Shell 代码

  1. #!/bin/bash
  2. python GetConfigByArgs.py \
  3. --report_to "none" \
  4. --data_path "data/belle_chat_ramdon_10k.json" \
  5. --model_name_or_path "baichuan-inc/Baichuan2-7B-Base" \
  6. --output_dir "output" \
  7. --model_max_length 512 \
  8. --num_train_epochs 4 \
  9. --per_device_train_batch_size 16 \
  10. --gradient_accumulation_steps 1 \
  11. --save_strategy epoch \
  12. --learning_rate 2e-5 \
  13. --lr_scheduler_type constant \
  14. --adam_beta1 0.9 \
  15. --adam_beta2 0.98 \
  16. --adam_epsilon 1e-8 \
  17. --max_grad_norm 1.0 \
  18. --weight_decay 1e-4 \
  19. --warmup_ratio 0.0 \
  20. --logging_steps 1 \
  21. --gradient_checkpointing True \
  22. --deepspeed ds_config.json \
  23. --bf16 False \
  24. --tf32 False

通过 -- 传递我们需要的参数即可。

四.总结

这个没啥总结的了,就是觉得写法比较优雅,后面自己的脚本也可以借用。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/601561
推荐阅读
相关标签
  

闽ICP备14008679号