当前位置:   article > 正文

P-tuningv2微调ChatGLM2及微调细节剖析_p-tuning v2

p-tuning v2

本文基于ADGEN广告文本数据集,采用P-tuningv2技术微调ChatGLM2的简单案例,深入源码剖析微调细节,把握微调核心。

1.数据集介绍

ADGEN(广告数据集):希望通过一段吸引的多样的措辞来描述该产品,吸引用户购买,是一个广告文本。每条数据包括一个产品的广告文本,文件保存的格式是{“content”:“”,“summary”:“”},“content"包含产品描述的属性值对,即希望生成该产品的哪些描述属性,例如如果是上衣,希望可以给定"材质”,“颜色”,“风格”,“图案”等描述生成包括以上属性描述的一段广告文本,“summary”对应特定产品的限定属性生成的广告示例。
在这里插入图片描述
该数据集可以通过 Google Drive 或者 Tsinghua Cloud 下载。 解压后有两个json文件,包括训练文件(train.json)和测试文件(test.json)

2. 模型准备

2.1 ChatGLM2-6b简介

  • 清华大学的开发的开源中英双语对话模型 ChatGLM-6B 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM2-6B 具有新的特性:
    • 更强大性能:ChatGLM2升级基座模型,使用GLM的混合目标函数。
    • 更长的上下文:基于 FlashAttention 技术,有初代模型的2K------》扩展到32K,并在对话阶段使用 8K 的上下文长度训练,允许更多轮次的对话。
      • 缺点:当前版本的 ChatGLM2-6B 对单轮超长文档的理解能力有限
    • 更高效的推理:基于 Multi-Query Attention 技术,推理速度提升了 42%

2.2 安装与配置

  1. 安装ChatGLM-6B模型
    https://github.com/THUDM/ChatGLM-6B 下载zip安装包后解压到服务器上
  2. 在anaconda 创建虚拟环境conda create -n <环境名称> python=3.X,并配置:
  • 激活刚创建的虚拟环境 conda activate <环境名称>
  • chatglm2模型下载并解压完成后,cd 进入 requirements.txt 所在的文件目录
  • 终端执行pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

3. 微调及微调细节

3.1 参数设置说明

跟训练相关的参数设置主要在train.sh文件,推理相关的参数在evaluate.sh文件。
训练参数 总结如下,推理部分参数与训练部分大致相同。
在这里插入图片描述
注意:

  1. arguments.py详细说明模型参数类相关的默认值设置(ModelArguments类)和数据训练参数类的默认值设置(DataTrainingArguments类)
  2. 请不要在train.sh 和 evaluate.sh 中用“#”设置注释,导致该变量参数失效,系统无法识别。

3.2 LLM输入和标签构建说明

train.sh 和 evaluate.sh 设置的prompt_column 和 response_column变量 最后经过main.py process_function_train函数构成了符合LLM的输入和输出。通过下面红框的这行代码,我们知道了json数据集中键值对是如何产生了LLM所需要的输入和输出。这启发我们如何定制私人数据集相应的LLM输入和输出。
在这里插入图片描述
可以看到query:查询问题的文本是由examples[prompt_column][i]得到,answer:回复是由examples[response_column][i]得到。

例如:
1 . 如果是多段输入进行拼接,得到一段输出的情况:
假设 josn数据集是有{“question”,“prompt”,“answer”}三段描述,其中我们希望query是由"question"+“prompt"构成,answer对应"answer”。这样设置的初衷可能是因为不同的question对应了不同的prompt文本。

我们可以首先对:
训练文件:train.sh 对应的部分进行修改:

--prompt_column [["question","prompt"]] \
--response_column answer\
  • 1
  • 2

其次对main.py 文件对应的的preprocess_function_train部分:

for i in range(len(examples[prompt_column])):
    if examples[prompt_column][i] and examples[response_column][i]:
        query = "".join([examples[prompt_column][0][i] for i in range(len(examples[prompt_column][i]))])
        answer =  examples[response_column][i]
  • 1
  • 2
  • 3
  • 4
  1. 如果我们希望对query是有"question"+ 共同的“prompt”构成,answer对应"answer"。这样设置的初衷可能是希望让它更好学习某种特定规则的输出。
    训练文件:train.sh 对应的部分进行修改:
  --prompt_column question \
  --response_column answer \
  • 1
  • 2

其次对main.py 文件对应的的preprocess_function_train部分:

for i in range(len(examples[prompt_column])):
    if examples[prompt_column][i] and examples[response_column][i]:
        query, answer = examples[prompt_column][i]+"回答时,请加上前缀‘你好’。", examples[response_column][i]
  • 1
  • 2
  • 3

基于以上,有几点认识:

  1. 训练集和测试集分别只有一个json或csv文件。
  2. 每条数据(以json为例)需要包括输入和输出的键值对,但是输入和输出的键值对并一定只有一个
  3. 输入和输出键值对名称可以改变,不一定设置为content和summary
  4. 结合train.sh和preprocess_function_train函数,可以更好根据自定义数据集构建合适的LLM输入和标签。

