P7:神经网络的基本骨架 — nn.Module的使用
1、创建神经网络
在创建的时候,我们需要继承torch.nn.Module这个类,之后我们需要创建两个函数:
(1)__init__()函数;
(2)forward()函数。
代码如下:
import torch
from torch import nn
class Tudui(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, input):
output = input + 1
return output
2、调用神经网络
tudui = Tudui()
x = torch.tensor(1.0)
output = tudui(x)
print(output)