赞
踩
由于bert模型具有12层,参数量达一亿,bert模型做微调有的时候就需要只训练部分参数,那么就需要把其他的参数冻结掉,固定住,又能微调bert模型,还能提高模型训练的效率。这个就需要用到parameter的requires_grad的属性,来冻结和放开参数。
首先我们看看bert模型的具体参数有那些:
bert.embeddings.word_embeddings.weight torch.Size([21128, 768]) bert.embeddings.position_embeddings.weight torch.Size([512, 768]) bert.embeddings.token_type_embeddings.weight torch.Size([2, 768]) bert.embeddings.LayerNorm.weight torch.Size([768]) bert.embeddings.LayerNorm.bias torch.Size([768]) bert.encoder.layer.0.attention.self.query.weight torch.Size([768, 768]) bert.encoder.layer.0.attention.self.query.bias torch.Size([768]) bert.encoder.layer.0.attention.self.key.weight torch.Size([768, 768]) bert.encoder.layer.0.attention.self.key.bias torch.Size([768]) bert.encoder.layer.0.attention.self.value.weight torch.Size([768, 768]) bert.encoder.layer.0.attention.self.value.bias torch.Size([768]) bert.encoder.layer.0.attention.output.dense.weight torch.Size([768, 768]) bert.encoder.layer.0.attention.output.dense.bias torch.Size([768]) bert.encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768]) bert.encoder.layer.0.intermediate.dense.bias torch.Size([3072]) bert.encoder.layer.0.output.dense.weight torch.Size([768, 3072]) bert.encoder.layer.0.output.dense.bias torch.Size([768]) bert.encoder.layer.0.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.0.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.1.attention.self.query.weight torch.Size([768, 768]) bert.encoder.layer.1.attention.self.query.bias torch.Size([768]) bert.encoder.layer.1.attention.self.key.weight torch.Size([768, 768]) bert.encoder.layer.1.attention.self.key.bias torch.Size([768]) bert.encoder.layer.1.attention.self.value.weight torch.Size([768, 768]) bert.encoder.layer.1.attention.self.value.bias torch.Size([768]) bert.encoder.layer.1.attention.output.dense.weight torch.Size([768, 768]) bert.encoder.layer.1.attention.output.dense.bias torch.Size([768]) bert.encoder.layer.1.attention.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.1.attention.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.1.intermediate.dense.weight torch.Size([3072, 768]) bert.encoder.layer.1.intermediate.dense.bias torch.Size([3072]) bert.encoder.layer.1.output.dense.weight torch.Size([768, 3072]) bert.encoder.layer.1.output.dense.bias torch.Size([768]) bert.encoder.layer.1.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.1.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.2.attention.self.query.weight torch.Size([768, 768]) bert.encoder.layer.2.attention.self.query.bias torch.Size([768]) bert.encoder.layer.2.attention.self.key.weight torch.Size([768, 768]) bert.encoder.layer.2.attention.self.key.bias torch.Size([768]) bert.encoder.layer.2.attention.self.value.weight torch.Size([768, 768]) bert.encoder.layer.2.attention.self.value.bias torch.Size([768]) bert.encoder.layer.2.attention.output.dense.weight torch.Size([768, 768]) bert.encoder.layer.2.attention.output.dense.bias torch.Size([768]) bert.encoder.layer.2.attention.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.2.attention.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.2.intermediate.dense.weight torch.Size([3072, 768]) bert.encoder.layer.2.intermediate.dense.bias torch.Size([3072]) bert.encoder.layer.2.output.dense.weight torch.Size([768, 3072]) bert.encoder.layer.2.output.dense.bias torch.Size([768]) bert.encoder.layer.2.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.2.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.3.attention.self.query.weight torch.Size([768, 768]) bert.encoder.layer.3.attention.self.query.bias torch.Size([768]) bert.encoder.layer.3.attention.self.key.weight torch.Size([768, 768]) bert.encoder.layer.3.attention.self.key.bias torch.Size([768]) bert.encoder.layer.3.attention.self.value.weight torch.Size([768, 768]) bert.encoder.layer.3.attention.self.value.bias torch.Size([768]) bert.encoder.layer.3.attention.output.dense.weight torch.Size([768, 768]) bert.encoder.layer.3.attention.output.dense.bias torch.Size([768]) bert.encoder.layer.3.attention.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.3.attention.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.3.intermediate.dense.weight torch.Size([3072, 768]) bert.encoder.layer.3.intermediate.dense.bias torch.Size([3072]) bert.encoder.layer.3.output.dense.weight torch.Size([768, 3072]) bert.encoder.layer.3.output.dense.bias torch.Size([768]) bert.encoder.layer.3.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.3.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.4.attention.self.query.weight torch.Size([768, 768]) bert.encoder.layer.4.attention.self.query.bias torch.Size([768]) bert.encoder.layer.4.attention.self.key.weight torch.Size([768, 768]) bert.encoder.layer.4.attention.self.key.bias torch.Size([768]) bert.encoder.layer.4.attention.self.value.weight torch.Size([768, 768]) bert.encoder.layer.4.attention.self.value.bias torch.Size([768]) bert.encoder.layer.4.attention.output.dense.weight torch.Size([768, 768]) bert.encoder.layer.4.attention.output.dense.bias torch.Size([768]) bert.encoder.layer.4.attention.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.4.attention.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.4.intermediate.dense.weight torch.Size([3072, 768]) bert.encoder.layer.4.intermediate.dense.bias torch.Size([3072]) bert.encoder.layer.4.output.dense.weight torch.Size([768, 3072]) bert.encoder.layer.4.output.dense.bias torch.Size([768]) bert.encoder.layer.4.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.4.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.5.attention.self.query.weight torch.Size([768, 768]) bert.encoder.layer.5.attention.self.query.bias torch.Size([768]) bert.encoder.layer.5.attention.self.key.weight torch.Size([768, 768]) bert.encoder.layer.5.attention.self.key.bias torch.Size([768]) bert.encoder.layer.5.attention.self.value.weight torch.Size([768, 768]) bert.encoder.layer.5.attention.self.value.bias torch.Size([768]) bert.encoder.layer.5.attention.output.dense.weight torch.Size([768, 768]) bert.encoder.layer.5.attention.output.dense.bias torch.Size([768]) bert.encoder.layer.5.attention.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.5.attention.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.5.intermediate.dense.weight torch.Size([3072, 768]) bert.encoder.layer.5.intermediate.dense.bias torch.Size([3072]) bert.encoder.layer.5.output.dense.weight torch.Size([768, 3072]) bert.encoder.layer.5.output.dense.bias torch.Size([768]) bert.encoder.layer.5.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.5.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.6.attention.self.query.weight torch.Size([768, 768]) bert.encoder.layer.6.attention.self.query.bias torch.Size([768]) bert.encoder.layer.6.attention.self.key.weight torch.Size([768, 768]) bert.encoder.layer.6.attention.self.key.bias torch.Size([768]) bert.encoder.layer.6.attention.self.value.weight torch.Size([768, 768]) bert.encoder.layer.6.attention.self.value.bias torch.Size([768]) bert.encoder.layer.6.attention.output.dense.weight torch.Size([768, 768]) bert.encoder.layer.6.attention.output.dense.bias torch.Size([768]) bert.encoder.layer.6.attention.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.6.attention.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.6.intermediate.dense.weight torch.Size([3072, 768]) bert.encoder.layer.6.intermediate.dense.bias torch.Size([3072]) bert.encoder.layer.6.output.dense.weight torch.Size([768, 3072]) bert.encoder.layer.6.output.dense.bias torch.Size([768]) bert.encoder.layer.6.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.6.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.7.attention.self.query.weight torch.Size([768, 768]) bert.encoder.layer.7.attention.self.query.bias torch.Size([768]) bert.encoder.layer.7.attention.self.key.weight torch.Size([768, 768]) bert.encoder.layer.7.attention.self.key.bias torch.Size([768]) bert.encoder.layer.7.attention.self.value.weight torch.Size([768, 768]) bert.encoder.layer.7.attention.self.value.bias torch.Size([768]) bert.encoder.layer.7.attention.output.dense.weight torch.Size([768, 768]) bert.encoder.layer.7.attention.output.dense.bias torch.Size([768]) bert.encoder.layer.7.attention.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.7.attention.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.7.intermediate.dense.weight torch.Size([3072, 768]) bert.encoder.layer.7.intermediate.dense.bias torch.Size([3072]) bert.encoder.layer.7.output.dense.weight torch.Size([768, 3072]) bert.encoder.layer.7.output.dense.bias torch.Size([768]) bert.encoder.layer.7.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.7.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.8.attention.self.query.weight torch.Size([768, 768]) bert.encoder.layer.8.attention.self.query.bias torch.Size([768]) bert.encoder.layer.8.attention.self.key.weight torch.Size([768, 768]) bert.encoder.layer.8.attention.self.key.bias torch.Size([768]) bert.encoder.layer.8.attention.self.value.weight torch.Size([768, 768]) bert.encoder.layer.8.attention.self.value.bias torch.Size([768]) bert.encoder.layer.8.attention.output.dense.weight torch.Size([768, 768]) bert.encoder.layer.8.attention.output.dense.bias torch.Size([768]) bert.encoder.layer.8.attention.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.8.attention.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.8.intermediate.dense.weight torch.Size([3072, 768]) bert.encoder.layer.8.intermediate.dense.bias torch.Size([3072]) bert.encoder.layer.8.output.dense.weight torch.Size([768, 3072]) bert.encoder.layer.8.output.dense.bias torch.Size([768]) bert.encoder.layer.8.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.8.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.9.attention.self.query.weight torch.Size([768, 768]) bert.encoder.layer.9.attention.self.query.bias torch.Size([768]) bert.encoder.layer.9.attention.self.key.weight torch.Size([768, 768]) bert.encoder.layer.9.attention.self.key.bias torch.Size([768]) bert.encoder.layer.9.attention.self.value.weight torch.Size([768, 768]) bert.encoder.layer.9.attention.self.value.bias torch.Size([768]) bert.encoder.layer.9.attention.output.dense.weight torch.Size([768, 768]) bert.encoder.layer.9.attention.output.dense.bias torch.Size([768]) bert.encoder.layer.9.attention.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.9.attention.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.9.intermediate.dense.weight torch.Size([3072, 768]) bert.encoder.layer.9.intermediate.dense.bias torch.Size([3072]) bert.encoder.layer.9.output.dense.weight torch.Size([768, 3072]) bert.encoder.layer.9.output.dense.bias torch.Size([768]) bert.encoder.layer.9.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.9.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.10.attention.self.query.weight torch.Size([768, 768]) bert.encoder.layer.10.attention.self.query.bias torch.Size([768]) bert.encoder.layer.10.attention.self.key.weight torch.Size([768, 768]) bert.encoder.layer.10.attention.self.key.bias torch.Size([768]) bert.encoder.layer.10.attention.self.value.weight torch.Size([768, 768]) bert.encoder.layer.10.attention.self.value.bias torch.Size([768]) bert.encoder.layer.10.attention.output.dense.weight torch.Size([768, 768]) bert.encoder.layer.10.attention.output.dense.bias torch.Size([768]) bert.encoder.layer.10.attention.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.10.attention.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.10.intermediate.dense.weight torch.Size([3072, 768]) bert.encoder.layer.10.intermediate.dense.bias torch.Size([3072]) bert.encoder.layer.10.output.dense.weight torch.Size([768, 3072]) bert.encoder.layer.10.output.dense.bias torch.Size([768]) bert.encoder.layer.10.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.10.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.11.attention.self.query.weight torch.Size([768, 768]) bert.encoder.layer.11.attention.self.query.bias torch.Size([768]) bert.encoder.layer.11.attention.self.key.weight torch.Size([768, 768]) bert.encoder.layer.11.attention.self.key.bias torch.Size([768]) bert.encoder.layer.11.attention.self.value.weight torch.Size([768, 768]) bert.encoder.layer.11.attention.self.value.bias torch.Size([768]) bert.encoder.layer.11.attention.output.dense.weight torch.Size([768, 768]) bert.encoder.layer.11.attention.output.dense.bias torch.Size([768]) bert.encoder.layer.11.attention.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.11.attention.output.LayerNorm.bias torch.Size([768]) bert.encoder.layer.11.intermediate.dense.weight torch.Size([3072, 768]) bert.encoder.layer.11.intermediate.dense.bias torch.Size([3072]) bert.encoder.layer.11.output.dense.weight torch.Size([768, 3072]) bert.encoder.layer.11.output.dense.bias torch.Size([768]) bert.encoder.layer.11.output.LayerNorm.weight torch.Size([768]) bert.encoder.layer.11.output.LayerNorm.bias torch.Size([768]) bert.pooler.dense.weight torch.Size([768, 768]) bert.pooler.dense.bias torch.Size([768]) out.weight torch.Size([2, 768]) out.bias torch.Size([2])
比如说现在我们要放开第11和12层以及bert.pooler和out层参数,冻结其他的参数,怎么实现呢?
pytorch中有 model.named_parameters() 和 requires_grad,直接写一个遍历然后设置就好。具体实现代码:
import torch.nn as nn from transformers import BertModel import torch class Model(nn.Module): def __init__(self): super().__init__() self.bert = BertModel.from_pretrained('pretrained_models/Chinese-BERT-wwm') self.out = nn.Linear(768,2) def forward(self): out = self.bert() return out if __name__ == '__main__': model = Model() unfreeze_layers = ['layer.10','layer.11','bert.pooler','out.'] for name, param in model.named_parameters(): print(name,param.size()) print("*"*30) print('\n') for name ,param in model.named_parameters(): param.requires_grad = False for ele in unfreeze_layers: if ele in name: param.requires_grad = True break #验证一下 for name, param in model.named_parameters(): if param.requires_grad: print(name,param.size()) #过滤掉requires_grad = False的参数 optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.00001)
最后在做训练的时候,优化器中一定要添加过滤器filter把requires_grad = False的参数过滤掉,在训练的时候,不会更新这些参数。
参考文献:
Bert模型冻结指定参数进行训练
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。