Pytorch 中的 with torch.no_grad(): 详解
1 python 中的 with 用法
with 语句适用于对资源进行访问的场合, 它会自动进行类似于 '清理内存' 的操作, 进行资源的释放, 类似于 file.close(), 以防止在程序运行中, 由于忘记关闭文件而导致程序占用的内存越来越大, 程序的运行时间越来越长
with 用法:
with ...:
...
2 torch.no_grad()
2.1 背景知识
-
在pytorch中,tensor有一个requires_grad参数,如果设置为True,则反向传播时,该tensor就会自动求导。tensor的requires_grad的属性默认为False,若一个 tensor自己创建的tensor requires_grad被设置为True,那么所有依赖它的节点requires_grad都为True(即使其他相依赖的tensor的requires_grad = False),当requires_grad设置为False时,反向传播时就不会自动求导了,因此大大节约了显存或者说内存
-
除此之外, 当进行神经网络的参数更新的过程中, 我们不希望将有些的参数进行梯度下降进行更新, 也可以使用 torch.no_grad()
-
使用 with torch.no_grad(): 条件下计算的新tensor, 即使以前的 tensor requires_grad() = True, 计算的结果也是 requires_grad() = False
import torch
x = torch.randn(2, 2, requires_grad = True)
y = torch.randn(2, 2, requires_grad = True)
with torch.no_grad():
z = x + y
print(z.requires_grad)
print(z.grad)
print(z.grad_fn)
>>False
>>None
>>None
由上述结果我们可以发现, 在 z.grad 和 z.grad_fn 的结构都为 None, 因此也解释了为什么在代码中我们经常看到以下的条件句:
if parameter.grad is not None:
...