当前位置:   article > 正文

小白也能微调大模型:LLaMA-Factory使用心得

llama-factory

大模型火了之后,相信不少人都在尝试将预训练大模型应用到自己的场景上,希望得到一个垂类专家,而不是通用大模型。

目前的思路,一是RAG(retrieval augmented generation),在模型的输入prompt中加入尽可能多的“目标领域”的相关知识,引导模型在生成时尽量靠拢目标领域,运用prompt中给予的目标知识;二是有监督微调,用适量的专业领域的数据(或混通用语料)让模型更能生成目标场景的内容。本文主要讲的就是微调。

什么是LLaMA-Factory

当我们想要微调大模型的时候,一个粗略的实验过程无外乎以下几个环节:

  1. 准备好硬件(GPU)、数据;通过各方面的资讯选中你想要微调的基座模型
  2. 准备好代码:输入数据 + 模型 -> 在GPU上反复训练
  3. 训练结束以后,得到训练过程中的checkpoint + 一些log信息
  4. 根据log信息选一些比较有希望的checkpoint在自己的测试集上推理,获得相应的结果
  5. 分析结果,获得下一轮实验(数据、训练方案的迭代)思路

而LLaMA-Factory就是一个很好的负责step 2的工具(当然它能做到的远不止step2,我们后面也会提),你可以理解为,他是一份写好的代码,你只需要把你准备好的数据、硬件、模型,以传参的方式传入,运行代码,模型就开始训练了。等训练结束以后,你把训练好的模型、测试集、硬件又作为参数传入,它就会帮你推理。

LLaMA-Factory的优点

LLaMA-Factory非常适合实验阶段使用,因为:

  1. 支持很多种开源大语言模型:

    实验阶段我们肯定有好几个觉得靠谱的模型,它们往往有自己的标准输入模板(尤其是代码补全这类任务,涉及较多的special token),你想试试的模型LLaMA-Factory基本都支持,通过template参数可以很方便地指定prompt的模板

  2. 支持非常多种训练方法:

    全量调参 vs Lora vs … 或预训练模型 vs 有监督fine-tuning,以及DPO PPO的对齐方案。

    你想试试的基本也都有,也是通过指定训练模式参数即可

  3. Log:

    训练过程中记录的内容比较全,除了同步能够输出loss曲线图以外,还自带bleu等评测指标

  4. 测试环节也很方便:

    支持merge model(比如微调后的adapter合并到原模型以便作为一个模型导出推理);

    支持各种时下比较流行的量化加速方案;

    支持vllm等高并发要求的推理框架;

    需要的话还可以快速搭建一个Gradio UI用于demo展示或可视化分析

使用心得

我没有用过LLaMA-Factory的全部功能,本文暂且以基本的微调任务为阐述重点,会覆盖上面提到的输入:数据 + GPU + 模型,输出微调后模型的使用。看完以后,应该基本就能完成任意一个支持的开源模型的微调任务了。此外,本文也会涉及一小部分LLaMA-Factory的代码文件目录讲解,方便你更好地探索其他的功能相关的参数来实现你的目标任务。

环境准备

首先是需要git clone两个文件目录,一个是目标大模型的仓库(包含模型权重文件等),一个是llama-factory的仓库

然后,我们通常会在两个地方遇到相关依赖的版本要求:

  1. llama-factory的Github仓库主页下,README的Requirement部分(目前已经很贴心地标注了最低要求和推荐要求),以及代码结构目录中的requirements.txt
  2. 想要使用的目标大模型的Huggingface或Github主页下,同样README部分、代码结构目录中的requirements.txt两个部分都会有相关依赖的版本要求

一般,我们以尽快跑通我们的实验目标为目的。

