MAML论文阅读笔记–回归实验

1.背景

   MAML是元学习领域的一篇经典文章。元学习(Meta-learning)与机器学习算法不同,不是先人为调参,然后在特定训练任务下训练模型,而是希望模型获取一种学会学习调参的能力,使其在新任务的小样本集上快速学习新任务。所以,深度学习模型有哪些需要人为确定的元素(初始化参数、网络结构、优化器等),不同的元学习就可以不同的元素,相应在元学习领域就有不同的研究领域。
   目前元学习可以学习预处理数据集 ,学习初始化参数,学习网络结构,学习选择优化器。MAML是学习初始化模型参数的一篇经典论文,其中包含三个实验,从监督回归、监督分类和强化学习的角度验证MAML在不同领域初始化参数的能力。

2.算法

   少样本学习在监督任务领域得到了很好的研究,其目标是从该任务的少数输入/输出对中学习一个新函数,使用来自类似任务的输入数据进行元学习。同样地,在少样本回归中,目标是在对许多具有相似统计特性的函数进行训练后,仅从该函数采样的少数数据点预测连续值函数的输出。
   用于监督分类和回归的两个常见损失函数是交叉熵和均方误差(MSE),公式(2)为均方误差损失函数,公式(3)为交叉熵损失函数。
均方误差
交叉熵
   MAML监督回归和分类算法详情见算法2:

在这里插入图片描述

3.回归实验

3.1 问题分析

  目标是利用少量样本回归一个正弦函数。每个任务都涉及到从输入回归到一个正弦波的输出,其中正弦波的振幅和相位在不同的任务之间是不同的。振幅在[0.1,5.0]范围内变化,相位在[0,π]范围内变化,输入和输出的维数都为1。在训练和测试过程中,数据点x从[−5.0,5.0]中均匀采样。

3.2 参数设置

  损失函数是预测值f(x)和真实值之间的均方误差。回归模型是一个神经网络模型,有2个神经元为40的隐藏层,使用ReLU激活函数。当使用MAML进行训练时,我们使用K=10 (在K-shot回归任务中,为每个任务提供K个输入/输出对进行学习) 示例,固定步长α=0.01,并使用Adam作为元优化器。baseline(最普遍的情况)也同样使用Adam训练。为了评估性能,我们通过变化不同数量的K值调整一个元学习模型,并比较性能两个baseline:(a)预训练的任务,这需要训练网络回归随机正弦函数,然后在测试时,对提供的K点使用自动调整步长通过梯度下降进行微调;(b)甲骨文接收真正的振幅和相位作为输入。在附录C中,我们展示了与其他多任务和自适应方法的比较。(见参考文献1)

3.3 实验结果

  我们在K为{5,10,20}数据点上评估了通过MAML微调模型和预训练模型的性能。在微调过程中,使用相同的K个数据点计算每个梯度步骤。定性结果如图2所示,
在这里插入图片描述
进一步拓展在附录B显示,只有5数据点时学习模型能够快速适应,图中显示为紫色三角形,而在所有任务上使用标准监督学习预训练的模型不能在保证没有过拟合的情况下使用如此少的数据点充分适应。至关重要的是,当K个数据点都在输入范围的一半时,用MAML训练的模型仍然可以推断出另一半范围内的振幅和相位,说明用MAML训练的模型f已经学会了模拟正弦波的周期性。此外,我们在定性和定量结果(图3和附录B)中观察到,
在这里插入图片描述

尽管训练在一个梯度步后获得了很好的性能,用MAML学习到的模型随着额外的梯度步骤继续改进。这一改进表明,MAML优化了参数,使它们位于易于快速适应的区域,并且对p(T)的损失函数敏感,如2.2节中讨论的,而不是对仅在一步后改善的参数θ进行过拟合。

附录B
在图6中,我们展示了经过10次学习训练的MAML模型的完整定量结果,并在5、10和20次进行评估。在图7中,我们展示了MAML和预训练的基线在随机采样的正弦曲线上的定性性能。
在这里插入图片描述
在这里插入图片描述

参考

[1] MAML论文

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇

)">
下一篇>>