Personalized Federated Learning using Hypernetworks 论文阅读笔记+代码解读

论文地址点这里

一. 介绍

联邦学习是在多个不相交的本地数据集上学习模型的任务,由于隐私、存储等问题而无法共享本地数据集。但当数据分布在不同的客户端时,学习单个全局模型可能会失败。为了处理这种跨客户端的异构性,个性化联邦学习使每一个客户端能自行调整。
pFedHN使用超网络(Hypernetwork)来为每个输入产生身影网络参数。每个客户端都有一个唯一的嵌入向量,该向量作为输入传递给超网络,因此绝大多书参数在客户机之间共享。
使用超网络的另一个好处是,超级网络的训练参数向量永远不会被传输,每一个客户端只需要接受自己的网络参数进行预测和梯度计算,超网络只需要接受梯度即可优化自身参数。

二. 相关工作

联邦学习: 针对数据隐私问题、通信问题等方面,各个客户端无法进行数据传输,只能进行参数等传输。联邦学习以FedAvg最为出名,但是这样所有客户端将会学到一个全局的模型,而无法满足个性化。
联邦个性化学习: 联邦学习设置提出了许多挑战,包括数据异构,设备异构。特别是数据异构使得学习一个适用于所有客户端的共享全局模型变得非常困难。这有很多方法,例如基于MAML的元学习个性化学习,个性化层实现。
超网络: 超网络由Klein等人提出的深度神经网络,其输出为学习任务的另一目标网络的权值。其思想是输出的权值随超网络的输入而变化。

三. 方法

3.1 联邦个性化问题表示

对于一个个性化联邦学习来说,每一个客户端有自己的参数

θ

i

theta_i

θi,数据集分布

P

i

mathcal{P}_i

Pi对应m个数据例子

S

i

=

{

(

x

j

(

i

)

,

y

j

(

i

)

)

}

i

=

1

m

i

mathcal{S_i}={(x_j^{(i)},y_j^{(i)})}_{i=1}^{m_i}

Si={(xj(i),yj(i))}i=1mi,因此我们使用

L

i

(

θ

i

)

=

1

m

i

j

l

i

(

x

i

,

y

i

;

θ

i

)

mathcal{L}_i(theta_i) = frac{1}{m_i}sum_jl_i(x_i,y_i;theta_i)

Li(θi)=mi1jli(xi,yi;θi),来表示一个客户端的损失,那么我们的联邦个性化学习的优化目标为:

Θ

=

a

r

g

min

Θ

1

n

i

=

1

n

E

x

,

y

P

i

[

l

i

(

x

j

,

y

j

;

θ

i

)

]

Theta^*=argmin_{Theta}frac{1}{n}sum_{i=1}^n mathbb{E}_{x,y~mathcal{P_i}}[l_i(x_j,y_j;theta_i)]

Θ=argΘminn1i=1nEx,yPi[li(xj,yj;θi)]
对于训练来说优化目标为:

a

r

g

min

θ

1

n

i

=

1

n

L

i

(

θ

i

)

=

a

r

g

min

θ

1

n

i

=

1

n

1

m

i

j

=

1

m

i

[

l

i

(

x

j

,

y

j

;

θ

i

)

]

arg min_{theta}frac{1}{n}sum_{i=1}^nmathcal{L}_i(theta_i)=arg min_{theta}frac{1}{n}sum_{i=1}^nfrac{1}{m_i}sum_{j=1}^{m_i}[l_i(x_j,y_j;theta_i)]

argθminn1i=1nLi(θi)=argθminn1i=1nmi1j=1mi[li(xj,yj;θi)]

3.2 联邦超网络

超网络根据他的输入输出另一个网络的权值。我们以

h

(

.

;

φ

)

h(.;varphi)

h(.;φ)表示我们的超网络。以

f

(

.

;

θ

)

f(.;theta)

f(.;θ)表示我们的目标网络(也就是分类的网络)。超网络在服务器部署,每一个客户端通过传给服务器嵌入向量v来获取到对应的参数,训练完之后将自己网络的参数

θ

theta

