当前位置:   article > 正文

构建基于BERT微调的多标签分类模型_bertpretrainedmodel

bertpretrainedmodel

实现继承BERT预训练模型的分类任务类

  1. import torch.nn as nn
  2. from transformers import BertPreTrainedModel, BertModel, BertConfig
  3. # 构建基于BERT的微调模型类
  4. class Model(nn.Module):
  5. def __init__(self, config):
  6. super(Model, self).__init__()
  7. # 导入参数设置对象
  8. model_config = BertConfig.from_pretrained(config.bert_path,
  9. num_labels=config.num_classes)
  10. # 导入基于bert-base-chinese的预训练模型
  11. self.bert = BertModel.from_pretrained(config.bert_path, config=model_config)
  12. # 此处用于调节是否将BERT纳入微调训练, 建议数据量+算力充足的情况下置为True
  13. # 如果设置为False, 则保持整个BERT网络参数不变, 微调仅仅针对最后的全连接层进行训练
  14. for param in self.bert.parameters():
  15. param.requires_grad = True
  16. # 全连接层的出口维度, 取决于具体的任务
  17. self.fc = nn.Linear(config.hidden_size, config.num_classes)
  18. def forward(self, x):
  19. # x[0]是输入的具体文本信息
  20. context = x[0]
  21. # x[1]是经过tokenizer处理后返回的attention mask张量
  22. # mask的尺寸size和输入相同, padding部分用0遮掩, 比如[1, 1, 1, 0, 0]
  23. mask = x[1]
  24. # x[2]是字符类型id
  25. token_type_ids = x[2]
  26. # 利用BERT模型得到输出张量, 并且只保留BertPooler的输出, 即第一个字符CLS对应的输出张量
  27. _, pooled = self.bert(context, attention_mask=mask, token_type_ids=token_type_id)
  28. # 再利用微调网络进一步提取特征, 并利用全连接层对特征张量进行维度变换
  29. out = self.fc(pooled)
  30. return out

BERT模型的参数执行微调

        展示BERT模型中的参数命名:

  1. class Model(nn.Module):
  2. def __init__(self, config):
  3. super(Model, self).__init__()
  4. self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)
  5. # 将BERT中所有的参数层名字打印出来
  6. for name, param in self.bert.named_parameters():
  7. print(name)
  8. self.fc = nn.Linear(config.hidden_size, config.num_classes)

针对BERT模型中的embedding层, 让其中的参数不参与微调

  1. class Model(nn.Module):
  2. def __init__(self, config):
  3. super(Model, self).__init__()
  4. self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)
  5. # 希望锁定embeddings层的参数, 不参与更新
  6. for name, param in self.bert.embeddings.named_parameters():
  7. print(name)
  8. param.requires_grad = False
  9. self.fc = nn.Linear(config.hidden_size, config.num_classes)

BERT中的全连接层, 让其中的weight参数不参与微调

  1. class Model(nn.Module):
  2. def __init__(self, config):
  3. super(Model, self).__init__()
  4. self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)
  5. # 希望将全连接层中的.weight部分参数锁定
  6. for name, param in self.bert.named_parameters():
  7. if name.endswith('weight'):
  8. print(name)
  9. param.requires_grad = False
  10. self.fc = nn.Linear(config.hidden_size, config.num_classes)

BERT中指定的若干层, 让其中的参数不参与微调

  1. class Model(nn.Module):
  2. def __init__(self, config):
  3. super(Model, self).__init__()
  4. self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)
  5. # 封闭BERT中的第1, 3, 5层参数, 不参与微调
  6. index_array = [1, 3, 5]
  7. for name, param in self.bert.named_parameters():
  8. new_x = name.split('.')[2]
  9. if new_x in index_array:
  10. print(name)
  11. param.requires_grad = False
  12. self.fc = nn.Linear(config.hidden_size, config.num_classes)

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号