tensorflow set contain

如果知道tensor的长度,比较简单

import tensorflow as tf

one_vector = tf.constant([0,111,222,333,0])
tmp_list = []
for tmp_index in range(0, 3):
    tmp_list.append(tf.cast(tf.math.equal(one_vector[tmp_index:tmp_index+3], 
    tf.constant([111,222,333])),tf.int32))
    
total = tf.reduce_sum(tmp_list)
 
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

print(sess.run(tmp_list))
print(sess.run(total))

print结果:
[array([0, 0, 0], dtype=int32), array([1, 1, 1], dtype=int32), array([0, 0, 0], dtype=int32)]
3

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇

)">
下一篇>>