赞
踩
LoRA(Learnable Re-Weighting),是一种重加权模型。LORA模型将神经网络中的每一层看做是一个可加权的特征提取器,每一层的权重决定了它对模型输出的影响。通过对已有的SD模型的部分权重进行调整,从而实现对生图效果的改善。(大部分LoRA模型是对Transformer中的注意力权重Linear层进行了调整,也有部分对Conv2D的卷积核权重进行微调)
LoRA从技术角度来讲很简单,基本流程如下图:蓝色部分表示原来的预训练权重,橙色部分则是lora需要训练的权重A和B。
W = W0+BA
来更新模型权重。普通的Linear LoRA只考虑了Linear层的weight,B矩阵的维度数量len(weight_up.shape)
是等于 A矩阵的维度数量len(weight_down.shape)
的。B的shape是[in_dim, rank]
,A的形状是[rank, out_dim]
,乘完就是[in_dim, out_dim]
。
if isinstance(module, nn.Linear):
assert len(self._up_weight.shape) == len(self._down_weight.shape) == 2
in_dim = module.in_features
out_dim = module.out_features
self._lora_down = nn.Linear(in_dim, self._r, bias=False)
self._lora_up = nn.Linear(self._r, out_dim, bias=False)
self._lora_down.weight = nn.Parameter(self._down_weight)
self._lora_up.weight = nn.Parameter(self._up_weight)
卷积层和全连接层是两种不同的操作。在卷积层的LoRA改造,主要是对卷积核的权重矩阵做改造。明白了这一点,其实卷积层的改造思路跟全连接基本是一致的。
Conv2d的B矩阵的维度数量len(weight_up.shape)
是不等于 A矩阵的维度数量len(weight_down.shape)
的。融合时需要保证最后2个维度不变,对前面2个维度做矩阵的转置乘法。B的shape是[in_dim, rank, kernel=(h, w)]
,A的形状是[rank, out_dim, (1, 1)]
,乘完就是[in_dim, out_dim, kernel=(h, w)]
,即[in_dim, out_dim, h, w]
。
这里的A和B的shape设计是有原因的,因为BA最后的形状要和原始模型中卷积核的权重矩阵(self.weight)一致,所以要根据self.weight的shape来设计:
if isinstance(module, nn.Conv2d):
assert len(self._up_weight.shape) == len(self._down_weight.shape) == 4
r = self._r
in_dim = module.in_channels
out_dim = module.out_channels
kernel = module.kernel_size
stride = module.stride
padding = module.padding
self._lora_down = nn.Conv2d(in_dim, r, kernel, stride, padding, bias=False)
self._lora_up = nn.Conv2d(r, out_dim, (1, 1), (1, 1), bias=False)
self._lora_down.weight = nn.Parameter(self._down_weight)
self._lora_up.weight = nn.Parameter(self._up_weight)
lora模型中每层的权重包含3个部分,分别为.lora_down.weight
、 .lora_up.weight
和 .alpha
。其中down和up分别为lora模型的上下层权重分别对应了B和A权重,alpha也是一个可学习的参数。lora模型每层的权重可表示为:
w
=
a
l
p
h
a
∗
(
d
o
w
n
M
a
t
r
i
x
@
u
p
M
a
t
r
i
x
)
w = alpha∗(downMatrix \ @ \ upMatrix)
w=alpha∗(downMatrix @ upMatrix)
以目前最流行的LCM LoRA
为例进行一下可视化,key如下(部分):
lora_unet_down_blocks_0_downsamplers_0_conv.alpha
lora_unet_down_blocks_0_downsamplers_0_conv.lora_down.weight
lora_unet_down_blocks_0_downsamplers_0_conv.lora_up.weight
lora_unet_down_blocks_0_resnets_0_conv1.alpha
lora_unet_down_blocks_0_resnets_0_conv1.lora_down.weight
lora_unet_down_blocks_0_resnets_0_conv1.lora_up.weight
lora_unet_down_blocks_0_resnets_0_conv2.alpha
lora_unet_down_blocks_0_resnets_0_conv2.lora_down.weight
lora_unet_down_blocks_0_resnets_0_conv2.lora_up.weight
lora_unet_down_blocks_0_resnets_0_time_emb_proj.alpha
lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_down.weight
lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_up.weight
lora_unet_down_blocks_0_resnets_1_conv1.alpha
lora_unet_down_blocks_0_resnets_1_conv1.lora_down.weight
lora_unet_down_blocks_0_resnets_1_conv1.lora_up.weight
lora_unet_down_blocks_0_resnets_1_conv2.alpha
lora_unet_down_blocks_0_resnets_1_conv2.lora_down.weight
lora_unet_down_blocks_0_resnets_1_conv2.lora_up.weight
lora_unet_down_blocks_0_resnets_1_time_emb_proj.alpha
lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_down.weight
lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_0_proj_in.alpha
lora_unet_down_blocks_1_attentions_0_proj_in.lora_down.weight
lora_unet_down_blocks_1_attentions_0_proj_in.lora_up.weight
lora_unet_down_blocks_1_attentions_0_proj_out.alpha
lora_unet_down_blocks_1_attentions_0_proj_out.lora_down.weight
lora_unet_down_blocks_1_attentions_0_proj_out.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_k.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_out_0.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_q.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_v.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_k.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_out_0.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_q.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_v.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_0_proj.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_0_proj.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_0_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_2.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_2.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_2.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_k.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_out_0.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_q.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_v.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_k.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_out_0.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_q.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_v.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_0_proj.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_0_proj.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_0_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_2.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_2.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_2.lora_up.weight
lora_unet_down_blocks_1_attentions_1_proj_in.alpha
lora_unet_down_blocks_1_attentions_1_proj_in.lora_down.weight
lora_unet_down_blocks_1_attentions_1_proj_in.lora_up.weight
lora_unet_down_blocks_1_attentions_1_proj_out.alpha
lora_unet_down_blocks_1_attentions_1_proj_out.lora_down.weight
lora_unet_down_blocks_1_attentions_1_proj_out.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_k.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_out_0.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_q.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_v.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_k.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_out_0.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_q.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_v.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_0_proj.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_0_proj.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_0_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_2.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_2.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_2.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_k.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_out_0.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_q.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_v.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_k.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_out_0.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_q.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_v.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_0_proj.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_0_proj.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_0_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_2.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_2.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_2.lora_up.weight
lora_unet_down_blocks_1_downsamplers_0_conv.alpha
lora_unet_down_blocks_1_downsamplers_0_conv.lora_down.weight
lora_unet_down_blocks_1_downsamplers_0_conv.lora_up.weight
以加载到UNet
的LoRA为例:首先获取unet的原始权重W0(state_dict_unet)
,和LoRA的权重BA(state_dict_lora)
。将对应层的权重
W
=
W
0
+
B
A
W=W0+BA
W=W0+BA,即state_dict_unet[name] += state_dict_lora[name]
。
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
state_dict_unet = unet.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
if len(state_dict_lora) > 0:
for name in state_dict_lora:
state_dict_unet[name] += state_dict_lora[name].to(device=device)
unet.load_state_dict(state_dict_unet)
其中convert_state_dict
就是将down
和up
进行矩阵乘法得到BA
:
B
A
=
a
l
p
h
a
∗
(
d
o
w
n
M
a
t
r
i
x
@
u
p
M
a
t
r
i
x
)
BA = alpha∗(downMatrix \ @ \ upMatrix)
BA=alpha∗(downMatrix @ upMatrix)
注意:这里的LoRA不仅有Linear的LoRA,还包含Conv2d的LoRA。
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
}
state_dict_ = {}
for key in state_dict:
if ".lora_up" not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
# lcm lora have alpha
alpha = state_dict[key.replace(".lora_up.weight", ".alpha")].to(device="cuda", dtype=torch.float16)
print(key, "@", key.replace(".lora_up", ".lora_down"))
print(alpha, weight_up.shape, weight_down.shape)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
if len(weight_up.shape) == len(weight_down.shape): # for Linear weight
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else: # for Conv2d weight
lora_weight = alpha * torch.einsum('a b, b c h w -> a c h w', weight_up, weight_down)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
for special_key in special_keys:
target_name = target_name.replace(special_key, special_keys[special_key])
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。