Pytorch 中的 with torch.no_grad(): 详解

1 python 中的 with 用法

with 语句适用于对资源进行访问的场合, 它会自动进行类似于 '清理内存' 的操作, 进行资源的释放, 类似于 file.close(), 以防止在程序运行中, 由于忘记关闭文件而导致程序占用的内存越来越大, 程序的运行时间越来越长

with 用法:
with ...:
    ...

2 torch.no_grad()

2.1 背景知识

  • 在pytorch中,tensor有一个requires_grad参数,如果设置为True,则反向传播时,该tensor就会自动求导。tensor的requires_grad的属性默认为False,若一个 tensor自己创建的tensor requires_grad被设置为True,那么所有依赖它的节点requires_grad都为True(即使其他相依赖的tensor的requires_grad = False),当requires_grad设置为False时,反向传播时就不会自动求导了,因此大大节约了显存或者说内存

  • 除此之外, 当进行神经网络的参数更新的过程中, 我们不希望将有些的参数进行梯度下降进行更新, 也可以使用 torch.no_grad()

  • 使用 with torch.no_grad(): 条件下计算的新tensor, 即使以前的 tensor requires_grad() = True, 计算的结果也是 requires_grad() = False

import torch

x = torch.randn(2, 2, requires_grad = True)
y = torch.randn(2, 2, requires_grad = True)
with torch.no_grad():
    z = x + y

print(z.requires_grad)
print(z.grad)
print(z.grad_fn)

>>False
>>None
>>None

由上述结果我们可以发现, 在 z.gradz.grad_fn 的结构都为 None, 因此也解释了为什么在代码中我们经常看到以下的条件句:

if parameter.grad is not None:
    ...
本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>