如果是自己掌控度比较高的环境(自己的GPU),装包装cuda什么的都比较擅长:

  1. 检查llama-factory主页README中的要求,把几个依赖库的版本检查一下,保持在规定范围

    这主要是因为llama-factory的requirements.txt里面的相关依赖可能比较多,你不一定会用到llama-factory的所有功能

  2. 基于目标大模型文件目录中的requirements.txt,使用pip install -r requirements.txt

    这主要是因为,这个文件中基本包含这个模型要运行起来的所有依赖

  3. 尝试运行,缺啥补啥

    这里建议按照目标大模型主页的quick start,写一个简单的脚本就可以了

如果是实验室的服务器或者公司的服务器这种掌控度小的场景,记得要自己创建一个虚拟环境,或者起一个自己的docker容器,在虚拟环境或docker容器内操作,具体使用conda还是docker取决于你们公共服务器的权限管理,哪个方便用哪个,或者其他人平常用什么你就用什么。

数据集准备

模型能够跑通以后,我们准备用于微调的数据集。

这里需要理清楚几个概念

数据的内容组织方式,取决于训练场景的输入和输出。

通常一个样本由(输入,输出)的pair构成,场景上主要是下面3种(更多的可以参考readme里关于数据集准备的部分)

  1. 预训练场景:在一句话里并没有特别关注某个位置的内容,想要提升整个训练集语料上的general效果,此时对于GPT架构的模型,一般使用的输入和输出是相同的,所以如果我们指定了训练模式为预训练,那么llama-factory会自动copy输入内容作为输出label的
  2. sft(有监督微调supervised fine tuning):这个场景下,我们特别关注字符串里某个位置的内容,想要针对性地提升。比如NL2SQL专门去调SQL部分的风格或者内容,那就可以只把SQL部分作为输出,NL部分作为输入,而不是把NL+SQL一整句话作为输入和输出;再比如代码补齐场景,一般前文后文作为输入,补齐的中间部分作为输出,针对补齐部分做loss的计算
  3. 偏好对齐场景:主要是输出部分会有两个label,一个更好的,一个更差的,主要是适应DPO等热门的微调方法,模型不光可以从具体的label中学习,还可以通过两个label的差距来学习,目前后者带来的效果大体上更好,学习的目标更精细,有很多文章可以按兴趣去学习。

无论哪种场景,我们都可以按照llama-factory要求的标准格式组织数据集,保存成一个文件,比如下面这种.json文件:

[
  {
    "instruction": "user instruction (required)",
    "input": "user input (optional)",
    "output": "model response (required)",
    "system": "system prompt (optional)",
    "history": [
      ["user instruction in the first round (optional)", "model response in the first round (optional)"],
      ["user instruction in the second round (optional)", "model response in the second round (optional)"]
    ]
  }
]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

