当前位置:   article > 正文

基于 P-Tuning的高效微调ChatGLM2-6B_chatglm2-6b ptuning微调原理

chatglm2-6b ptuning微调原理

1 ChatGLM2-6B介绍

ChatGLM是清华技术成果转化的公司智谱AI研发的支持中英双语的对话机器人。ChatGLM基于GLM130B千亿基础模型训练,它具备多领域知识、代码能力、常识推理及运用能力;支持与用户通过自然语言对话进行交互,处理多种自然语言任务。比如:对话聊天、智能问答、创作文章、创作剧本、事件抽取、生成代码等等。

代码地址:https://github.com/THUDM/ChatGLM2-6B

ChatGLM2-6B是第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,又增加许多新特性:

  • 更强大的性能:基于ChatGLM初代模型的开发经验,全面升级了ChatGLM2-6B的基座模型。ChatGLM2-6B使用了GLM的混合目标函数,经过了1.4T中英标识符的预训练与人类偏好对齐训练。评测结果显示,与初代模型相比,ChatGLM2-6B在MMLU(+23%)、CEval(+33%)、GSM8K(+571%) 、BBH(+60%)等数据集上的性能取得了大幅度的提升,在同尺寸开源模型中具有较强的竞争力。

  • 更长的上下文:基于 FlashAttention 技术,研究人员将基座模型的上下文长度由 ChatGLM-6B 的2K扩展到了32K,并在对话阶段使用8K的上下文长度训练,允许更多轮次的对话。但当前版本的ChatGLM2-6B对单轮超长文档的理解能力有限,会在后续迭代升级中着重进行优化。

  • 更高效的推理:基于 Multi-Query Attention 技术,ChatGLM2-6B有更高效的推理速度和更低的显存占用。在官方的模型实现下,推理速度相比初代提升了42%,INT4量化下,6G显存支持的对话长度由1K提升到了8K。

  • 更开放的协议:ChatGLM2-6B权重对学术研究完全开放,在获得官方的书面许可后,亦允许商业使用。

相比于初代模型,ChatGLM2-6B在数理逻辑、知识推理、长文档理解等多个维度的能力上,都取得了巨大的提升。

2 P-Tuning V2介绍

论文地址:https://arxiv.org/pdf/2110.07602.pdf

代码地址:https://github.com/THUDM/P-tuning-v2

2.1 背景

之前的Prompt Tuning和P-Tuning等方法存在两个主要的问题:

第一,缺乏模型参数规模和任务通用性。

  • 缺乏规模通用性:Prompt Tuning论文中表明当模型规模超过100亿个参数时,提示优化可以与全量微调相媲美。但是对于那些较小的模型(从100M到1B),提示优化和全量微调的表现有很大差异,这大大限制了提示优化的适用性。

  • 缺乏任务普遍性:尽管Prompt Tuning和P-tuning在一些 NLU 基准测试中表现出优势,但提示调优对硬序列标记任务(即序列标注)的有效性尚未得到验证。

第二,缺少深度提示优化,在Prompt Tuning和P-tuning中,连续提示只被插入transformer第一层的输入embedding序列中,在接下来的transformer层中,插入连续提示的位置的embedding是由之前的transformer层计算出来的,这可能导致两个可能的优化挑战。

  • 由于序列长度的限制,可调参数的数量是有限的。
  • 输入embedding对模型预测只有相对间接的影响。

考虑到这些问题,作者提出了Ptuning v2,它利用深度提示优化(如:Prefix Tuning),对Prompt Tuning和P-Tuning进行改进,作为一个跨规模和NLU任务的通用解决方案。

P-Tuning v2是对prefix-tuning和p-tuning进行的优化。prefix-tuning等存在一些问题:

  • 是针对于生成任务而言的,不能处理困难的序列标注任务、抽取式问答等,缺乏普遍性。【解决方法,分类还是使用CLS或者token。】
  • 当模型规模较小,特别是小于100亿个参数时,它仍然不如微调法。【解决方法:在每一层都加上prompt。】

2.2 技术原理

