当前位置:   article > 正文

【论文解读】GPT Understands, Too_gpt understands,too

gpt understands,too

一.论文

1.1 P-tuning

区别于之前的工作,这篇工作认为promote可以在句子中的任意位置起到作用,可以将它们插入上下文或目标中

上图中,左图是不使用任何操作,右图是选择在居首和目标前插入promote的embedding,插入promote的过程可以表示为

其中x代表一系列离散的输入令牌,y代表目标(可以理解为希望模型想要给你的回答),e()表示对应的embedding,其实就是将其参数化映射成为伪tokens,即

通过最小化这些参数

1.2 promote生成

嵌入的promote实际上可以理解为不一定离散不相互关联的,而实际上的promote其实应该是高度离散的且具有关联性的,因此作者选择使用双向长短期记忆网络(LSTM),激活函数和MLP来建模这种关系

在推理中,我们只需要输出嵌入h,并且可以丢弃LSTM头

二.代码

本质上是使用一个PromptEncoder来生成伪的embedding添加到原先的embedding中

2.1 训练

训练过程只更新promote_encoder中的参数

 2.1.1 PromptEncoder

PTuneForLAMA中实例化了PromptEncoder

 PromptEncoder本质上是一个(嵌入 + LSTM + MLP)

  1. import torch
  2. import torch.nn as nn
  3. class PromptEncoder(torch.nn.Module):
  4. def __init__(self, template, hidden_size, tokenizer, device, args):
  5. super().__init__()
  6. self.device = device
  7. self.spell_length = sum(template)
  8. self.hidden_size = hidden_size
  9. self.tokenizer = tokenizer
  10. self.args = args
  11. # ent embedding
  12. self.cloze_length = template
  13. self.cloze_mask = [
  14. [1] * self.cloze_length[0] # first cloze
  15. + [1] * self.cloze_length[1] # second cloze
  16. + [1] * self.cloze_length[2] # third cloze
  17. ]
  18. self.cloze_mask = torch.LongTensor(self.cloze_mask).bool().to(self.device)
  19. self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0])))).to(self.device)
  20. # embedding
  21. self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), self.hidden_size).to(self.device)
  22. # LSTM
  23. self.lstm_head = torch.nn.LSTM(input_size=self.hidden_size,
  24. hidden_size=self.hidden_size // 2,
  25. num_layers=2,
  26. dropout=self.args.lstm_dropout,
  27. bidirectional=True,
  28. batch_first=True)
  29. self.mlp_head = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size),
  30. nn.ReLU(),
  31. nn.Linear(self.hidden_size, self.hidden_size))
  32. print("init prompt encoder...")
  33. def forward(self):
  34. input_embeds = self.embedding(self.seq_indices).unsqueeze(0)
  35. output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze()
  36. return output_embeds

2.1.2 调用

在PTuneForLAMA的forward函数中调用了embed_input来实现

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/858058
推荐阅读
  

闽ICP备14008679号