当前位置:   article > 正文

BERT参数量计算_bert模型参数量怎么算

bert模型参数量怎么算

在这里插入图片描述

模型概况:

BERT-Base: L = 12 , H = 768 , A = 12 L = 12, H = 768, A = 12 L=12,H=768,A=12

参数计算:
PART 01:input embedding
  • Token Embedding: 30522 × 768 30522 \times 768 30522×768
  • Position Embedding: (max_length) 512 × 768 512 \times 768 512×768
  • Segment Embedding: 2 × 768 2 \times 768 2×768
  • 总参数量 ( 30522 + 512 + 2 ) × 768 = 23 , 835 , 648 (30522 + 512 + 2) \times 768 = 23,835,648 (30522+512+2)×768=23,835,648
PART 02:Multi-Head Attention
  • 基本信息

    • 12个head
    • 生成 Q K V 3个向量
  • 单个 head 的参数量

    • 768 × 768 / 12 × 3 768 \times 768/12 \times 3 768×768/12×3
      在这里插入图片描述
  • 多头拼接的参数

    • 12 × 768 / 12 × 768 12 \times 768/12 \times 768 12×768/12×768
  • 总参数量 ( 768 × 768 / 12 × 3 ) × 12 + 12 × 768 / 12 × 768 = 2 , 359 , 296 (768 \times 768/12 \times 3)\times {\color{red}12} + 12 \times 768/12 \times 768 = 2,359,296 (768×768/12×3)×12+12×768/12×768=2,359,296

PART 03:Add & Norm (第一次)
  • 基本信息
    • 针对多头注意力的输出,这里使用的是 L a y e r N o r m ( x + S u b l a y e r ( x ) ) LayerNorm(x + Sublayer(x)) LayerNorm(x+Sublayer(x))在这里插入图片描述

      进行层标准化需要计算同一层隐层单元中的如上两个参数。

  • 总参数量 768 × 2 = 1 , 536 768 \times 2 = 1,536 768×2=1,536
PART 04:Feed Forward
  • 公式 F F N ( x ) = m a x ( 0 , x W 1 + b 1 ) W 2 + b 2 FFN(x)=max(0, xW_{1}+b_{1})W_{2}+b_{2} FFN(x)=max(0,xW1+b1)W2+b2
  • 论文指明,feed-forward/filter size 设置为 4H(即 4 × 768 = 3072 4 \times 768 = 3072 4×768=3072
  • 第一层参数: 768 × 3072 + 3072 768 \times 3072 + 3072 768×3072+3072
  • 第二层参数: 3072 × 768 + 768 3072 \times 768 + 768 3072×768+768
  • 总参数量 ( 768 × 3072 + 3072 ) + ( 3072 × 768 + 768 ) = 4 , 722 , 432 (768 \times 3072 + 3072)+ (3072 \times 768 + 768)= 4,722,432 (768×3072+3072)+(3072×768+768)=4,722,432
PART 05:Add & Norm (第二次)
  • 与第一次相同,参数量为 768 × 2 = 1 , 536 768 \times 2 = 1,536 768×2=1,536
计算结果:
  • 由于 PART 02-05 在 BERT-Base 模型中共有 12 个 Encoder
  • 因此,参数总量为:
    23 , 835 , 648 + ( 2 , 359 , 296 + 1 , 536 + 4 , 722 , 432 + 1 , 536 ) × 12 = 108 , 853 , 248 23,835,648 + (2,359,296 + 1,536 + 4,722,432 + 1,536) \times 12 = 108,853,248 23,835,648+(2,359,296+1,536+4,722,432+1,536)×12=108,853,248
参考论文

Transformer: Attention is all you need
Layer Normalization: Layer Normalization
BERT: BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/喵喵爱编程/article/detail/826132
推荐阅读
相关标签
  

闽ICP备14008679号