这里的instruction,input,output等键实际上不重要,你想叫什么都可以,关键在于你要把这个数据集的相关信息,注册到/data/dataset_info.json中,什么叫“注册”呢,就是说参照它里面已经有的数据集注册信息的格式,再添加一个键到其中,比如:

 "给这个数据集取一个名字(传参时使用)": {
    "file_name": "把你的数据集按照上面说的保存成一个文件,也放在这个目录下,这里填文件的名字,如xxx.json",
    "file_sha1": "可以用一些算sha1值的函数对文件算一下,也可以省略这个键",
    
    # 这里是关键,llama-factory实际使用的是prompt,query这些键,你要在这里完成映射关系的描述,这也是上面说 instruction,input,output这些键你想叫什么名字都可以的原因
    "columns": {
      "prompt": "instruction",
      "query": "input",
      "response": "output",
      "history": "history"
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

这其中,query对应的列的内容会拼接在prompt列对应的内容后面,变成{prompt\nquery}一起作为模型的输入,response列表示这个样本的期望输出,会用于计算loss

在启动时,通过--stage参数告诉llama-factory用什么方式使用这个数据集,比如说--stage pt,那么llama-factory就只会使用prompt列对应的内容,response列的内容会忽略(我们说了pt模式的输入和输出一般是一样的)。因此,我们要做的是根据我们的场景,是预训练(prompt),有监督(query prompt response)还是强化学习(response)等,是多轮对话(history)还是单纯的补全,把它们会涉及的数据对应的键,都映射到正确的数据集里的键上,具体的参考LLama-factory/data/README.md即可。

可能遇到的问题

我们在构造数据集.json文件的时候,可能有的人会用一个这样的脚本,伪代码如下:

# 参考官方示例创建一个空list存样本
samples = []
# 遍历自己的原始数据源
for data in src:
	# 各种处理
	process...sample...
	
	# 处理完了以后变成一个字典结构
	sample = {
		'instruction': xxxxx,
		'input': ....,
		'output': xxxx,
	}
	
	samples.append(sample)

# 保存成.json文件
with open(.....) as wf:
	json.dumps(samples, wf)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

这里面有两个问题:

  1. 如果你的数据集特别大,或者一个样本包含的信息特别多,内存里要维护一个超级大的list,可能会导致你处理过程中就内存溢出了
  2. json.dumps实际上会构造出一个超长string,llama-factory里面的读取函数可能是基于transformers.load_datasets,这个函数使用pyarrow去读取字符串,读特别长的json会卡住,我就遇到了load不报错但是也不运行的情况

实际上,我们并不一定非要构造.json数据集,构造一个.jsonl数据集也是完全可以运行的,并且pyarrow更喜欢,伪代码如下

with open(保存数据集,'w', encoding='utf-8') as wf:
	# 遍历自己的原始数据源
	for data in src:
		# 各种处理
		process....sample...
		
		# 处理完以后变成一个字典结构
		sample = {
			....
		}
		
		wf.write(json.dumps(sample, ensure_ascii=False) + '\n')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

这样就行了,既不用维护一个巨大的list在内存中,也不用担心读取的时候出问题,使用方法上没有任何变化,还是一样把这个文件注册到dataset_info.json里面。

训练参数

template参数:决定数据集中的prompt和response如何连接

数据准备好以后,最重要的就是--template参数了,每个大模型都有一些自己的special tokens,比如对话大模型往往会有user,assistant这样的标识,把这些标识插入到prompt和response之间,才能构成一个完整的模型输入。而我们准备数据的时候,只需要准备prompt,response的内容,这些标识是不需要我们插入的。

不要想当然地把模型的名字填上去,不如亲自去源码里看看

我们前往/src/llmtuner/data/template.py能够看到所有的template参数都会对我们的数据做什么,比如下面是deepseek两个模板对应的处理
在这里插入图片描述

有的时候我们想要使用的标识和提供的template不同,比如我们使用的base模型,其预训练任务就是直白的输入+输出,没有什么user,assistant;再比如我们做代码补全任务,有自己的special token想要插入,prompt形如<fim-prefix>xxxx<fim-suffix>xxxx<fim-middle>

这种情况我比较推荐使用vanilla作为template的参数,没有做任何的添加,我们可以在构造数据集的阶段,就自由地把special tokens和我们的文本内容拼接好,一起放入prompt对应的键中。
在这里插入图片描述

注意不是default,default仍然会添加Human, Assistant
在这里插入图片描述

资源相关的参数

官方的readme里面给出的单卡、多卡训练示例已经很详细了,这里不多赘述。

多卡训练时,可能没接触过的人会有些疑惑ds_config.json写什么样子

deepspeed --num_gpus 8 src/train_bash.py \
    --deepspeed ds_config.json \
    --ddp_timeout 180000000 \
    ... # arguments (same as above)
  • 1
  • 2
  • 3
  • 4

这里对于初学者的实验阶段,我建议直接copy官方示例构造一个ds_config.json也可以,直接去掉这个参数也可以,先跑通,再根据实际需要回来调整,一步一步来,总会逐渐了解的。

比较常见的是使用实验室或者公司的公共服务器的场景,需要指定在哪些卡上面训练,添加include按如下形式指定即可:

deepspeed --num_gpus 2 --master_port=9901 --include localhost:2,3

其他的和多卡训练相关的参数,比如每张卡的batch size等,理解都比较直接,自行查阅。运行起来以后,根据每个batch差不多要的时间估算一下,再根据自己的耗时需求调整即可。

这里可能会遇到一个小问题:起训练任务起失败了以后,master_port显示被占用,不得不换一个port。这主要是因为任务挂了以后,程序清理不干净,比如网络方面还占着。建议你使用ps -ef | grep 你的用户名 去检查是不是还有和刚刚那个挂掉的任务相关的进程,kill掉即可

其他参数

实际上llama-factory有很多训练参数可以设置,并不局限于示例中给出的参数,你应该积极地去/src/llmtuner/hparams这里看看,以实现你的需求

这里举例几个我的需求:

  1. 我的数据集非常大,需要加速generating train split的过程。

    增加--num_workers 8

  2. 我希望从训练集里面切一个小验证集,在训练过程中每个epoch结束的时候,在验证集上eval一遍,并保存结果在log中

​ 增加--val_size 0.01 --evaluation_strategy epoch等,思路和transformers的trainer差不多

  1. 同时,训练完毕输出的loss曲线图片,我也需要这个验证集上的loss曲线

    设置--plot_loss参数

  2. greedy推理

    如果你去了我说的位置看,就会发现默认temperature=0.95, top_p=0.7,如果我希望unset,并设置do_sample=False,需要设置二者为1.0(transformers的default就是这样)
    在这里插入图片描述

另外一个最重要的事情是,学会在issue中搜索,会有很多同样的问题已经得到了解答

训练完毕后的推理

训练完毕后,我们要在自己的测试集或者公用测试集上测试模型的效果:

如果是公用测试集,比如MMLU之类的任务,你可以直接参考官方的readme,使用evaluation即可

如果是自己的测试集,我们往往需要测试这个测试集上的指标,同样的:

  1. 首先把你的测试集,像训练集一样构造成一个.json / .jsonl,并注册到dataset_info.json中
  2. 参考官方的demo predict的部分即可

如果不急于算指标,只是想要看看具体的case,官方也提供了命令行demo和浏览器的demo,照抄着改就可以了,应该不难。

我遇到的问题是,我的测试集特别大,使用predict非常慢,并且得到的结果只保存了label和结果,我希望快速推完测试集,并且保存好输入、输出、label,这样方便我后续自己的可视化。这种场景,最好是将模型部署成一个服务,自己写一个脚本去发请求,边发请求边把自己想要保存的样本保存到文件中。

关于模型部署,为了避免本文过长就暂时不多叙述,可以用以下方法:

  1. TGI部署:使用llama-factory的模型导出功能(如果你是lora微调的就会顺便merge weights),将导出的模型用TGI部署。

​ 比较推荐,TGI部署很简单,适配非常多种模型,所以很适合实验阶段使用

  1. vllm部署:vllm相比之下更适合生产环境,面向高并发的真实场景,如果涉及前后处理策略、时延等方面的测试,建议保持和线上一致
  2. llama-factory的api demo,其实应该差不多就是把上面的封装了,我没有使用过,但是思路是一样的,把模型变成一个服务,可以参考官方的demo学习使用。

总结

本文介绍了初学者如何使用llama-factory这个工具进行大模型的微调任务,包含用自己的数据构造训练、测试数据,起训练任务时候的相关参数,训练完毕后的测试集推理环节等,虽然可能不够全面,但多多少少以授人以渔的思路介绍了应该去源码的什么位置获得进一步的信息,以满足文中没有覆盖到的需求。

值得一提的是,因为作者的使用经验也有些时日,llama-factory也一直保持着更新,文中的一些内容可能有谬误,一切以源码为主。初学者应该养成看源码、看issue,自行找答案的能力。

如果有任何我理解有误的地方,还希望多多指正,感激不尽!下一个新技能点再见!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/神奇cpp/article/detail/834032
推荐阅读
相关标签
  

闽ICP备14008679号