θ的 梯度传回给服务端,具体图片如下:
在这里插入图片描述
因此我么变化我们对应的优化目标为:

a

r

g

min

φ

,

v

1

,

.

.

.

.

v

n

1

n

i

=

1

n

L

i

(

h

(

v

i

;

φ

)

)

arg min_{varphi,v_1,....v_n}frac{1}{n}sum_{i=1}^nmathcal{L}_i(h(v_i;varphi))

argφ,v1,....vnminn1i=1nLi(h(vi;φ))
通过使用链式法则可以计算出

φ

L

i

=

(

φ

θ

i

)

T

θ

i

L

i

nabla_{varphi}mathcal{L}_i = (nabla_{varphi}theta_i)^Tnabla_{theta_i}mathcal{L_i}

φLi=(φθi)TθiLi,因此我们只需要在客户端计算

θ

theta

θ的梯度传回给服务端即可。
作者这里用一种更通用的规则来进行更新,通过使用

Δ

θ

i

=

θ

i

~

θ

Delta theta_i = widetilde{theta_i}-theta

Δθi=θi

θ替换掉

θ

theta

θ的梯度(这里

θ

i

~

widetilde{theta_i}

θi

对应为在客户端经过训练k个epoch后更新的结果)。
训练过程如下图所示:
在这里插入图片描述
最后作者还说了,利用超网络的时候应该只对每个客户端目标网络使用特征层的权值输出,避免客户端任务之间的无关性等问题。

四. 关键代码解读

作者的github代码点这里,我们针对上面的流程进行讲解。
首先是构造hypernetwork和target network
对于hypernetwork,我们接受每一个客户端传入的向量v,生成target network对应的权重,如下:

class CNNHyperPC(nn.Module):
    def __init__(
            self, n_nodes, embedding_dim, in_channels=3, out_dim=10, n_kernels=16, hidden_dim=100,
            spec_norm=False, n_hidden=1):
        super().__init__()

        self.in_channels = in_channels
        self.out_dim = out_dim
        self.n_kernels = n_kernels
        self.embeddings = nn.Embedding(num_embeddings=n_nodes, embedding_dim=embedding_dim)

        layers = [
            spectral_norm(nn.Linear(embedding_dim, hidden_dim)) if spec_norm else nn.Linear(embedding_dim, hidden_dim),
        ]
        for _ in range(n_hidden):
            layers.append(nn.ReLU(inplace=True))
            layers.append(
                spectral_norm(nn.Linear(hidden_dim, hidden_dim)) if spec_norm else nn.Linear(hidden_dim, hidden_dim),
            )

        self.mlp = nn.Sequential(*layers)

        self.c1_weights = nn.Linear(hidden_dim, self.n_kernels * self.in_channels * 5 * 5)
        self.c1_bias = nn.Linear(hidden_dim, self.n_kernels)
        self.c2_weights = nn.Linear(hidden_dim, 2 * self.n_kernels * self.n_kernels * 5 * 5)
        self.c2_bias = nn.Linear(hidden_dim, 2 * self.n_kernels)
        self.l1_weights = nn.Linear(hidden_dim, 120 * 2 * self.n_kernels * 5 * 5)
        self.l1_bias = nn.Linear(hidden_dim, 120)
        self.l2_weights = nn.Linear(hidden_dim, 84 * 120)
        self.l2_bias = nn.Linear(hidden_dim, 84)

        if spec_norm:
            self.c1_weights = spectral_norm(self.c1_weights)
            self.c1_bias = spectral_norm(self.c1_bias)
            self.c2_weights = spectral_norm(self.c2_weights)
            self.c2_bias = spectral_norm(self.c2_bias)
            self.l1_weights = spectral_norm(self.l1_weights)
            self.l1_bias = spectral_norm(self.l1_bias)
            self.l2_weights = spectral_norm(self.l2_weights)
            self.l2_bias = spectral_norm(self.l2_bias)

    def forward(self, idx):
        emd = self.embeddings(idx)
        features = self.mlp(emd)

        weights = {
            "conv1.weight": self.c1_weights(features).view(self.n_kernels, self.in_channels, 5, 5),
            "conv1.bias": self.c1_bias(features).view(-1),
            "conv2.weight": self.c2_weights(features).view(2 * self.n_kernels, self.n_kernels, 5, 5),
            "conv2.bias": self.c2_bias(features).view(-1),
            "fc1.weight": self.l1_weights(features).view(120, 2 * self.n_kernels * 5 * 5),
            "fc1.bias": self.l1_bias(features).view(-1),
            "fc2.weight": self.l2_weights(features).view(84, 120),
            "fc2.bias": self.l2_bias(features).view(-1),
        }
        return weights

