赞
踩
\2023S-T1-A-Data[供选手] 0513\指令数据&标签
,找到我们的训练集测试集数据和标签,共有3个 .json
文件。jupyter notebook
环境中,配置好数据集,效果如下:import json import pandas as pd import re from collections import defaultdict, Counter from bs4 import BeautifulSoup # 配置文件路径,读取数据集 LABEL_PATH = '/kaggle/input/web-auto-navigation-dataset/' label_trainset = json.load(open(LABEL_PATH + 'label_trainset.json', encoding="utf-8")) instruction_trainset = json.load(open(LABEL_PATH + 'instruction_trainset.json', encoding="utf-8")) instruction_testA = json.load(open(LABEL_PATH + 'instruction_testA.json', encoding="utf-8")) import random chat_data = [] for idx in range(20): for instructions in instruction_trainset[idx]['instruction_detail']: chat_data.append({ 'prompt': instructions['instruction'], 'response': ';'.join([f'{key}' for key, value in instructions['key-value'].items()]), "history": [] }) random.shuffle(chat_data) with open('train.json', 'w') as up: for line in chat_data[:-400]: up.write(json.dumps(line)+'\n') with open('dev.json', 'w') as up: for line in chat_data[-400:]: up.write(json.dumps(line)+'\n')
/kaggle/working/
目录下找到:git clone
github仓库。!git clone https://github.com/THUDM/ChatGLM-6B.git
requirements.txt
中的必要工具包。!pip install -r /kaggle/working/ChatGLM-6B/requirements.txt
train_chat.sh
文件cd /kaggle/working/ChatGLM-6B/ptuning
--model_name_or_path
我们训练4-bit量化版本 THUDM/chatglm-6b-int4
。THUDM/chatglm-6b
版本会爆显存,因为会先将模型下载下来再进行量化。PRE_SEQ_LEN
、max_source_length
、max_target_length
过大导致模型训练较慢,可以适当修改。with open("/kaggle/working/ChatGLM-6B/ptuning/train_chat.sh", mode='w') as f: f.write('CHAT_TRAIN_DATA=/kaggle/working/train.json'+'\n') f.write('CHAT_VAL_DATA=/kaggle/working/dev.json'+'\n') f.write('CHECKPOINT_NAME=/kaggle/working'+'\n') f.write('PRE_SEQ_LEN=128'+'\n') f.write('LR=1e-3'+'\n') f.write('CUDA_VISIBLE_DEVICES=0 python3 main.py \\'+'\n') f.write(' --do_train \\'+'\n') f.write(' --train_file $CHAT_TRAIN_DATA \\'+'\n') f.write(' --validation_file $CHAT_VAL_DATA \\'+'\n') f.write(' --prompt_column prompt \\'+'\n') f.write(' --response_column response \\'+'\n') f.write(' --history_column history \\'+'\n') f.write(' --overwrite_cache \\'+'\n') f.write(' --model_name_or_path THUDM/chatglm-6b-int4 \\'+'\n') f.write(' --output_dir $CHECKPOINT_NAME \\'+'\n') f.write(' --overwrite_output_dir \\'+'\n') f.write(' --max_source_length 256 \\'+'\n') f.write(' --max_target_length 256 \\'+'\n') f.write(' --per_device_train_batch_size 1 \\'+'\n') f.write(' --per_device_eval_batch_size 1 \\'+'\n') f.write(' --gradient_accumulation_steps 16 \\'+'\n') f.write(' --predict_with_generate \\'+'\n') f.write(' --max_steps 3000 \\'+'\n') f.write(' --logging_steps 100 \\'+'\n') f.write(' --save_steps 1000 \\'+'\n') f.write(' --learning_rate $LR \\'+'\n') f.write(' --pre_seq_len $PRE_SEQ_LEN \\'+'\n') f.write(' --quantization_bit 4'+'\n')
Linux
的 vim
命令修改文件,尝试了如下两种写入方法都会报错:
!echo CHAT_TRAIN_DATA=/kaggle/working/train.json >> train_chat.sh
!echo CHAT_VAL_DATA=/kaggle/working/dev.json >> train_chat.sh
!echo CHECKPOINT_NAME=/kaggle/working >> train_chat.sh
Linux
支持的格式。%%writefile /kaggle/working/ChatGLM-6B/ptuning/train_chat.sh
!pip install rouge_chinese -i https://pypi.tuna.tsinghua.edu.cn/simple
wandb
,执行默认操作,等待模型运行完毕即可。!wandb off
!bash train_chat.sh
# 首先载入Tokenizer import torch from transformers import AutoTokenizer, AutoModel, AutoConfig # 加载 Checkpoint config = AutoConfig.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True) # pre_seq_len要设置成微调时候的大小 config.pre_seq_len = 128 tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True) model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True, config=config).half() # 本次微调得到的glm权重 prefix_state_dict = torch.load('./ChatGLM-6B/ptuning/checkpoint-100/pytorch_model.bin') new_prefix_state_dict = {} for k, v in prefix_state_dict.items(): if k.startswith("transformer.prefix_encoder."): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) # 据需求可以进行量化 model = model.quantize(4) model = model.half().cuda() model.transformer.prefix_encoder.float() model = model.eval() # 测试是否部署完成 response, history = model.chat(tokenizer, '''请搜索:吉林批发和零售业的主板B股的首创环保的信息。''', history=[]) print(response)
.json
文件。.json
文件压缩提交 参赛提交gradient_accumulation_steps
参数,在多个样本后更新参数。Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。