P-Tuning v2是一个用于改进预训练语言模型(Pre-trained Language Model,PLM)偏见的方法。其原理可以总结如下:

  • 样本选择:首先,从一个大规模的文本语料库中选择一部分样本作为训练集。这些样本应当具有多样性,包括不同的文化、背景和价值观。

  • PLM预训练:在选定的样本上进行预训练,生成一个初始的PLM模型。这个模型包含了很多词语的上下文信息,以及它们之间的关联性。

  • 特征定义:定义一个特征函数,用来指示某个词对于特定偏见的敏感程度。这些特征函数可以包括词语的含义、出现上下文的频次等等。

  • 偏见调整:通过修改样本中某些词语的上下文,缩小其与偏见之间的相关性。具体来说,对于某个特定的偏见,可以通过修改相关的样本以降低这个偏见对PLM的影响。

  • 约束优化:为了控制偏见调整的程度,引入一个约束函数来度量样本的平衡性。这个约束函数可以包括不同群体的样本分布、词语的多样性等等。

  • 迭代训练:在约束优化的框架下,反复调整样本和PLM模型,直到达到平衡的状态。这样可以在尽量保持语言模型质量的同时,尽量减小偏见。

P-Tuning v2的原理是通过定义特征函数和约束函数,以及调整样本和PLM模型的方法,来优化预训练语言模型的偏见问题。这个方法可以用于改进PLM的性别、种族、政治观点等各种偏见。

2.3 P-Tuning v2的优点

  • P-tuning v2在不同的模型规模(从300M到100B的参数)和各种困难的NLU任务(如问答和序列标注)上的表现与微调相匹配。

  • 与微调相比,P-tuning v2每个任务的可训练参数为0.1%到3%,这大大降低了训练时间的内存消耗和每个任务的存储成本

4 基于P-Tuning微调ChatGLM2-6B

4.1 ChatGLM2-6B部署

部署文档详见:https://mp.csdn.net/mp_blog/creation/editor/135084490

4.2 数据集下载

ADGEN 数据集为根据输入(content)生成一段广告词(summary)。官方给出的数据处理格式:

  1. {
  2. "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
  3. "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
  4. }

数据格式是{“content”:"","summary":""}形式;如果是问答类数据,content就是问题,summary就是回答。 

数据下载地址:https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1

下载完成后,将解压后的 AdvertiseGen 目录放到ptuning本目录下,如下所示:

  1. (chatglm) [root@localhost ChatGLM2-6B]# ll ptuning/
  2. 总用量 96
  3. drwxr-xr-x 2 root root 40 1218 20:17 AdvertiseGen
  4. -rw-r--r-- 1 root root 8474 1218 20:17 arguments.py
  5. -rw-r--r-- 1 root root 489 1218 20:17 deepspeed.json
  6. -rw-r--r-- 1 root root 768 1218 20:17 ds_train_finetune.sh
  7. -rw-r--r-- 1 root root 603 1218 20:17 evaluate_finetune.sh
  8. -rw-r--r-- 1 root root 702 1218 20:17 evaluate.sh
  9. -rwxr-xr-x 1 root root 17804 1218 20:17 main.py
  10. drwxr-xr-x 2 root root 106 1218 20:18 __pycache__
  11. -rw-r--r-- 1 root root 9567 1218 20:17 README.md
  12. -rw-r--r-- 1 root root 823 1218 20:17 train_chat.sh
  13. -rw-r--r-- 1 root root 3155 1218 20:17 trainer.py
  14. -rw-r--r-- 1 root root 11508 1218 20:17 trainer_seq2seq.py
  15. -rw-r--r-- 1 root root 833 1218 20:17 train.sh
  16. -rw-r--r-- 1 root root 6014 1218 20:17 web_demo.py
  17. -rw-r--r-- 1 root root 219 1218 20:17 web_demo.sh

4.3 环境构建

  1. pip install rouge_chinese nltk jieba datasets
  2. pip install transformers==4.30.2

4.4 启动P-Tuning微调

仓库代码实现了对于 ChatGLM2-6B 模型基于 P-Tuning v2 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。

代码地址:https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning

更改训练脚本:

