赞
踩
我们的流程走到了,环境准备完毕。
装完依赖之后,上节结果为:
LoRA的核心思想是在保持预训练模型的大部分权重参数不变的情况下,通过添加额外的网络层来进行微调。这些额外的网络层通常包括两个线性层,一个用于将数据从较高维度降到较低维度(称为秩),另一个则是将其从低维度恢复到原始维度。这种方法的关键在于,这些额外的低秩层的参数数量远少于原始模型的参数,从而实现了高效的参数使用。
在fintuning_demo
目录下的 config
ds_zereo_2
/ ds_zereo_3.json
: deepspeed
配置文件。lora.yaml
/ ptuning.yaml
/ sft.yaml
: 模型不同方式的配置文件,包括模型参数、优化器参数、训练参数等。这里选择LoRA
,配置文件中的参数描述如下:
这里主要使用 finetune_hf.py
该文件进行微调操作。其中的参数
同多机多卡
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml
python finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml
官方微调目录:/root/autodl-tmp/ChatGLM3/finetune_demo
配置文件目录:/root/autodl-tmp/ChatGLM3/finetune_demo/configs
,当中我们关注lora.yaml
训练的话,参数中需要一个:数据集的路径
官方推荐的数据集是:AdvertiseGen
,我们需要对其进行一些转换,才可以适配 ChatGLM3-6B
下载方式:
{"conversations": [{"role": "user", "content": "类型#裙*裙长#半身裙"}, {"role": "assistant", "content": "这款百搭时尚的仙女半身裙,整体设计非常的飘逸随性,穿上之后每个女孩子都能瞬间变成小仙女啦。料子非常的轻盈,透气性也很好,穿到夏天也很舒适。"}]}
官方提供了一个脚本来支持我们进行转换,运行下面的脚本。
import json from typing import Union from pathlib import Path def _resolve_path(path: Union[str, Path]) -> Path: return Path(path).expanduser().resolve() def _mkdir(dir_name: Union[str, Path]): dir_name = _resolve_path(dir_name) if not dir_name.is_dir(): dir_name.mkdir(parents=True, exist_ok=False) def convert_adgen(data_dir: Union[str, Path], save_dir: Union[str, Path]): def _convert(in_file: Path, out_file: Path): _mkdir(out_file.parent) with open(in_file, encoding='utf-8') as fin: with open(out_file, 'wt', encoding='utf-8') as fout: for line in fin: dct = json.loads(line) sample = {'conversations': [{'role': 'user', 'content': dct['content']}, {'role': 'assistant', 'content': dct['summary']}]} fout.write(json.dumps(sample, ensure_ascii=False) + '\n') data_dir = _resolve_path(data_dir) save_dir = _resolve_path(save_dir) train_file = data_dir / 'train.json' if train_file.is_file(): out_file = save_dir / train_file.relative_to(data_dir) _convert(train_file, out_file) dev_file = data_dir / 'dev.json' if dev_file.is_file(): out_file = save_dir / dev_file.relative_to(data_dir) _convert(dev_file, out_file) convert_adgen('data/AdvertiseGen', 'data/AdvertiseGen_fix')
最终数据输出到:data/AdvertiseGen_fix
中。下面我们开始微调。
下面我们使用命令来进行微调:
CUDA_VISIBLE_DEVICES=0 /root/.pyenv/shims/python finetune_hf.py /root/autodl-tmp/data/AdvertiseGen_fix THUDM/chatglm3-6b configs/lora.yaml
正常训练
训练结束
CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=0 /root/.pyenv/shims/python inference_hf.py output/checkpoint-2000/ --prompt "类型#裙*版型#显瘦*材质#网纱* 风格#性感*裙型#百褶*裙下摆#压褶*裙长#连衣裙*裙衣门襟#拉链*裙衣门襟#套头*裙款式#拼接*裙款式#拉链*裙款式#木耳边*裙款式#抽褶*裙款式# 不规则"
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。