赞
踩
chatglm2-6b, chatglm-6b微调/LORA/推理
源码地址:https://github.com/yongzhuo/ChatGLM2-SFT
1. torch>=2.0, 否则微调会报很多错误(单纯推理可以用低版本); 2. tokenizer.encode输出为 [gMASK, sop, 真实文本token] 64789 = {str} '[MASK]' 64790 = {str} '[gMASK]' 64791 = {str} '[sMASK]' 64792 = {str} 'sop' 64793 = {str} 'eop' 3. modeling_chatglm.py自带get_masks()的代码full_attention_mask -= padding_mask.unsqueeze(-1) - 1改为 full_attention_mask = full_attention_mask.long() - padding_mask.unsqueeze(-1).long() - 1 4. 不支持gradient_checkpointing, 修复的话需要modeling_chatglm.py新增get_input_embeddings, set_input_embeddings; 5. modeling_chatglm.py中的ChatGLMForConditionalGeneration类forward函数中的 if full_attention_mask is None: 前加入 batch_size, seq_length = input_ids.shape 6. get_mask(), 一直以来都对chatglm的mask/position有一些疑惑; def get_masks(seq, bos_token_id): """ code from model_chatglm.py """ if seq.count(bos_token_id) == 2: context_length = seq[2:].index(bos_token_id) + 2 else: context_length = seq.index(bos_token_id) attention_mask = torch.ones((1, len(seq), len(seq))) attention_mask.tril_() attention_mask[..., :context_length] = 1 # attention_mask.unsqueeze_(1) attention_mask = (attention_mask < 0.5).bool() return attention_mask 7. 严格按照官方prompt构建输入输出: 输入:"[Round 1]\n\n问:{}\n\n答:" 输出:"{}" 输入id: [gMASK, BOS, 输入tokens] 输出id: [gMASK, BOS, 输出tokens, EOS]
transformers==4.27.1
torch>=2.0
sentencepiece
cpm_kernels
mdtex2html
accelerate
protobuf
gradio
地址: chatglm2_6b/ft_chatglm2
配置: chatglm2_6b/ft_chatglm2/config.py
训练: python train.py
推理: python predict.py
验证: python evaluation.py
接口: python post_api.py
本项目相关资源仅供学术研究之用,使用涉及第三方代码的部分时,请严格遵循相应的开源协议。模型生成的内容受模型计算、随机性和量化精度损失等因素影响,本项目不对其准确性作出保证。对于模型输出的任何内容,本项目不承担任何法律责任,亦不对因使用相关资源和输出结果而可能产生的任何损失承担责任。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。