赞
踩
使用huggingface transformers中的bert模型,分析统计模型的参数量
import torch
from transformers import BertTokenizer, BertModel
bertModel = BertModel.from_pretrained("bert-base-chinese", output_hidden_states=True, output_attentions=True)
total = sum(p.numel() for p in bertModel.parameters())
print("total param:",total)
输出如下:
total param: 102267648
上述代码统计了模型的总参数量,输出为102267684
bert中的embedding有三种,分别为word embedding、position embedding、sentence embedding。
在bert-base-chinese这个模型中,词汇数量为21128,embedding维度为768,每条数据长度L为512。
word embedding参数量:21128*768
position embedding参数量:512*768
sentence embedding参数量:2*768
在embedding层最后有Layer Norm 层,改层的参数量为768+768,LN公式中的
α
\alpha
α 和
β
\beta
β
embedding层中的参数为
21128*768+512*768+2*768+768+768 =16622592
self-attention 一共有12层,每层中有两部分组成,分别为multihead-Attention 和Layer Norm层
multihead-Attention 中有Q、K、V三个转化矩阵和一个拼接矩阵,Q、K、V的shape为:768*12*64 +768
第一个768为embedding维度,12为head数量,64为子head的维度,最后加的768为模型中的bias。经过Q、K、V变化后的数据需要concat起来,额外需要一个768*768+768的拼接矩阵。
Layer Norm参数量:768+768
self-attention一层中的参数为:
(768*12*64 +768)*3+768*768+768 +768+768=2363904
一共12层,2363904 *12 = 28366848
feedforward 一共有12层,每层中有两部分组成,分别为feedforward和Layer Norm层
feedforward 网络结构为
W
2
(
W
1
X
+
b
1
)
+
b
2
W_2(W_1X+b_1)+b_2
W2(W1X+b1)+b2,有两个线性变换层
W
1
W_1
W1是从768–>7684,
W
2
W_2
W2是从7684–>768,
W
1
W_1
W1参数量为7687684+7684,
W
2
W_2
W2参数量为7684*768+768,
Layer Norm参数量:768+768
feedforward一层中的参数为:
(768*768*4 +768*4)+(768*4*768+768) + 768+768 =4723968
一共12层,4723968*12 = 56687616
embedding:16622592
self-attention:28366848
feedforward:56687616
在feedforward层后还有一个pooler层,维度为768*768,参数量为(768*768+768 weights+bias),为获取训练数据中第一个特殊字符[CLS]的词向量,进一步计算bert中的NSP任务中的loss
total = 16622592 +28366848+56687616 + 768*768+768= 102267648
与pytorch统计结果相同。
上述有不明白的地方,可以看看bert模型中每层的参数
for name,param in bertModel.named_parameters(): print(name) print(param.shape) # 输出如下: embeddings.word_embeddings.weight torch.Size([21128, 768]) embeddings.position_embeddings.weight torch.Size([512, 768]) embeddings.token_type_embeddings.weight torch.Size([2, 768]) embeddings.LayerNorm.weight torch.Size([768]) embeddings.LayerNorm.bias torch.Size([768]) encoder.layer.0.attention.self.query.weight torch.Size([768, 768]) encoder.layer.0.attention.self.query.bias torch.Size([768]) encoder.layer.0.attention.self.key.weight torch.Size([768, 768]) encoder.layer.0.attention.self.key.bias torch.Size([768]) encoder.layer.0.attention.self.value.weight torch.Size([768, 768]) encoder.layer.0.attention.self.value.bias torch.Size([768]) encoder.layer.0.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.0.attention.output.dense.bias torch.Size([768]) encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.0.intermediate.dense.bias torch.Size([3072]) encoder.layer.0.output.dense.weight torch.Size([768, 3072]) encoder.layer.0.output.dense.bias torch.Size([768]) encoder.layer.0.output.LayerNorm.weight torch.Size([768]) encoder.layer.0.output.LayerNorm.bias torch.Size([768]) encoder.layer.1.attention.self.query.weight torch.Size([768, 768]) encoder.layer.1.attention.self.query.bias torch.Size([768]) encoder.layer.1.attention.self.key.weight torch.Size([768, 768]) encoder.layer.1.attention.self.key.bias torch.Size([768]) encoder.layer.1.attention.self.value.weight torch.Size([768, 768]) encoder.layer.1.attention.self.value.bias torch.Size([768]) encoder.layer.1.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.1.attention.output.dense.bias torch.Size([768]) encoder.layer.1.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.1.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.1.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.1.intermediate.dense.bias torch.Size([3072]) encoder.layer.1.output.dense.weight torch.Size([768, 3072]) encoder.layer.1.output.dense.bias torch.Size([768]) encoder.layer.1.output.LayerNorm.weight torch.Size([768]) encoder.layer.1.output.LayerNorm.bias torch.Size([768]) encoder.layer.2.attention.self.query.weight torch.Size([768, 768]) encoder.layer.2.attention.self.query.bias torch.Size([768]) encoder.layer.2.attention.self.key.weight torch.Size([768, 768]) encoder.layer.2.attention.self.key.bias torch.Size([768]) encoder.layer.2.attention.self.value.weight torch.Size([768, 768]) encoder.layer.2.attention.self.value.bias torch.Size([768]) encoder.layer.2.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.2.attention.output.dense.bias torch.Size([768]) encoder.layer.2.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.2.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.2.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.2.intermediate.dense.bias torch.Size([3072]) encoder.layer.2.output.dense.weight torch.Size([768, 3072]) encoder.layer.2.output.dense.bias torch.Size([768]) encoder.layer.2.output.LayerNorm.weight torch.Size([768]) encoder.layer.2.output.LayerNorm.bias torch.Size([768]) encoder.layer.3.attention.self.query.weight torch.Size([768, 768]) encoder.layer.3.attention.self.query.bias torch.Size([768]) encoder.layer.3.attention.self.key.weight torch.Size([768, 768]) encoder.layer.3.attention.self.key.bias torch.Size([768]) encoder.layer.3.attention.self.value.weight torch.Size([768, 768]) encoder.layer.3.attention.self.value.bias torch.Size([768]) encoder.layer.3.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.3.attention.output.dense.bias torch.Size([768]) encoder.layer.3.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.3.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.3.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.3.intermediate.dense.bias torch.Size([3072]) encoder.layer.3.output.dense.weight torch.Size([768, 3072]) encoder.layer.3.output.dense.bias torch.Size([768]) encoder.layer.3.output.LayerNorm.weight torch.Size([768]) encoder.layer.3.output.LayerNorm.bias torch.Size([768]) encoder.layer.4.attention.self.query.weight torch.Size([768, 768]) encoder.layer.4.attention.self.query.bias torch.Size([768]) encoder.layer.4.attention.self.key.weight torch.Size([768, 768]) encoder.layer.4.attention.self.key.bias torch.Size([768]) encoder.layer.4.attention.self.value.weight torch.Size([768, 768]) encoder.layer.4.attention.self.value.bias torch.Size([768]) encoder.layer.4.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.4.attention.output.dense.bias torch.Size([768]) encoder.layer.4.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.4.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.4.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.4.intermediate.dense.bias torch.Size([3072]) encoder.layer.4.output.dense.weight torch.Size([768, 3072]) encoder.layer.4.output.dense.bias torch.Size([768]) encoder.layer.4.output.LayerNorm.weight torch.Size([768]) encoder.layer.4.output.LayerNorm.bias torch.Size([768]) encoder.layer.5.attention.self.query.weight torch.Size([768, 768]) encoder.layer.5.attention.self.query.bias torch.Size([768]) encoder.layer.5.attention.self.key.weight torch.Size([768, 768]) encoder.layer.5.attention.self.key.bias torch.Size([768]) encoder.layer.5.attention.self.value.weight torch.Size([768, 768]) encoder.layer.5.attention.self.value.bias torch.Size([768]) encoder.layer.5.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.5.attention.output.dense.bias torch.Size([768]) encoder.layer.5.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.5.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.5.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.5.intermediate.dense.bias torch.Size([3072]) encoder.layer.5.output.dense.weight torch.Size([768, 3072]) encoder.layer.5.output.dense.bias torch.Size([768]) encoder.layer.5.output.LayerNorm.weight torch.Size([768]) encoder.layer.5.output.LayerNorm.bias torch.Size([768]) encoder.layer.6.attention.self.query.weight torch.Size([768, 768]) encoder.layer.6.attention.self.query.bias torch.Size([768]) encoder.layer.6.attention.self.key.weight torch.Size([768, 768]) encoder.layer.6.attention.self.key.bias torch.Size([768]) encoder.layer.6.attention.self.value.weight torch.Size([768, 768]) encoder.layer.6.attention.self.value.bias torch.Size([768]) encoder.layer.6.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.6.attention.output.dense.bias torch.Size([768]) encoder.layer.6.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.6.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.6.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.6.intermediate.dense.bias torch.Size([3072]) encoder.layer.6.output.dense.weight torch.Size([768, 3072]) encoder.layer.6.output.dense.bias torch.Size([768]) encoder.layer.6.output.LayerNorm.weight torch.Size([768]) encoder.layer.6.output.LayerNorm.bias torch.Size([768]) encoder.layer.7.attention.self.query.weight torch.Size([768, 768]) encoder.layer.7.attention.self.query.bias torch.Size([768]) encoder.layer.7.attention.self.key.weight torch.Size([768, 768]) encoder.layer.7.attention.self.key.bias torch.Size([768]) encoder.layer.7.attention.self.value.weight torch.Size([768, 768]) encoder.layer.7.attention.self.value.bias torch.Size([768]) encoder.layer.7.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.7.attention.output.dense.bias torch.Size([768]) encoder.layer.7.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.7.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.7.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.7.intermediate.dense.bias torch.Size([3072]) encoder.layer.7.output.dense.weight torch.Size([768, 3072]) encoder.layer.7.output.dense.bias torch.Size([768]) encoder.layer.7.output.LayerNorm.weight torch.Size([768]) encoder.layer.7.output.LayerNorm.bias torch.Size([768]) encoder.layer.8.attention.self.query.weight torch.Size([768, 768]) encoder.layer.8.attention.self.query.bias torch.Size([768]) encoder.layer.8.attention.self.key.weight torch.Size([768, 768]) encoder.layer.8.attention.self.key.bias torch.Size([768]) encoder.layer.8.attention.self.value.weight torch.Size([768, 768]) encoder.layer.8.attention.self.value.bias torch.Size([768]) encoder.layer.8.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.8.attention.output.dense.bias torch.Size([768]) encoder.layer.8.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.8.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.8.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.8.intermediate.dense.bias torch.Size([3072]) encoder.layer.8.output.dense.weight torch.Size([768, 3072]) encoder.layer.8.output.dense.bias torch.Size([768]) encoder.layer.8.output.LayerNorm.weight torch.Size([768]) encoder.layer.8.output.LayerNorm.bias torch.Size([768]) encoder.layer.9.attention.self.query.weight torch.Size([768, 768]) encoder.layer.9.attention.self.query.bias torch.Size([768]) encoder.layer.9.attention.self.key.weight torch.Size([768, 768]) encoder.layer.9.attention.self.key.bias torch.Size([768]) encoder.layer.9.attention.self.value.weight torch.Size([768, 768]) encoder.layer.9.attention.self.value.bias torch.Size([768]) encoder.layer.9.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.9.attention.output.dense.bias torch.Size([768]) encoder.layer.9.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.9.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.9.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.9.intermediate.dense.bias torch.Size([3072]) encoder.layer.9.output.dense.weight torch.Size([768, 3072]) encoder.layer.9.output.dense.bias torch.Size([768]) encoder.layer.9.output.LayerNorm.weight torch.Size([768]) encoder.layer.9.output.LayerNorm.bias torch.Size([768]) encoder.layer.10.attention.self.query.weight torch.Size([768, 768]) encoder.layer.10.attention.self.query.bias torch.Size([768]) encoder.layer.10.attention.self.key.weight torch.Size([768, 768]) encoder.layer.10.attention.self.key.bias torch.Size([768]) encoder.layer.10.attention.self.value.weight torch.Size([768, 768]) encoder.layer.10.attention.self.value.bias torch.Size([768]) encoder.layer.10.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.10.attention.output.dense.bias torch.Size([768]) encoder.layer.10.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.10.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.10.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.10.intermediate.dense.bias torch.Size([3072]) encoder.layer.10.output.dense.weight torch.Size([768, 3072]) encoder.layer.10.output.dense.bias torch.Size([768]) encoder.layer.10.output.LayerNorm.weight torch.Size([768]) encoder.layer.10.output.LayerNorm.bias torch.Size([768]) encoder.layer.11.attention.self.query.weight torch.Size([768, 768]) encoder.layer.11.attention.self.query.bias torch.Size([768]) encoder.layer.11.attention.self.key.weight torch.Size([768, 768]) encoder.layer.11.attention.self.key.bias torch.Size([768]) encoder.layer.11.attention.self.value.weight torch.Size([768, 768]) encoder.layer.11.attention.self.value.bias torch.Size([768]) encoder.layer.11.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.11.attention.output.dense.bias torch.Size([768]) encoder.layer.11.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.11.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.11.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.11.intermediate.dense.bias torch.Size([3072]) encoder.layer.11.output.dense.weight torch.Size([768, 3072]) encoder.layer.11.output.dense.bias torch.Size([768]) encoder.layer.11.output.LayerNorm.weight torch.Size([768]) encoder.layer.11.output.LayerNorm.bias torch.Size([768]) pooler.dense.weight torch.Size([768, 768]) pooler.dense.bias torch.Size([768])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。