bert模型取last_hidden_state[:, 0]

问题

看代码时,第一行跑完target_pred的last_hidden_state的shape为(32,100,768),第二行跑完target_pred的shape为(32,768) ,不理解。
注:32为batch_size,100为max_length,768为hidden_size。

target_pred = model(target_input_ids, target_attention_mask, target_token_type_ids, output_hidden_states=True, return_dict=True)
target_pred = target_pred.last_hidden_state[:, 0]

解决

涉及的知识点是三维数组切片,测试代码如下:

a=np.arange(0, 24, 1).reshape(3,2,4)
print(a)
# 输出
[[[ 0  1  2  3]
  [ 4  5  6  7]]

 [[ 8  9 10 11]
  [12 13 14 15]]

 [[16 17 18 19]
  [20 21 22 23]]]
  
b=a[:,0]
print(b)
print(b.shape)
# 输出
[[ 0  1  2  3]
 [ 8  9 10 11]
 [16 17 18 19]]
(3, 4)

以上面的例子类推,last_hidden_state[:, 0]的含义:每个batch有32个样本,每个样本的shape为(100,768)的二维数组,对每个样本取第一行,即(1,768),因此last_hidden_state[:, 0]的shape为(32,768),相当于降维。并且调用了cls。

扩展阅读

1.【bert】: 在eval时pooler、last_hiddent_state、cls的区分
2.索引与切片,玩转数组之七十二变

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>