首先经过nn.Embedding层,这一层主要是进行向量编码工作(具体大家可以取搜一搜,也就是给每一个词向量对应编码一次),之后经过基层Linear隐藏层(这里作者有写到spectral_norm对应为频谱归一化,但训练时并未使用,应该是更好的拟合),经过这些后对我们target的每一层的参数对应的weight和bias进行输出。
接下来对应的是目标模型的特征提取部分,和上面的输出的参数一一对应

class CNNTargetPC(nn.Module):
    def __init__(self, in_channels=3, n_kernels=16):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, n_kernels, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(n_kernels, 2 * n_kernels, 5)
        self.fc1 = nn.Linear(2 * n_kernels * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

最后是一个输出层,这一层是每一个模型自带的,不参与超网络

class LocalLayer(nn.Module):

    def __init__(self, n_input=84, n_output=2, nonlinearity=False):
        super().__init__()
        self.nonlinearity = nonlinearity
        layers = []
        if nonlinearity:
            layers.append(nn.ReLU())

        layers.append(nn.Linear(n_input, n_output))
        self.layer = nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)

网络定义好后就可以开始训练
首先是每一个客户端产生自己的标识符传到hypernetwork获得参数。其实很简单,每个客户端产生的标号即为标识符(例如1号客户端被选中训练,则传入tensor[1]到hypernetwork)

node_id = random.choice(range(num_nodes))
# produce & load local network weights
weights = hnet(torch.tensor([node_id], dtype=torch.long).to(device))
net.load_state_dict(weights)

之后客户端自己训练更新自己的参数
(clip_grad_norm_为修剪梯度的,防止梯度爆炸)

for i in range(inner_steps):
    net.train()
    inner_optim.zero_grad()
    optimizer.zero_grad()
    nodes.local_optimizers[node_id].zero_grad()

    batch = next(iter(nodes.train_loaders[node_id]))
    img, label = tuple(t.to(device) for t in batch)

    net_out = net(img)
    pred = nodes.local_layers[node_id](net_out)

    loss = criteria(pred, label)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(net.parameters(), 50)
    inner_optim.step()
    nodes.local_optimizers[node_id].step()

这个时候我们来求超网络的参数梯度进行更新

Δ

θ

i

=

θ

i

~

θ

Delta theta_i = widetilde{theta_i}-theta

Δθi=θi

θ来作为目标网络的更新
这里用到torch,autograd.grad,这个函数是计算梯度的,由于我们的输出为一个向量而不是标量,因此grad_outputs不能为None,而这里对应的就是

Δ

θ

i

=

θ

i

~

θ

Delta theta_i = widetilde{theta_i}-theta

Δθi=θi

θ

final_state = net.state_dict()
delta_theta = OrderedDict({k: inner_state[k] - final_state[k] for k in weights.keys()})

# calculating phi gradient
hnet_grads = torch.autograd.grad(
    list(weights.values()), hnet.parameters(), grad_outputs=list(delta_theta.values())
)

# update hnet weights
for p, g in zip(hnet.parameters(), hnet_grads):
    p.grad = g

torch.nn.utils.clip_grad_norm_(hnet.parameters(), 50)
optimizer.step()

到这里训练过程就介绍完了,整体上也不算复杂,当然对应上面的算法发现v咋没变化,其实是变化了的,这里的v实际对应每个标识符输入到hypernetwork的第一个Embedding得出来的值,而Embedding作为参数自然是更新了。

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码

)">
< <上一篇
下一篇>>