当前位置:   article > 正文

为ChatGLM-6B模型的训练纪实:从数据集准备到LLamA-Factory的高效应用(二)_llama-factory 数据集

llama-factory 数据集

阶段成果展示

一、LLamA-Factory简介

LLamA-Factory 是一个高效的模型训练工具,支持多种大规模语言模型(如LLaMA、ChatGLM等)的微调。该工具集成了增量预训练、指令监督微调、奖励模型训练等多种方法,支持多种精度和先进算法(如LoRA、QLoRA)。LLamA-Factory 提供丰富的实验监控工具(如TensorBoard、Wandb),并优化了训练和推理速度。通过结合多种模型和训练方法,LLamA-Factory 能显著提升训练效率和模型性能。详细内容请访问 LLamA-Factory (GitHub)

二、环境配置与准备[^1]

  • 硬件要求:内存32G以上(最低)、GPU 1*NVIDIA V100、显存24G以上(最低)
  1. 创建DSW,模型和LLamA-Factory安装过程详见脚注文章在这里插入图片描述

  2. 启动LLamA-Factory web界面(命令可能因版本不同而异,以下为v0.6.0版本)

     cd LLaMA-Factory
     CUDA_VISIBLE_DEVICE=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
    
    • 1
    • 2

    点击即可运行在这里插入图片描述
    在这里插入图片描述

三、训练过程

1. 超参数的设置(针对本主机24G显存进行优化)

在使用LLamA-Factory训练ChatGLM-6B时,可以调节以下关键超参数:

  1. 学习率(Learning Rate)

    • 参数值:5e-5
    • 建议值:5e-5(保持不变)
    • 描述:AdamW优化器的初始学习率,决定模型更新权重的速度。默认学习率通常已较为合适。
  2. 训练轮数(Epochs)

    • 参数值:3.0
    • 描述:需要执行的训练总轮数,定义整个数据集被模型看过的次数。
  3. 最大梯度范数(Max Gradient Norm)

    • 参数值:1.0
    • 建议值:1.0(保持不变)
    • 描述:用于梯度裁剪的范数,防止梯度爆炸。
  4. 最大样本数(Max Samples)

    • 参数值:100000
    • 描述:每个数据集的最大样本数,限制训练数据量。
  5. 批处理大小(Batch Size)

    • 参数值:2
    • 建议值:8
    • 描述:每个GPU处理的样本数量。24G显存较大,可以适当增加批处理大小以充分利用显存资源。
  6. 梯度累积(Gradient Accumulation Steps)

    • 参数值:8
    • 建议值:2
    • 描述:梯度累积的步数,用于模拟更大的批处理大小。增加批处理大小后,可以减少梯度累积的步数,提高训练速度。
  7. 计算类型(Precision Type)

    • 参数值:fp16
    • 建议值:fp16(保持不变)
    • 描述:是否使用混合精度训练,fp16可加快训练速度并减少显存占用。
  8. 截断长度(Max Sequence Length)

    • 参数值:1024
    • 建议值:1024(保持不变)
    • 描述:输入序列分词后的最大长度,一般不能低于一个条目最低token数。输入序列的长度对显存的占用较大,1024较为合理。
      在这里插入图片描述

2. 训练策略

  1. 训练阶段(Training Phase)
    • 当前方式:Supervised Fine-Tuning(监督微调)
    • 描述:基于标注数据进行监督学习,提高模型在特定任务上的表现。
  2. 验证集比例(Validation Split)
    • 参数值:0
    • 建议值:0.1
    • 描述:验证集用于调整模型的超参数,以及监测模型是否过拟合或者欠拟合。它在训练过程中没有直接参与参数的更新。
  3. 学习率调节器(Learning Rate Scheduler)
    • 参数值:cosine
    • 建议值:cosine(保持不变)
    • 描述:学习率调度器的名称,使用cosine调度器动态调整学习率。

上述设置完成后选取数据集进行训练在这里插入图片描述

3.训练过程:

在这里插入图片描述在这里插入图片描述

4. 训练监控与调优

  1. 监控训练过程:实时监控训练过程中各项指标(如损失值、准确率等)。这类指标在LLamA-Factory中一站式提供。
  2. 模型保存与恢复:定期保存模型,支持从断点继续训练,确保训练过程的连续性和安全性。这样可以防止由于中断(如断电或系统故障)导致的训练进度丢失。此外,还可以在早停策略触发时恢复到最优模型。
  3. 早停策略:根据验证集的性能动态调整训练进度,避免过拟合,确保模型在实际应用中的表现。
    在这里插入图片描述

结合损失函数图,从图中可以看到损失值在训练过程中逐步下降,且趋于平稳。这表明模型在逐渐收敛。监控训练和验证集的损失值可以帮助判断训练进度和调整策略。例如,如果在多个epoch内验证集损失不再下降,则可以触发早停策略,停止训练以避免过拟合

在这里插入图片描述

从上图中的损失函数来看,可以观察到:1. 波动性:损失函数的值有显著的波动。这可能表明模型在训练过程中存在一定的不稳定性,可能是由于数据噪声或者模型参数调整引起的。2. 原始与平滑值:图中展示了原始损失函数值和经过平滑处理后的损失函数值。平滑处理后的曲线波动较小,更能反映出损失函数的总体趋势。这种处理有助于更好地观察损失函数的收敛情况。3. 趋势分析:从平滑后的曲线来看,损失函数并没有明显的单调下降趋势,这可能意味着模型尚未收敛到一个较优的状态。这种情况可能需要调整学习率、模型结构或者进一步优化训练数据。

一般此类情况出现在训练的刚开始阶段(因为横坐标步数很少的情况下纵坐标的改变微乎其微),此时可以继续观察损失函数是否呈下降趋势(如上图);否则可以及时采取早停策略。

[^ 1 ]参考组员文章:ChatGlm3-6B的部署及微调流程

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/834066
推荐阅读
相关标签
  

闽ICP备14008679号