# 莫烦Tensorflow学习笔记（10-12）——构建简单的神经网络及其可视化

## 二、代码详解

``````def add_layer(inputs,in_size,out_size,activation_function=None):
Weights=tf.Variable(tf.random_normal([in_size,out_size]))
biases=tf.Variable(tf.zeros([1,out_size])+0.1)
Wx_plus_b=tf.matmul(inputs,Weights)+biases
if activation_function is None:
outputs=Wx_plus_b
else:
outputs=activation_function(Wx_plus_b)

return outputs``````

### 2、数据导入和神经网络建立

``````x_data=np.linspace(-1,1,300)[:,np.newaxis]
noise=np.random.normal(0,0.05,x_data.shape)
y_data=np.square(x_data)-0.5+noise

xs=tf.placeholder(tf.float32,[None,1])
ys=tf.placeholder(tf.float32,[None,1])
#activation function is a relu function

loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1]))
#loss function is based on MSE(mean square error)

init=tf.global_variables_initializer() #initialize all variables
sess=tf.Session()
sess.run(init) #very important ``````

### 3、神经网络训练可视化

``````%matplotlib #在导入库的时候声明

fig=plt.figure()
ax.scatter(x_data,y_data)
plt.ion()
plt.show()

for i in range(2000):
sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
if i%50==0:
try:
plt.pause(0.5)
except Exception:
pass

try:
ax.lines.remove(lines[0])
plt.show()
except Exception as e:
pass

prediction_value=sess.run(prediction,feed_dict={xs:x_data})
lines=ax.plot(x_data,prediction_value,'r-',lw=10)
``````

THE END