当前位置:   article > 正文

ChatGLM-6B 模型训练 P-Tuning 微调实战_chatglm-6b 微调 项目 github 数据

chatglm-6b 微调 项目 github 数据

接上一篇:
云服务器部署开源ChatGLM-6B,让你也能拥有自己的ChatGPThttps://gblfy.blog.csdn.net/article/details/130682359

声明:基于ChatGLM-6B模型微调,因此大家请参考上一篇完成ChatGLM-6B模型

官网链接:https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning
在这里插入图片描述

一、软件依赖
cd ChatGLM-6B/ptuning/
pip install  transformers==4.27.1 rouge_chinese nltk jieba datasets -i https://pypi.tuna.tsinghua.edu.cn/simple some-package
  • 1
  • 2

Tsinghua Cloud 下载处理好的 ADGEN 数据集,将解压后的 AdvertiseGen 目录放到本目录下。

二、下载数据集
wget https://cloud.tsinghua.edu.cn/seafhttp/files/c471a957-72fe-49d7-94fd-3a847c7990e6/AdvertiseGen.tar.gz

tar -zxvf AdvertiseGen.tar.gz
  • 1
  • 2
  • 3
三、模型训练
vim train.sh
  • 1

调整如下配置
train_file 默认
validation_file 默认
model_name_or_path 模型路径
output_dir 输出路径

PRE_SEQ_LEN=128
LR=2e-2

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_train \
    --train_file ./AdvertiseGen/train.json \
    --validation_file ./AdvertiseGen/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path /home/user/imported_models/model/chatglm-6b \
    --output_dir ./ouput/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

运行以下指令进行训练:

bash train.sh
  • 1

训练完成后会生成checkpoint-1000、checkpoint-2000、checkpoint-3000三个文件夹

四、模型推理
vim evaluate.sh
  • 1

调整配置
validation_file 默认
test_file
model_name_or_path 模型路径
ptuning_checkpoint 训练后生成的checkpoint-3000

PRE_SEQ_LEN=128
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
STEP=3000

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_predict \
    --validation_file ./AdvertiseGen/dev.json \
    --test_file ./AdvertiseGen/dev.json \
    --overwrite_cache \
    --prompt_column content \
    --response_column summary \
    --model_name_or_path /home/user/imported_models/model/chatglm-6b \
    --ptuning_checkpoint ./ouput/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \
    --output_dir ./output/$CHECKPOINT \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_eval_batch_size 1 \
    --predict_with_generate \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
bash evaluate.sh
  • 1

推理完成后会生成-评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在 ./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt。

五、推理验证

修改web_demo.sh

vim web_demo.sh
  • 1
PRE_SEQ_LEN=128

CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \
    --model_name_or_path /home/user/imported_models/model/chatglm-6b \
    --ptuning_checkpoint output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \
    --pre_seq_len $PRE_SEQ_LEN
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
bash  web_demo.sh
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/804158
推荐阅读
相关标签
  

闽ICP备14008679号