当前位置:   article > 正文

ChatGLM-6B模型训练自己的数据集_chatglm-6b训练自己数据

chatglm-6b训练自己数据

ChatGLM-6B模型训练自己的数据集

上期我主要分享了一下ChatGLM-6B官方模型的部署、官方数据集的微调、推理以及测试过程,这期我将主要分享一下使用ChatGLM-6B微调自己数据集的过程。上期链接

1.首先将自己处理好的数据集拷贝到’ChatGLM-6B/ptuning/’文件夹下,可以新建一个自己的数据集文件夹如mydata。

我新建的文件夹mydata

2.首先要修改train.sh中的参数,官方train.sh文档:

PRE_SEQ_LEN=128
LR=2e-2

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path THUDM/chatglm-6b \
    --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

参数解释如下(自己的理解,不正确欢迎指正):

1. PRE_SEQ_LEN:这个是输入的最大序列长度,可以根据你的数据集适当调大或调小,一般64-128是一个合理范围。
2. LR:学习率,可以适当调小,如1e-3 - 5e-3,因为这是微调,学习率不需要太大。
3. CUDA_VISIBLE_DEVICES:设置使用的GPU设备。
4. --train_file 和 --validation_file:设置你自己的数据集路径。
5. --model_name_or_path:设置为 chatglm-6b 的模型路径。
6. --output_dir:设置微调后的模型和日志输出路径。
7.- max_source_length 指定我们输入对话文本(即软提示)的最大长度。如果某个输入文本超过这个长度,则会截断;如果短于这个长度,则会在末尾填充padding。
- max_target_length 指定模型输出的响应文本的最大长度。如果模型生成的响应超过这个长度,则会截断;如果短于这个长度,则会在末尾填充padding。
-所以,这两个参数的设置值需要根据你的硬件情况和数据特点来确定。一般来说,64-128对于输入,32-64对于输出是比较合理的范围。你可以在开发集上进行测试,观察模型输出的响应结果和计算开销来选择最优值。
设置这两个参数的目的是:
(1) 使所有训练实例输入和输出的长度统一,以方便模型的训练和batch的构造。
(2) 避免过长的输入和输出导致的计算开销大和内存超限的问题。
(3) 鼓励模型学习如何在约束长度内生成更精准和相关的响应。
7. --gradient_accumulation_steps:设置梯度累积步数,可以适当增大,如8-32,以便使用更大的batch size。
8. --per_device_train_batch_size 和 --per_device_eval_batch_size:可以适当增大,以加快训练速度,如4-8。
9. --max_steps:设置最大训练步数,根据数据集大小适当调整,一般3000-10000步是一个合理范围。
11. --logging_steps 和 --save_steps:设置日志记录步数和模型保存步数,可以根据需要自己更改。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
12.--quantization_bit:模型参数的精度选择。
这个参数是关于 P-Tuning 方法中的模型量化(model quantization)方法的选择。
P-Tuning 方法中的量化是指将模型参数从浮点数(FP32)精度量化为低精度,如FP16(半精度)或更低。这可以大大减小模型的参数量,加速推理速度,同时牺牲一定的精度。
在这句话中,作者指出:
1. P-Tuning 方法会冻结(freeze)预训练模型(如GPT2)的全部参数,以此作为量化的起点。
2. 通过调整quantization_bit 参数可以选择不同的量化级别。不设置此参数则默认使用FP16半精度。
3. 调整量化级别会影响模型精度,需要在验证集上测试不同的量化级别,选择精度损失最小的配置。
也就是说,P-Tuning 方法采用冻结预训练模型,然后对其进行量化和微调的策略。设置quantization_bit参数可以选择FP16之外的更低精度,如8bit或4bit,以获得更高的加速比,同时尽可能保留精度。
量化会使模型参数变得更加稀疏,进而可以大幅压缩模型体积和加速推理。但同时也会损失一定精度。所以需要根据实际情况选择一个合理的trade-off。
总的来说,这句话的意思就是说明P-Tuning方法采用了模型冻结和可调量化级别来实现从FP32到低精度的转换,以获得较高的加速比。quantization_bit参数可以选择不同的量化精度,无此参数则默认为FP16。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

3.我的train.sh参数设置

PRE_SEQ_LEN=128
LR=2e-3 #因为是微调所以我考虑减小了学习率

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_train \
    --train_file mydata/train.json \  #我的训练数据集地址
    --validation_file mydata/test.json \  #我的测试数据集地址
    --prompt_column question\  #我的数据集标签
    --response_column answer\  #我的数据集标签
    --overwrite_cache \
    --model_name_or_path THUDM/chatglm-6b \
    --output_dir myoutput/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \   #输入句子的最大目标长度
    --max_target_length 512\   #输出句子的最大目标长度
    --per_device_train_batch_size 4 \  #由于我的内存比较大,我增大了batch_size
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 8 \   
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 50 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

这里我去掉了 --quantization_bit 4,考虑到我有足够的内存,我采用了精度较高的FP16

4.推理

推理参数调整参考微调。

5.利用微调后的模型进行验证

参考上期

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/358036
推荐阅读
相关标签
  

闽ICP备14008679号