当前位置:   article > 正文

StableDiffusion LoRA 原理与代码详解_lora原理详解

lora原理详解

LoRA原理

LoRA(Learnable Re-Weighting),是一种重加权模型。LORA模型将神经网络中的每一层看做是一个可加权的特征提取器,每一层的权重决定了它对模型输出的影响。通过对已有的SD模型的部分权重进行调整,从而实现对生图效果的改善。(大部分LoRA模型是对Transformer中的注意力权重Linear层进行了调整,也有部分对Conv2D的卷积核权重进行微调)

Linear LoRA

LoRA从技术角度来讲很简单,基本流程如下图:蓝色部分表示原来的预训练权重,橙色部分则是lora需要训练的权重A和B。
在这里插入图片描述

  • 训练阶段:上图中的A和B是可训练的权重,在训练阶段,蓝色部分冻结,训练A和B,最后保存权重也仅保存A和B的相关参数。基本步骤如下:
    在这里插入图片描述
  • 推理阶段:正常使用W = W0+BA来更新模型权重。

普通的Linear LoRA只考虑了Linear层的weight,B矩阵的维度数量len(weight_up.shape)等于 A矩阵的维度数量len(weight_down.shape) 的。B的shape是[in_dim, rank],A的形状是[rank, out_dim],乘完就是[in_dim, out_dim]

		if isinstance(module, nn.Linear):
            assert len(self._up_weight.shape) == len(self._down_weight.shape) == 2

            in_dim = module.in_features
            out_dim = module.out_features
            self._lora_down = nn.Linear(in_dim, self._r, bias=False)
            self._lora_up = nn.Linear(self._r, out_dim, bias=False)


        self._lora_down.weight = nn.Parameter(self._down_weight)
        self._lora_up.weight = nn.Parameter(self._up_weight)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

Conv2d LoRA

卷积层和全连接层是两种不同的操作。在卷积层的LoRA改造,主要是对卷积核的权重矩阵做改造。明白了这一点,其实卷积层的改造思路跟全连接基本是一致的。

Conv2d的B矩阵的维度数量len(weight_up.shape)不等于 A矩阵的维度数量len(weight_down.shape) 的。融合时需要保证最后2个维度不变,对前面2个维度做矩阵的转置乘法。B的shape是[in_dim, rank, kernel=(h, w)],A的形状是[rank, out_dim, (1, 1)],乘完就是[in_dim, out_dim, kernel=(h, w)],即[in_dim, out_dim, h, w]

这里的A和B的shape设计是有原因的,因为BA最后的形状要和原始模型中卷积核的权重矩阵(self.weight)一致,所以要根据self.weight的shape来设计:在这里插入图片描述

	 if isinstance(module, nn.Conv2d):
            assert len(self._up_weight.shape) == len(self._down_weight.shape) == 4

            r = self._r
            in_dim = module.in_channels
            out_dim = module.out_channels
            kernel = module.kernel_size
            stride = module.stride
            padding = module.padding

            self._lora_down = nn.Conv2d(in_dim, r, kernel, stride, padding, bias=False)
            self._lora_up = nn.Conv2d(r, out_dim, (1, 1), (1, 1), bias=False)


        self._lora_down.weight = nn.Parameter(self._down_weight)
        self._lora_up.weight = nn.Parameter(self._up_weight)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

LoRA文件内容

lora模型中每层的权重包含3个部分,分别为.lora_down.weight.lora_up.weight.alpha。其中downup分别为lora模型的上下层权重分别对应了BA权重,alpha也是一个可学习的参数。lora模型每层的权重可表示为:
w = a l p h a ∗ ( d o w n M a t r i x   @   u p M a t r i x ) w = alpha∗(downMatrix \ @ \ upMatrix) w=alpha(downMatrix @ upMatrix)

以目前最流行的LCM LoRA为例进行一下可视化,key如下(部分):

