【pytorch】冻结网络踩坑

普通conv和fc层的冻结方式:

# 冻结参数
for i, p in enumerate(self.model.parameters()):
    if i <= 66:
        p.requires_grad = False


# 验证一下是否成功冻结参数
for k, v in self.model.named_parameters():
    print("k:{} v:{} ".format(k, v.requires_grad))

注意:model.parameters()都在梯度回传的更新过程中,所以可以用param.requires_grad = False的方式冻结,但是对于一些BN层的参数,比如BN层的runing_mean和runing_var,这两个值是前向计算统计得来的,并没有在梯度回传的更新过程中。所以,param.requires_grad=False对它们不起任何作用!

踩坑:

我的目的:在共用一个主干网络的多任务学习中,完全冻结其中一个表现较好的任务1分支,只训练其他两个任务:任务2分支和任务3分支。

结果:我以为用 “param.requires_grad=False” 的方式可以冻结任务1分支的所有参数,然后我发现我错了,冻结完,在验证过程中,我发现任务1的表现居然变差了。

验证:打印参数值,发现任务1的卷积层和全连接层参数不变(被成功冻结),只有BN层的runing_mean和runing_var发生了改变(未被冻结),应该就是他们的问题。

冻结BN层的runing_mean和runing_var的方法可以参照:

def fix_bn(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()

model = models.resnet50(pretrained=True)
model.cuda()
model.train()
model.apply(fix_bn)  # fix batchnorm

总结:当需要保持某一分支的分类性能不变时(我的意思是后续都不对这个分支进行训练了,只拿来验证和测试),除了要冻结可回传梯度的权重值,还要冻结上述BN层的值。当然如果网络还要继续训练的话,也可以不冻结BN层的runing_mean和runing_var。如果冻结网络后某一分支的性能突然变差,可以考虑一下试试冻结BN层的runing_mean和runing_var~

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>