tensorflow set contain
如果知道tensor的长度,比较简单
import tensorflow as tf
one_vector = tf.constant([0,111,222,333,0])
tmp_list = []
for tmp_index in range(0, 3):
tmp_list.append(tf.cast(tf.math.equal(one_vector[tmp_index:tmp_index+3],
tf.constant([111,222,333])),tf.int32))
total = tf.reduce_sum(tmp_list)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
print(sess.run(tmp_list))
print(sess.run(total))
print结果:
[array([0, 0, 0], dtype=int32), array([1, 1, 1], dtype=int32), array([0, 0, 0], dtype=int32)]
3