Tensor

Tensor 和 Function 互相连接生成了一个无圈图(acyclic graph)，它编码了完整的计算历史。每个tensor都有一个 grad_fn 属性，该属性引用了创建 Tensor 自身的Function(除非这个tensor是手动创建的， grad_fn 是 None )。

``````import torch
print(x)
``````

``````tensor([[1., 1., 1.],
``````

``````y = x + 2
print(y)
``````

``````tensor([[3., 3.],
``````

``````z = y *  2
out = z.mean()

print(z, out)
``````

``````tensor([[6., 6.],
``````

`.requires_grad_()` 原地改变了现有tensor的 `requires_grad` 标志。如果没有指定的话，默认输入的这个标志是 False。

``````a = torch.randn(2, 2)
a = ((a * 2) / (a - 2))
b = (a * a).sum()
``````

``````False
True
<SumBackward0 object at 0x0000021F46A2AA00>
``````

梯度

``````out.backward()
``````

``````print(x.grad)
``````

``````tensor([[0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333]])
``````

`torch.autograd` 是计算雅可比向量积的一个“引擎”。雅可比向量积的特性使外部梯度输入到具有非标量输出的模型中变得非常方便。

``````x = torch.randn(2, requires_grad=True)
print(x)
y = x * 2
while y.data.norm() < 1000:
y = y * 2
print(y)
``````

``````tensor([-1.2877, -0.5659], requires_grad=True)
``````

``````v = torch.tensor([1, 1], dtype=torch.float)
y.backward(v)
``````

``````tensor([1024., 1024.])
``````

``````print(x)
``````

``````tensor([-1.2877, -0.5659], requires_grad=True)
True
True
False
``````

保存Tensor

``````import torch
y = x.pow(2)
``````

``````True
True
``````

``````x = torch.randn(2, requires_grad=True)
y = x.exp()
``````

``````True
False
``````

局部禁用梯度计算

Python 有几种机制可以在本地禁用梯度计算：

evaluation mode ( nn.Module.eval())方法不是用于禁用梯度计算，但经常因为名字与禁止梯度计算相混淆。

`​requires_grad`是一个标志，默认为 false除非包含在 `nn.Parameter`中，它允许从梯度计算中细粒度地排除子图。它在向前和向后传播中都生效：

``````#定义一个在不同线程中使用的train函数
def train_fn():
# forward
y = (x + 5) * (x + 5) * 0.1
# backward
y.sum().backward()
# 优化器更新

# 编写线程代码来驱动train_fn
for _ in range(10):
p.start()

p.join()
``````[<Thread(Thread-5, stopped 11328)>, <Thread(Thread-6, stopped 13548)>, <Thread(Thread-7, stopped 14440)>, <Thread(Thread-8, stopped 12720)>, <Thread(Thread-9, stopped 2416)>, <Thread(Thread-10, stopped 3820)>, <Thread(Thread-11, stopped 10688)>, <Thread(Thread-12, stopped 4620)>, <Thread(Thread-13, stopped 2200)>, <Thread(Thread-14, stopped 14916)>]