vi ptuning/train.sh 
  1. PRE_SEQ_LEN=128
  2. LR=2e-2
  3. NUM_GPUS=1
  4. torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
  5. --do_train \
  6. --train_file AdvertiseGen/train.json \
  7. --validation_file AdvertiseGen/dev.json \
  8. --preprocessing_num_workers 10 \
  9. --prompt_column content \
  10. --response_column summary \
  11. --overwrite_cache \
  12. --model_name_or_path ../THUDM/chatglm2-6b \
  13. --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
  14. --overwrite_output_dir \
  15. --max_source_length 64 \
  16. --max_target_length 128 \
  17. --per_device_train_batch_size 1 \
  18. --per_device_eval_batch_size 1 \
  19. --gradient_accumulation_steps 16 \
  20. --predict_with_generate \
  21. --max_steps 3000 \
  22. --logging_steps 10 \
  23. --save_steps 1000 \
  24. --learning_rate $LR \
  25. --pre_seq_len $PRE_SEQ_LEN \
  26. --quantization_bit 4
  • train.sh 中的 PRE_SEQ_LENLR 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。

  • 在默认配置 quantization_bit=4per_device_train_batch_size=1gradient_accumulation_steps=16 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。

  • 为本地加载模型,将 train.sh 中的 THUDM/chatglm2-6b 改为你本地的模型路径。

通过bash命令启动训练:

  1. cd ptuning/
  2. bash train.sh

查看GPU使用:

  1. (chatglm) [root@localhost ptuning]# nvidia-smi
  2. Thu Jan 4 14:16:59 2024
  3. +-----------------------------------------------------------------------------+
  4. | NVIDIA-SMI 525.85.05 Driver Version: 525.85.05 CUDA Version: 12.0 |
  5. |-------------------------------+----------------------+----------------------+
  6. | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
  7. | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
  8. | | | MIG M. |
  9. |===============================+======================+======================|
  10. | 0 Tesla V100S-PCI... Off | 00000000:0B:00.0 Off | 0 |
  11. | N/A 75C P0 245W / 250W | 7830MiB / 32768MiB | 97% Default |
  12. | | | N/A |
  13. +-------------------------------+----------------------+----------------------+
  14. +-----------------------------------------------------------------------------+
  15. | Processes: |
  16. | GPU GI CI PID Type Process name GPU Memory |
  17. | ID ID Usage |
  18. |=============================================================================|
  19. | 0 N/A N/A 16098 C ...nvs/chatglm/bin/python3.9 7826MiB |
  20. +-----------------------------------------------------------------------------+

4.5 P-Tuning微调结果测试

修改ptuning/web_demo.py

vi ptuning/web_demo.py 
  1. demo.queue().launch(share=False, inbrowser=True)
  2. 修改为:
  3. demo.queue().launch(share=False, server_name='0.0.0.0', inbrowser=True)

修改ptuning/web_demo.sh

vi ptuning/web_demo.sh 
  1. PRE_SEQ_LEN=128
  2. CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \
  3. --model_name_or_path THUDM/chatglm2-6b \
  4. --ptuning_checkpoint output/adgen-chatglm2-6b-pt-128-2e-2/checkpoint-3000 \
  5. --pre_seq_len $PRE_SEQ_LEN
  6. 修改为:
  7. PRE_SEQ_LEN=128
  8. CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \
  9. --model_name_or_path ../THUDM/chatglm2-6b \
  10. --ptuning_checkpoint output/adgen-chatglm2-6b-pt-128-2e-2/checkpoint-3000 \
  11. --pre_seq_len $PRE_SEQ_LEN

 运行微调后的模型:

  1. cd ptuning/
  2. bash web_demo.sh

 

测试1:

  1. input:类型#上衣材质#牛仔布颜色#白色风格#简约图案#刺绣衣样式#外套衣款式#破洞
  2. Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
  3. 微调前输出:这件上衣由牛仔布材质制成,采用了简约风格,图案为刺绣设计,衣样式为外套,衣款式为破洞。
  4. 微调后输出:白色牛仔外套,简约大方的款式,白色,搭配经典牛仔色,更加大气。领口处白色的刺绣,浪漫精致。下摆破洞,独特个性。

测试2:

  1. input:类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领
  2. Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。
  3. 微调前输出:这件上衣由牛仔布材质制成,采用了简约风格,图案为刺绣设计,衣样式为外套,衣款式为破洞。
  4. 微调后输出:这一款连衣裙,简约的白色系,搭配上撞色的印染,带来一种时髦的文艺风。修身的版型,勾勒出曼妙的身材曲线,显得高挑。

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/613803
推荐阅读
相关标签
  

闽ICP备14008679号