lora_unet_down_blocks_0_downsamplers_0_conv.alpha
lora_unet_down_blocks_0_downsamplers_0_conv.lora_down.weight
lora_unet_down_blocks_0_downsamplers_0_conv.lora_up.weight
lora_unet_down_blocks_0_resnets_0_conv1.alpha
lora_unet_down_blocks_0_resnets_0_conv1.lora_down.weight
lora_unet_down_blocks_0_resnets_0_conv1.lora_up.weight
lora_unet_down_blocks_0_resnets_0_conv2.alpha
lora_unet_down_blocks_0_resnets_0_conv2.lora_down.weight
lora_unet_down_blocks_0_resnets_0_conv2.lora_up.weight
lora_unet_down_blocks_0_resnets_0_time_emb_proj.alpha
lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_down.weight
lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_up.weight
lora_unet_down_blocks_0_resnets_1_conv1.alpha
lora_unet_down_blocks_0_resnets_1_conv1.lora_down.weight
lora_unet_down_blocks_0_resnets_1_conv1.lora_up.weight
lora_unet_down_blocks_0_resnets_1_conv2.alpha
lora_unet_down_blocks_0_resnets_1_conv2.lora_down.weight
lora_unet_down_blocks_0_resnets_1_conv2.lora_up.weight
lora_unet_down_blocks_0_resnets_1_time_emb_proj.alpha
lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_down.weight
lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_0_proj_in.alpha
lora_unet_down_blocks_1_attentions_0_proj_in.lora_down.weight
lora_unet_down_blocks_1_attentions_0_proj_in.lora_up.weight
lora_unet_down_blocks_1_attentions_0_proj_out.alpha
lora_unet_down_blocks_1_attentions_0_proj_out.lora_down.weight
lora_unet_down_blocks_1_attentions_0_proj_out.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_k.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_out_0.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_q.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_v.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_k.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_out_0.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_q.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_v.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_0_proj.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_0_proj.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_0_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_2.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_2.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_2.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_k.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_out_0.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_q.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_v.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_k.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_out_0.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_q.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_v.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_0_proj.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_0_proj.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_0_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_2.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_2.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_2.lora_up.weight
lora_unet_down_blocks_1_attentions_1_proj_in.alpha
lora_unet_down_blocks_1_attentions_1_proj_in.lora_down.weight
lora_unet_down_blocks_1_attentions_1_proj_in.lora_up.weight
lora_unet_down_blocks_1_attentions_1_proj_out.alpha
lora_unet_down_blocks_1_attentions_1_proj_out.lora_down.weight
lora_unet_down_blocks_1_attentions_1_proj_out.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_k.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_out_0.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_q.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_v.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_k.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_out_0.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_q.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_v.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_0_proj.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_0_proj.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_0_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_2.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_2.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_2.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_k.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_out_0.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_q.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_v.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_k.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_out_0.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_q.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_v.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_0_proj.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_0_proj.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_0_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_2.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_2.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_2.lora_up.weight
lora_unet_down_blocks_1_downsamplers_0_conv.alpha
lora_unet_down_blocks_1_downsamplers_0_conv.lora_down.weight
lora_unet_down_blocks_1_downsamplers_0_conv.lora_up.weight
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156

LoRA权重加载

以加载到UNet的LoRA为例:首先获取unet的原始权重W0(state_dict_unet),和LoRA的权重BA(state_dict_lora)。将对应层的权重 W = W 0 + B A W=W0+BA W=W0+BA,即state_dict_unet[name] += state_dict_lora[name]

	def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
        state_dict_unet = unet.state_dict()
        state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
        state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
        if len(state_dict_lora) > 0:
            for name in state_dict_lora:
                state_dict_unet[name] += state_dict_lora[name].to(device=device)
            unet.load_state_dict(state_dict_unet)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

其中convert_state_dict就是将downup进行矩阵乘法得到BA
B A = a l p h a ∗ ( d o w n M a t r i x   @   u p M a t r i x ) BA = alpha∗(downMatrix \ @ \ upMatrix) BA=alpha(downMatrix @ upMatrix)

注意:这里的LoRA不仅有Linear的LoRA,还包含Conv2d的LoRA

	def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
        special_keys = {
            "down.blocks": "down_blocks",
            "up.blocks": "up_blocks",
            "mid.block": "mid_block",
            "proj.in": "proj_in",
            "proj.out": "proj_out",
            "transformer.blocks": "transformer_blocks",
            "to.q": "to_q",
            "to.k": "to_k",
            "to.v": "to_v",
            "to.out": "to_out",
        }
        state_dict_ = {}
        for key in state_dict:
            if ".lora_up" not in key:
                continue
            if not key.startswith(lora_prefix):
                continue
            weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
            weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
            # lcm lora have alpha
            alpha = state_dict[key.replace(".lora_up.weight", ".alpha")].to(device="cuda", dtype=torch.float16)
            print(key, "@", key.replace(".lora_up", ".lora_down"))
            print(alpha, weight_up.shape, weight_down.shape)
            if len(weight_up.shape) == 4:
                weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
                weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
                if len(weight_up.shape) == len(weight_down.shape):  # for Linear weight
                    lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
                else:  # for Conv2d weight
                    lora_weight = alpha * torch.einsum('a b, b c h w -> a c h w', weight_up, weight_down)
            else:
                lora_weight = alpha * torch.mm(weight_up, weight_down)
            target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
            for special_key in special_keys:
                target_name = target_name.replace(special_key, special_keys[special_key])
            state_dict_[target_name] = lora_weight.cpu()
        return state_dict_
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/524274
推荐阅读
相关标签
  

闽ICP备14008679号