赞
踩
接上一篇:
云服务器部署开源ChatGLM-6B,让你也能拥有自己的ChatGPThttps://gblfy.blog.csdn.net/article/details/130682359
声明:基于ChatGLM-6B模型微调,因此大家请参考上一篇完成ChatGLM-6B模型
官网链接:https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning
cd ChatGLM-6B/ptuning/
pip install transformers==4.27.1 rouge_chinese nltk jieba datasets -i https://pypi.tuna.tsinghua.edu.cn/simple some-package
从 Tsinghua Cloud 下载处理好的 ADGEN 数据集,将解压后的 AdvertiseGen 目录放到本目录下。
wget https://cloud.tsinghua.edu.cn/seafhttp/files/c471a957-72fe-49d7-94fd-3a847c7990e6/AdvertiseGen.tar.gz
tar -zxvf AdvertiseGen.tar.gz
vim train.sh
调整如下配置
train_file 默认
validation_file 默认
model_name_or_path 模型路径
output_dir 输出路径
PRE_SEQ_LEN=128 LR=2e-2 CUDA_VISIBLE_DEVICES=0 python3 main.py \ --do_train \ --train_file ./AdvertiseGen/train.json \ --validation_file ./AdvertiseGen/dev.json \ --prompt_column content \ --response_column summary \ --overwrite_cache \ --model_name_or_path /home/user/imported_models/model/chatglm-6b \ --output_dir ./ouput/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ --overwrite_output_dir \ --max_source_length 64 \ --max_target_length 64 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 16 \ --predict_with_generate \ --max_steps 3000 \ --logging_steps 10 \ --save_steps 1000 \ --learning_rate $LR \ --pre_seq_len $PRE_SEQ_LEN \ --quantization_bit 4
运行以下指令进行训练:
bash train.sh
训练完成后会生成checkpoint-1000、checkpoint-2000、checkpoint-3000三个文件夹
vim evaluate.sh
调整配置
validation_file 默认
test_file
model_name_or_path 模型路径
ptuning_checkpoint 训练后生成的checkpoint-3000
PRE_SEQ_LEN=128 CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2 STEP=3000 CUDA_VISIBLE_DEVICES=0 python3 main.py \ --do_predict \ --validation_file ./AdvertiseGen/dev.json \ --test_file ./AdvertiseGen/dev.json \ --overwrite_cache \ --prompt_column content \ --response_column summary \ --model_name_or_path /home/user/imported_models/model/chatglm-6b \ --ptuning_checkpoint ./ouput/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \ --output_dir ./output/$CHECKPOINT \ --overwrite_output_dir \ --max_source_length 64 \ --max_target_length 64 \ --per_device_eval_batch_size 1 \ --predict_with_generate \ --pre_seq_len $PRE_SEQ_LEN \ --quantization_bit 4
bash evaluate.sh
推理完成后会生成-评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在 ./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt。
修改web_demo.sh
vim web_demo.sh
PRE_SEQ_LEN=128
CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \
--model_name_or_path /home/user/imported_models/model/chatglm-6b \
--ptuning_checkpoint output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \
--pre_seq_len $PRE_SEQ_LEN
bash web_demo.sh
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。