赞
踩
p-tuning-v2的主要贡献是在原本的输入前添加自定义长度的layer prompts,在后续针对下游任务的训练中冻结BERT模型的所有参数而只训练这些prompt给模型带来影响的参数。
这些参数在源码中具体为:prompt对应的embedding层,两个线性层(对应key和value),一个全连接层。
- class PrefixEncoder(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.embedding = torch.nn.Embedding(seq_len, dim_ebd)
- self.trans = torch.nn.Sequential(
- torch.nn.Linear(dim_ebd, dim_ebd),
- torch.nn.Tanh(),
- torch.nn.Linear(dim_ebd, num_layer * 2 * dim_ebd)
- ).to(device)
- def forward(self, prefix):
- prefix_tokens = self.embedding(prefix)
- past_key_values = self.trans(prefix_tokens)
- return past_key_values
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。