当前位置:   article > 正文

【NLP | 自然语言处理】BERT Prompt文本分类(含源代码)_promptbert

promptbert

一、Prompt 介绍

Prompt 是 NLP 的一个新领域。在 Prompt 中任务的描述被嵌入到输入中,提供了一种新的方式来控制机器学习模型输出。

Prompt 是利用预训练语言模型在大量文本数据上获得的知识,来解决各种下游任务。Prompt 的优势是它可以减少或者避免对预训练模型进行微调,节省计算资源和时间,同时保持或者提高模型的性能和泛化能力。

Prompt 的方法是根据不同的任务和数据,设计合适的输入格式,包括问题,上下文,前缀,后缀,分隔符等。

二、BERT 与 Prompt 使用

Prompt 可用于提高 BERT 的句子表示能力,通过在 BERT 的输入中加入一些特定的词语作为 Prompt,引导 BERT 生成更好的句子向量:

  • 方法1:在句子的开头或结尾加入 Prompt;
  • 方法2:在句子的中间加入 Prompt。

三、Prompt 搜索方法

Prompt 的搜索方法找到最优的 Prompt,能最大化 BERT 表示能力的 Prompt。目前有三种主要的搜索方法:

  • 随机搜索:随机生成一些 Prompt,然后用它们作为 BERT 的输入,计算 BERT 的输出向量与目标向量的相似度,选择相似度最高的 Prompt 作为最优的 Prompt。
  • 贪心搜索:从一个空的 Prompt 开始,每次在 Prompt 的末尾加入一个词,然后用它作为 BERT 的输入,计算 BERT 的输出向量与目标向量的相似度,选择相似度最高的词作为 Prompt 的一部分,直到达到一个预设的长度或者相似度阈值。
  • 强化学习搜索:将 Prompt 的生成视为一个序列决策问题,使用强化学习的算法,来优化一个策略网络,根据一个奖励函数来更新网络的参数。

四、Prompt 方法局限性

BERT + Prompt 的优势是能够利用 Prompt 来引导 BERT 生成更好的句子向量,从而提高句子表示的质量和多样性。

句子相似度,文本分类,文本检索等,BERT + Prompt 可能会比原始 BERT 模型有效。文本生成的任务,如文本摘要,文本复述,文本续写等,BERT + Prompt 可能不一定比原始 BERT 模型有效。

Prompt 适合进行多任务进行建模,比如多个文本任务一起进行训练。因此在单个任务中,Prompt 并不会增加模型精度。在现有文本分类比赛中暂时还没看到 Prompt 的使用案例。

五、案例:Prompt 文本分类

输入文本:

It was [mask]. 文本输入样例
  • 1

将 [MASK] 输出接全连接层,进行分类。

5.1 步骤1:定义模型

class Bert_Model(nn.Module):
    def __init__(self,  bert_path ,config_file ):
        super(Bert_Model, self).__init__()
        self.bert = BertForMaskedLM.from_pretrained(bert_path,config=config_file)  # 加载预训练模型权重
 
 
    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids, attention_mask, token_type_ids) #masked LM 输出的是 mask的值 对应的ids的概率 ,输出 会是词表大小,里面是概率 
        logit = outputs[0]  # 池化后的输出 [bs, config.hidden_size]

        return logit 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

5.2 步骤2:定义数据集

class MyDataSet(Data.Dataset):
    def __init__(self, sen , mask , typ ,label ):
        super(MyDataSet, self).__init__()
        self.sen = torch.tensor(sen,dtype=torch.long)
        self.mask = torch.tensor(mask,dtype=torch.long)
        self.typ =torch.tensor( typ,dtype=torch.long)
        self.label = torch.tensor(label,dtype=torch.long)
 
    def __len__(self):
        return self.sen.shape[0]
 
    def __getitem__(self, idx):
        return self.sen[idx], self.mask[idx],self.typ[idx],self.label[idx]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

5.3 步骤3:对文本加入Prompt

prefix = 'It was [mask]. '

for i in range(len(x_train)):
    text_ = prefix+x_train[i][0]
    encode_dict = tokenizer.encode_plus(text_,max_length=60,padding="max_length",truncation=True)
  • 1
  • 2
  • 3
  • 4
  • 5

5.4 步骤4:模型训练与预测

optimizer = AdamW(model.parameters(),lr=2e-5,weight_decay=1e-4)  #使用Adam优化器
loss_func = nn.CrossEntropyLoss(ignore_index=-1)

for idx,(ids,att_mask,type,y) in enumerate(train_dataset):
    ids,att_mask,type,y = ids.to(device),att_mask.to(device),type.to(device),y.to(device)
    out_train = model(ids,att_mask,type)
    loss = loss_func(out_train.view(-1, tokenizer.vocab_size),y.view(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss_sum += loss.item()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/450100
推荐阅读
相关标签
  

闽ICP备14008679号