赞
踩
实现继承BERT预训练模型的分类任务类
- import torch.nn as nn
- from transformers import BertPreTrainedModel, BertModel, BertConfig
-
- # 构建基于BERT的微调模型类
- class Model(nn.Module):
- def __init__(self, config):
- super(Model, self).__init__()
-
- # 导入参数设置对象
- model_config = BertConfig.from_pretrained(config.bert_path,
- num_labels=config.num_classes)
- # 导入基于bert-base-chinese的预训练模型
- self.bert = BertModel.from_pretrained(config.bert_path, config=model_config)
-
- # 此处用于调节是否将BERT纳入微调训练, 建议数据量+算力充足的情况下置为True
- # 如果设置为False, 则保持整个BERT网络参数不变, 微调仅仅针对最后的全连接层进行训练
- for param in self.bert.parameters():
- param.requires_grad = True
-
- # 全连接层的出口维度, 取决于具体的任务
- self.fc = nn.Linear(config.hidden_size, config.num_classes)
-
- def forward(self, x):
- # x[0]是输入的具体文本信息
- context = x[0]
- # x[1]是经过tokenizer处理后返回的attention mask张量
- # mask的尺寸size和输入相同, padding部分用0遮掩, 比如[1, 1, 1, 0, 0]
- mask = x[1]
- # x[2]是字符类型id
- token_type_ids = x[2]
- # 利用BERT模型得到输出张量, 并且只保留BertPooler的输出, 即第一个字符CLS对应的输出张量
- _, pooled = self.bert(context, attention_mask=mask, token_type_ids=token_type_id)
- # 再利用微调网络进一步提取特征, 并利用全连接层对特征张量进行维度变换
- out = self.fc(pooled)
- return out

对BERT模型的参数执行微调
展示BERT模型中的参数命名:
- class Model(nn.Module):
- def __init__(self, config):
- super(Model, self).__init__()
- self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)
-
- # 将BERT中所有的参数层名字打印出来
- for name, param in self.bert.named_parameters():
- print(name)
-
- self.fc = nn.Linear(config.hidden_size, config.num_classes)
针对BERT模型中的embedding层, 让其中的参数不参与微调
- class Model(nn.Module):
- def __init__(self, config):
- super(Model, self).__init__()
- self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)
-
- # 希望锁定embeddings层的参数, 不参与更新
- for name, param in self.bert.embeddings.named_parameters():
- print(name)
- param.requires_grad = False
-
- self.fc = nn.Linear(config.hidden_size, config.num_classes)
BERT中的全连接层, 让其中的weight参数不参与微调
- class Model(nn.Module):
- def __init__(self, config):
- super(Model, self).__init__()
- self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)
-
- # 希望将全连接层中的.weight部分参数锁定
- for name, param in self.bert.named_parameters():
- if name.endswith('weight'):
- print(name)
- param.requires_grad = False
-
- self.fc = nn.Linear(config.hidden_size, config.num_classes)
BERT中指定的若干层, 让其中的参数不参与微调
- class Model(nn.Module):
- def __init__(self, config):
- super(Model, self).__init__()
- self.bert = BertModel.from_pretrained(config.bert_path,config=config.bert_config)
-
- # 封闭BERT中的第1, 3, 5层参数, 不参与微调
- index_array = [1, 3, 5]
- for name, param in self.bert.named_parameters():
- new_x = name.split('.')[2]
- if new_x in index_array:
- print(name)
- param.requires_grad = False
-
- self.fc = nn.Linear(config.hidden_size, config.num_classes)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。