赞
踩
一、目录
1.原理 https://github.com/huggingface/peft
pip install loralib
pip install peft : 高效微调 Parameter-Efficient Fine-Tuning (PEFT)
2. 代码本质实现。
3. 常规大模型LoRA 微调训练。 低阶自适应 lora
4. 参数含义
5. lora 前后模型变化
二、实现
1 原理
通过低维度的矩阵 计算 代替高维度的矩阵运算,提高训练速度。
在原始 PLM (Pre-trained Language Model) 旁边增加一个旁路,做一个降维再升维的操作,来模拟所谓的intrinsic rank。
2 代码本质实现:
import loralib as lora
import torch.nn as nn
class Model(nn.Module):
def __init__(self,in_feature,d_dim,n_class):
super(Model, self).__init__()
self.layer1=lora.Linear(in_feature,d_dim,r=16) #本质:Lora 将nn.Linear() 二次封装,训练与常规模型一样
self.layer2=lora.Linear(d_dim,n_class,r=16)
self.relu=nn.ReLU()
self.log_softmax=nn.LogSoftmax(dim=1)
def forward(self,x):
x=self.layer1(x)
x=self.relu(x)
x=self.layer2(x)
return self.log_softmax(x)
3、常规大模型LoRA 微调训练。
https://www.philschmid.de/fine-tune-flan-t5-peft 代码实例。
4、lora 本质:将模型进行降阶,其余与常规模型一样。
class TaskType(str, enum.Enum): #lora 任务类型
SEQ_CLS = "SEQ_CLS"
SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM"
CAUSAL_LM = "CAUSAL_LM"
TOKEN_CLS = "TOKEN_CLS"
from peft import get_peft_model, LoraConfig, TaskType peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, # 设置任务类型,固定值, inference_mode=False, # 设置推理模式为 False r=8, # 设置 PEFT 模型的秩为 8 lora_alpha=32, # 设置 LORA 的 alpha 参数为 32 lora_dropout=0.1, # 设置 LORA 的 dropout 参数为 0.1 ) # 加载模型 model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) model = get_peft_model(model, peft_config) # 打印模型参数 model.print_trainable_parameters() # output: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282
5、参数含义
lora_rank(int,optional): LoRA 微调中的秩大小。这里并不是越大越好,对于小型数据集如果r=1就可以达到很不错的效果,即便增加r得到的结果也没有太大差别。
lora_alpha(float,optional): LoRA 微调中的缩放系数。
lora_dropout(float,optional): LoRA 微调中的 Dropout 系数。
learning_rate(float,optional): AdamW 优化器的初始学习率。如果设置过大会出现loss值无法收敛或过拟合现象即过度适应训练集而丧失泛化能力,对非训练集中的数据失去原本的计算能力。
num_train_epochs(float,optional): 训练轮数,如果loss值没有收敛到理想值可以增加训练轮数或适当降低学习率。
from transformers import AutoModelForCausalLM
from peft import get_peft_model, LoraConfig, TaskType
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, #
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
model = AutoModelForCausalLM.from_pretrained("gpt2")
print(model)
model = get_peft_model(model, peft_config)
print(model)
6. lora 接口解读
6.1 代码示例: https://github.com/huggingface/peft
6.2 接口示例整合:https://huggingface.co/docs/transformers/peft
6.3 peft 接口示例:https://huggingface.co/docs/peft/package_reference/lora
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。