3.3 微调实现

    1. 首先进入配置了ChatGLM的虚拟环境,conda acitivate <环境名>
    1. 终端执行命令进入文件夹 cd <文件夹目录>
    1. 修改train.sh文件参数后,终端执行:bash train.sh(注意,sh文件里不能用#来标注注释,否则会导致该变量系统无法识别)
    1. 查看输出训练日志
      注意:
  1. sh文件不能带注释
  2. model_name_or_path 如果为“THUDM/chatglm-6b” 类似的带有“THUDM”会每次调用都从网络上下载模型,最好改为自己本地保存模型的路径。

输出打印日志,调用成功。
在这里插入图片描述

4. 评估细节

4.1 评估实现

经过p-tuning v2训练后保存的参数只有PrefixEncoder 部分的参数,所以推理时需要加载原始ChatGLM-6B模型参数以及PrefixEncoder的权重。推理部分修改主要在evaluate.sh文件。

修改地方注意两点

  1. 将 evaluate.sh 中的 CHECKPOINT 更改为训练时保存的 checkpoint 名称,同时如果你想要从本地加载模型,可以将model_name_or_path中的THUDM/chatglm-6b改为你本地的模型路径。
  2. 其他参数设置和训练一致。
    执行成功后会输出 各项指标:
    在这里插入图片描述

4.2 评估指标说明

BLEU:

  • BLEU(Bilingual Evaluation Understudy)根据精确率(precision)来衡量机器翻译质量,BLEU 值越高,机器翻译结果与人工翻译结果越相似。

ROUGE:

  • 全称是 (Recall-Oriented Understudy for Gisting Evaluation),根据召回率(recall)来衡量机器翻译质量
    • ROUGE-N: 在 N-gram 上计算召回率
    • ROUGE-L: 考虑了机器译文和参考译文之间的最长公共子序列
    • ROUGE-W: 改进了ROUGE-L,用加权的方法计算最长公共子序列

Q : 什么是n-gram?

  • n代表连续的n个词的组合,它的值可以是1,2,3,或者更高
    • 当n=1时,我们称之为unigram,即1-gram,指单个词语,如句子“我喜欢学习自然语言处理”,1-gram:[“我”, “喜欢”, “学习”, “自然语言处理”, “。”]
    • 当n=2时,我们称之为bigram,即2-gram,指相邻两个词语组合,如句子“我喜欢学习自然语言处理”,2-gram为:[“我喜欢”, “喜欢学习”, “学习自然语言处理”, “自然语言处理。”]
    • 当n=3时,我们称之为trigram,即3-gram,指相邻三个词语组合,如句子“我喜欢学习自然语言处理”,3-gram为:[“我喜欢学习”, “喜欢学习自然语言处理”, “学习自然语言处理。”]
  • 使用n-gram可以捕捉一定长度的上下文信息,有助于更好地理解文本和评估翻译质量。

BLEU 和ROUGE的计算

  • 参考翻译:“今天天气晴朗。”
  • 系统生成的翻译: “今天的天气是晴朗的。”

a. BLEU指标计算:

  • 首先将 参考翻译 和 系统生成的翻译 拆分成 n-gram序列。
    比如:参考翻译的1-gram:[“今天”, “天气”, “晴朗”, “。”] 而 系统生成的翻译的1-gram:[“今天”, “的”, “天气”, “是”, “晴朗”, “的”, “。”]
  • 接下来,计算两者n-gram的匹配数。例如,1-gram中有3个匹配:[“今天”, “天气”, “晴朗”]
  • 计算精确度:系统生成的翻译中n-gram匹配的个数/系统生成的翻译中n-gram的总个数
  • 考虑到较长的翻译可能具有较高的n-gram匹配,使用短文本惩罚(brevity penalty)来调整精确度。防止短翻译在BLEU中得分过高。
  • 计算BLEU得分:BLEU = 短文本惩罚 * exp(1/n * (log(p1) + log(p2) + … + log(pn)))
    其中,p1, p2, …, pn是1-gram, 2-gram, …, n-gram的精确度,n是n-gram的最大长度。

b. ROUGE指标计算:

  • 用于评估文本摘要任务,着重于信息完整性和涵盖程度,所以将参考翻译和系统生成的翻译视为两个文本摘要。
  • 计算系统生成的翻译(Predict)中包含的n-gram在参考翻译(Label)中出现的次数。
  • 计算召回率(recall):将匹配的n-gram总数除以参考翻译中的总n-gram数。
  • ROUGE得分可以根据需要使用不同的n-gram大小,通常使用ROUGE-1、ROUGE-2和ROUGE-L。
    • ROUGE-1 = 召回率(系统生成的1-gram匹配数 / 参考翻译中的1-gram总数)
    • ROUGE-2 = 召回率(系统生成的2-gram匹配数 / 参考翻译中的2-gram总数)
    • ROUGE-L = 最长公共子序列(Longest Common Subsequence,LCSS)的长度 / 参考翻译的总长度

BLEU指标 和 ROUGE指标 取值范围是[0,1],可以在main.py源码中看到,指标计算过程中将数值都乘以100。因此原本取值范围:[0,1],扩大后变成[0,100]。
在这里插入图片描述
另外需要注意的:

  • BLEU-4并不是只看4-gram的情况,而是计算从1-gram到4-gram的累积分数,加权策略为 1-gram, 2-gram, 3-gram和4-gram 的权重各占25%。
  • 默认情况下, sentence_bleu()和corpus_bleu()都是计算累积的4-gram BLEU分数的, 也称之为BLEU-4。
  • BLEU中的smooth_function有7种方法,不同smooth方法结果差异较大,可以查阅论文:https://aclanthology.org/P02-1040.pdf
    在这里插入图片描述
    其他没有注意到的细节,欢迎评论区讨论。源码提供了很多值得深入的细节。
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/277461
推荐阅读
相关标签
  

闽ICP备14008679号