• List item

# 在对Tensorflow（1.5）构建的模型使用Pytorch进行复现的过程中，遇到了一些关于Tensorflow的问题，这里整合记录一下。

## 权重初始化

``````import tensorflow as tf

A = tf.random_uniform((3, 2, 3, 4), minval=0.0, maxval=4.0, seed=0)
B = tf.ones((3, 2, 3, 4))

def concat(layers):
return tf.concat(layers, axis=3)

def DecomNet(input_im, layer_num, channel=64, kernel_size=3):
input_max = tf.reduce_max(input_im, axis=3, keepdims=True)
input_im = concat([input_max, input_im])
with tf.variable_scope('DecomNet', reuse=tf.AUTO_REUSE):
conv = tf.layers.conv2d(input_im, channel, kernel_size * 3, padding='same', activation=None,
name="shallow_feature_extraction")
for idx in range(layer_num):
conv = tf.layers.conv2d(conv, channel, kernel_size, padding='same', activation=tf.nn.relu,
name='activated_layer_%d' % idx)
conv = tf.layers.conv2d(conv, 4, kernel_size, padding='same', activation=None, name='recon_layer')

R = tf.sigmoid(conv[:, :, :, 0:3])
L = tf.sigmoid(conv[:, :, :, 3:4])

return R, L

with tf.Session():
print(A.eval())
# init = tf.global_variables_initializer()
with tf.Session() as sess:

[R_low, I_low] = DecomNet(A, layer_num=5)
# sess.run(tf.global_variables_initializer())
print(R_low.eval())

``````

## 解决方法

``````import tensorflow as tf

A = tf.random_uniform((3, 2, 3, 4), minval=0.0, maxval=4.0, seed=0)
B = tf.ones((3, 2, 3, 4))

def concat(layers):
return tf.concat(layers, axis=3)

def DecomNet(input_im, layer_num, channel=64, kernel_size=3):
input_max = tf.reduce_max(input_im, axis=3, keepdims=True)
input_im = concat([input_max, input_im])
with tf.variable_scope('DecomNet', reuse=tf.AUTO_REUSE):
conv = tf.layers.conv2d(input_im, channel, kernel_size * 3, padding='same', activation=None,
name="shallow_feature_extraction")
for idx in range(layer_num):
conv = tf.layers.conv2d(conv, channel, kernel_size, padding='same', activation=tf.nn.relu,
name='activated_layer_%d' % idx)
conv = tf.layers.conv2d(conv, 4, kernel_size, padding='same', activation=None, name='recon_layer')

R = tf.sigmoid(conv[:, :, :, 0:3])
L = tf.sigmoid(conv[:, :, :, 3:4])

return R, L

with tf.Session():
print(A.eval())
# init = tf.global_variables_initializer()
with tf.Session() as sess:

[R_low, I_low] = DecomNet(A, layer_num=5)
sess.run(tf.global_variables_initializer()) # 这是不报错的关键语句
print(R_low.eval())

``````

THE END

)">