赞
踩
cd ChatGLM3/finetune_chatmodel_demo
git clone https://github.com/tangqiaoyu/ToolAlpaca.git
pip install transformers==4.30.2 accelerate sentencepiece astunparse deepspeed
./scripts/format_tool_alpaca.py --path "ToolAlpaca/data/train_data.json"
formatted_data/tool_alpaca.jsonl
./scripts/finetune_pt_multiturn.sh
- 如果出现显存不足的报错提示,需要修改
finetune_pt_multiturn.sh
- 把MAX_SEQ_LEN=2048改成MAX_SEQ_LEN=1024,MAX_SEQ_LEN会影响输入文本的长度限制
- 使用2张16G的T4显卡,每张显卡都需要加载完整的模型,只是把任务分成2部分
- 使用MAX_SEQ_LEN=2048需要单张显卡21G以上
数据量 | MAX_STEP x BATCHSIZE x gradient_accumulation_steps |
---|---|
100 | 500 |
1000 | 3000 |
100000 | 100000 |
checkpoint
的路径在: output/tool_alpaca_pt-20240104-184837-128-2e-2
cd ChatGLM3/finetune_chatmodel_demo
curl -O https://cloud.tsinghua.edu.cn/seafhttp/files/93349217-b0ae-4b3e-875e-303fa05d7f08/AdvertiseGen.tar.gz
# 解压下载的文件
tar zxvf AdvertiseGen.tar.gz
./scripts/format_advertise_gen.py --path "AdvertiseGen/train.json"
./scripts/finetune_pt.sh
加载pt
训练微调后的checkpoint
的关键代码
MODEL_PATH = os.environ.get('MODEL_PATH', '/ChatGLM3/THUDM/chatglm3-6b-32k')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
PT_PATH = os.environ.get('PT_PATH', '/ChatGLM3/finetune_chatmodel_demo/output/tool_alpaca_pt-20240105-185333-128-2e-2')
PT_PRE_SEQ_LEN = 128
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
if PT_PATH is not None and os.path.exists(PT_PATH):
config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True, pre_seq_len=PT_PRE_SEQ_LEN)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, config=config, device_map="auto").eval()
prefix_state_dict = torch.load(os.path.join(PT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
print("Loaded from pt checkpoints", new_prefix_state_dict.keys())
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
return tokenizer, model
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。