赞
踩
大模型微调技术——LoRA
大模型在预训练完收敛后模型包含很多进行矩阵乘法的稠密层,这些层通常都是满秩的,在微调过程中改变量其实也很小,在矩阵乘法中表现为低秩的改变。所以对于预训练权重矩阵
W
0
∈
R
d
×
k
W_0 ∈R^{d×k}
W0∈Rd×k, 可以用低秩分解表示更新过程:
W
0
+
▲
W
=
W
0
+
B
A
,
B
∈
R
d
×
r
,
A
∈
R
r
×
k
,
r
<
<
m
i
n
(
d
,
k
)
W_0 +▲W = W_0 +BA,B∈R^{d×r},A∈R^{r×k},r<<min(d,k)
W0+▲W=W0+BA,B∈Rd×r,A∈Rr×k,r<<min(d,k)。训练时,W_0是冻结的,没有梯度更新,A和B是可训练的,然后他们都会乘以相同的输入x:
h
=
(
W
0
+
▲
W
)
x
=
(
W
0
+
B
A
)
x
=
W
0
x
+
B
A
x
h = (W_0 +▲W)x= (W_0 +BA)x = W_0x +BAx
h=(W0+▲W)x=(W0+BA)x=W0x+BAx
注意三点:
知道以上这些,仍然不能帮助我们理解究竟LoRA怎么用到大模型训练上。首先要有几点认识
1、 理论上LoRA可以用到不同类型神经网络的权重矩阵,减少可训练的参数量
2、 不同的大语言模型基于Transformer的不同改版,Transformer结构主要self attention层(W_q,W_k,W_v,W_o)和MLP层。
本质上LoRA对网络层的权重矩阵进行低秩学习(d* k—>d* r , r*k)
stable Diffusion论文截图(对k,v,q矩阵的微调学习):
LLAMA2(来自BIT_666的博客):
官方提供微调层的建议:
‘q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj’
对应q,v,k,o 输出,还有MLP层的结构可以加入LoRA模块进行微调
LoRA微调主要关注的几个问题:
CUDA_VISIBLE_DEVICES=0: 单卡运行。
do_train: 是否执行训练。
model_name_or_path: 预训练模型路径。
dataset_dir: 训练数据存储目录。
dataset: 训练数据集名称,可在 data/dataset_info.json 中增加自定义数据集。
output_dir: 微调后的模型保存路径。
source_prefix: 训练时每个输入序列添加的前缀,可为空。
max_source_length: 输入序列的最大长度,即 source_prefix + instruction + input 的长度。
max_target_length: 输出序列的最大长度,即 output 的长度。
per_device_train_batch_size: 用于训练的批处理大小。可根据 GPU 显存大小自行设置。
gradient_accumulation_steps: 梯度累加次数。
logging_steps: 多少步输出一次 log。
save_steps: 多少步保存一次参数。
learning_rate: AdamW 优化器的初始学习率。设置过大会出现loss值无法收敛或过拟合现象即过度适应训练集而丧失泛化能力,对非训练集中的数据失去原本的计算能力
**lr_scheduler_type": 学习率策略,可以设置cosine(月线退火策略)
num_train_epochs: 训练轮数(若非整数,则最后一轮只训练部分数据)如果loss值没有收敛到理想值可以增加训练轮数或适当降低学习率
plot_loss: 微调后绘制损失函数曲线,图片保存在 output_dir 中 。
fp16: 使用半精度(混合精度)训练。
lora_target: 大模型内将要进行 LoRA 微调的模块名称。
lora_rank: LoRA 微调中的秩大小。
padding_side: pad对齐方式,左对齐或者右对齐。
LoraConfig:创建 LoRA 微调方法对应的配置【比较影响模型的效果】
task_type:指定任务类型。如:条件生成任务(SEQ_2_SEQ_LM),因果语言建模(CAUSAL_LM)等。
r: LoRA低秩矩阵的维数。关于秩的选择。
lora_alpha: LoRA低秩矩阵的缩放系数,为一个常数超参,调整alpha与调整学习率类似。
lora_dropout:LoRA 层的丢弃(dropout)率,取值范围为[0, 1)。
target_modules:要替换为 LoRA 的模块名称列表或模块名称的正则表达式。
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules = ["attn.c_proj", "attn.c_attn"]
)
解释:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。