1.引言

2.神经网络搭建

2.1 准备工作

2.2 搭建网络

2.3 训练网络

3.效果

4. 完整代码

# 2.神经网络搭建

## 2.1 准备工作

``````import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1)
y = x.pow(3) + 0.2 * torch.rand(x.size())``````

既然是拟合，我们当然需要一些数据啦，我选取了在区间  内的100个等间距点，并将它们排列成三次函数的图像。

## 2.2 搭建网络

``````class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)

def forward(self, x):
x = F.relu(self.hidden(x))
return self.predict(x)

net = Net(1, 20, 1)
print(net)
loss_func = torch.nn.MSELoss()``````

## 2.3 训练网络

``````for t in range(2000):
prediction = net(x)
loss = loss_func(prediction, y)
loss.backward()
optimizer.step()``````

# 3.效果

``````for t in range(2000):
prediction = net(x)
loss = loss_func(prediction, y)
loss.backward()
optimizer.step()
if t % 5 == 0:
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy(), s=10)
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2)
plt.text(2, -100, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 10, 'color': 'red'})
plt.pause(0.1)
plt.ioff()
plt.show()``````

# 4. 完整代码

``````import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1)
y = x.pow(3) + 0.2 * torch.rand(x.size())
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)

def forward(self, x):
x = F.relu(self.hidden(x))
return self.predict(x)
net = Net(1, 20, 1)
print(net)
loss_func = torch.nn.MSELoss()
plt.ion()
for t in range(2000):
prediction = net(x)
loss = loss_func(prediction, y)
loss.backward()
optimizer.step()
if t % 5 == 0:
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy(), s=10)
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2)
plt.text(2, -100, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 10, 'color': 'red'})
plt.pause(0.1)
plt.ioff()
plt.show()
``````

THE END