当前位置:   article > 正文

【ChatGLM3】微调指南

【ChatGLM3】微调指南

下载数据集ToolAlpaca

  • 从GitHub下载
cd ChatGLM3/finetune_chatmodel_demo
git clone https://github.com/tangqiaoyu/ToolAlpaca.git
  • 1
  • 2
  • 除基础的 torch 依赖外,示例代码运行还需要依赖: pip install transformers==4.30.2 accelerate sentencepiece astunparse deepspeed
  • 处理数据集格式
./scripts/format_tool_alpaca.py --path "ToolAlpaca/data/train_data.json"
  • 1
  • 处理后的数据: formatted_data/tool_alpaca.jsonl
  • 开始微调: ./scripts/finetune_pt_multiturn.sh
  • 如果出现显存不足的报错提示,需要修改finetune_pt_multiturn.sh
  • 把MAX_SEQ_LEN=2048改成MAX_SEQ_LEN=1024,MAX_SEQ_LEN会影响输入文本的长度限制
  • 使用2张16G的T4显卡,每张显卡都需要加载完整的模型,只是把任务分成2部分
  • 使用MAX_SEQ_LEN=2048需要单张显卡21G以上
  • 参数调整参考
数据量MAX_STEP x BATCHSIZE x gradient_accumulation_steps
100500
10003000
100000100000
  • 训练完成后,checkpoint的路径在: output/tool_alpaca_pt-20240104-184837-128-2e-2

下载数据集AdvertiseGen

  • 从清华大学网站下载
cd ChatGLM3/finetune_chatmodel_demo
curl -O https://cloud.tsinghua.edu.cn/seafhttp/files/93349217-b0ae-4b3e-875e-303fa05d7f08/AdvertiseGen.tar.gz

# 解压下载的文件
tar zxvf AdvertiseGen.tar.gz
  • 1
  • 2
  • 3
  • 4
  • 5
  • 处理数据集格式
./scripts/format_advertise_gen.py --path "AdvertiseGen/train.json"
  • 1
  • 开始训练: ./scripts/finetune_pt.sh

加载PT训练的checkpoint

加载pt训练微调后的checkpoint的关键代码

MODEL_PATH = os.environ.get('MODEL_PATH', '/ChatGLM3/THUDM/chatglm3-6b-32k')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
PT_PATH = os.environ.get('PT_PATH', '/ChatGLM3/finetune_chatmodel_demo/output/tool_alpaca_pt-20240105-185333-128-2e-2')
PT_PRE_SEQ_LEN = 128

@st.cache_resource
def get_model():
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
    if PT_PATH is not None and os.path.exists(PT_PATH):
        config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True, pre_seq_len=PT_PRE_SEQ_LEN)
        model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, config=config, device_map="auto").eval()

        prefix_state_dict = torch.load(os.path.join(PT_PATH, "pytorch_model.bin"))
        new_prefix_state_dict = {}
        for k, v in prefix_state_dict.items():
            if k.startswith("transformer.prefix_encoder."):
                new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
        print("Loaded from pt checkpoints", new_prefix_state_dict.keys())
        model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
    else:
        model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()

    return tokenizer, model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/963288
推荐阅读
  

闽ICP备14008679号