bert模型取last_hidden_state[:, 0]
问题
看代码时,第一行跑完target_pred的last_hidden_state的shape为(32,100,768),第二行跑完target_pred的shape为(32,768) ,不理解。
注:32为batch_size,100为max_length,768为hidden_size。
target_pred = model(target_input_ids, target_attention_mask, target_token_type_ids, output_hidden_states=True, return_dict=True)
target_pred = target_pred.last_hidden_state[:, 0]
解决
涉及的知识点是三维数组切片,测试代码如下:
a=np.arange(0, 24, 1).reshape(3,2,4)
print(a)
# 输出
[[[ 0 1 2 3]
[ 4 5 6 7]]
[[ 8 9 10 11]
[12 13 14 15]]
[[16 17 18 19]
[20 21 22 23]]]
b=a[:,0]
print(b)
print(b.shape)
# 输出
[[ 0 1 2 3]
[ 8 9 10 11]
[16 17 18 19]]
(3, 4)
以上面的例子类推,last_hidden_state[:, 0]的含义:每个batch有32个样本,每个样本的shape为(100,768)的二维数组,对每个样本取第一行,即(1,768),因此last_hidden_state[:, 0]的shape为(32,768),相当于降维。并且调用了cls。
扩展阅读
1.【bert】: 在eval时pooler、last_hiddent_state、cls的区分
2.索引与切片,玩转数组之七十二变
本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
二维码