当前位置:   article > 正文

bert模型取last_hidden_state[:, 0]_bert输出lasthidd

bert输出lasthidd

问题

看代码时,第一行跑完target_pred的last_hidden_state的shape为(32,100,768),第二行跑完target_pred的shape为(32,768) ,不理解。
注:32为batch_size,100为max_length,768为hidden_size。

target_pred = model(target_input_ids, target_attention_mask, target_token_type_ids, output_hidden_states=True, return_dict=True)
target_pred = target_pred.last_hidden_state[:, 0]
  • 1
  • 2

解决

涉及的知识点是三维数组切片,测试代码如下:

a=np.arange(0, 24, 1).reshape(3,2,4)
print(a)
# 输出
[[[ 0  1  2  3]
  [ 4  5  6  7]]

 [[ 8  9 10 11]
  [12 13 14 15]]

 [[16 17 18 19]
  [20 21 22 23]]]
  
b=a[:,0]
print(b)
print(b.shape)
# 输出
[[ 0  1  2  3]
 [ 8  9 10 11]
 [16 17 18 19]]
(3, 4)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

以上面的例子类推,last_hidden_state[:, 0]的含义:每个batch有32个样本,每个样本的shape为(100,768)的二维数组,对每个样本取第一行,即(1,768),因此last_hidden_state[:, 0]的shape为(32,768),相当于降维。并且调用了cls。

扩展阅读

1.【bert】: 在eval时pooler、last_hiddent_state、cls的区分
2.索引与切片,玩转数组之七十二变

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Li_阴宅/article/detail/780762
推荐阅读
相关标签
  

闽ICP备14008679号