赞
踩
看代码时,第一行跑完target_pred的last_hidden_state的shape为(32,100,768),第二行跑完target_pred的shape为(32,768) ,不理解。
注:32为batch_size,100为max_length,768为hidden_size。
target_pred = model(target_input_ids, target_attention_mask, target_token_type_ids, output_hidden_states=True, return_dict=True)
target_pred = target_pred.last_hidden_state[:, 0]
涉及的知识点是三维数组切片,测试代码如下:
a=np.arange(0, 24, 1).reshape(3,2,4) print(a) # 输出 [[[ 0 1 2 3] [ 4 5 6 7]] [[ 8 9 10 11] [12 13 14 15]] [[16 17 18 19] [20 21 22 23]]] b=a[:,0] print(b) print(b.shape) # 输出 [[ 0 1 2 3] [ 8 9 10 11] [16 17 18 19]] (3, 4)
以上面的例子类推,last_hidden_state[:, 0]的含义:每个batch有32个样本,每个样本的shape为(100,768)的二维数组,对每个样本取第一行,即(1,768),因此last_hidden_state[:, 0]的shape为(32,768),相当于降维。并且调用了cls。
1.【bert】: 在eval时pooler、last_hiddent_state、cls的区分
2.索引与切片,玩转数组之七十二变
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。