Press "Enter" to skip to content

联邦学习理解及实现

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

FedAvg以及FedSDG的实现

 

谈到联邦学习,最重要的一个点是梯度融合时的策略。如何巧妙地融合各个客户端之间的梯度,以保证最优的模型效果,是联邦学习中需要着重研究的问题。FedAvg是一个经典且简单的策略。

 

一句话描述:所有client的梯度取平均值得到最终梯度。具体证明如下。

 

机器学习问题建模

 

首先,对于一切损失函数非凸的机器学习问题(例如神经网络),都可以表示为以下式子:

是目标函数,
是参数。

是第i个数据的损失,或者说第i个数据的代价。上面式子的意思是:一切非凸机器学习问题,都是一个最小化目标函数的问题,这个目标函数是由每个数据的损失平均贡献的。

 

进一步细化,

,这表示
在参数
贡献了
的损失。如果我们有
个用户,每个用户贡献的数据为
,其集合内元素数量

,那幺我们可以改写上面的式子为:

 

如果数据呈IID,即独立同分布,那幺根据期望的性质可以得到:

 

当然,这个独立同分布不一定能实现,所以需要后面进一步讨论。

 

SGD算法

 

SGD算法全称是:stochastic gradient descent,随机梯度下降法,是一种常见的梯度下降算法。

 

和SGD相对的算法是BGD:batch gradient descent,批量梯度下降法BGD。BGD在每次更新时,要用到所有的样本来得到一个标准梯度,然后沿着这个梯度更新。因此对于凸优化问题BGD肯定可以得到一个收敛的解。而SGD则是每次取一个样本,来代替整个样本集合进行梯度下降,这样虽然不是每次迭代得到的损失函数都向着全局最优方向, 但是大的整体的方向是向全局最优解发展的。还有一种mini-batch梯度下降,是这两个方法的折中。

 

根据一些证明,SGD和BGD都能收敛,所以都是可用的。因为联邦学习是分布式的,所以肯定只能用SGD。

 

基线算法FederatedSGD

 

定义一个值C:每次参与联邦学习聚合的client数量占总client数量的比例。当C=1时,代表全员参与聚合。FedSGD就是在C=1时的一个基线算法,也就是每次让所有client参与,把本地所有的数据进行训练,在本地只训练一次,然后进行聚合(说实话我很迷惑,这不应该叫FedBGD吗)。

 

聚合时的操作是这样的:

 

首先要有一个固定学习律

,然后每个用户
计算自己的损失变化:
,其中
代表客户端此时的模型参数。服务端收收集损失变化,利用

的平均值对整个模型进行更新:

 

因为

 

的变化可以这幺表示:

 

所以,这个式子也可以写为:

 

然后把所有全加起来,再把

 

那一项换掉,可以得到:

 

上面这个式子就很清晰了,新的模型=每个设备的权重*每个设备的模型。在聚合之前,每个设备自己还可以自娱自乐,自己迭代多轮:

 

至此,FedSGD的操作就介绍完毕了,其实就是个求均值。

 

Federated Averaging

 

上面介绍了FedSGD,然而FedSGD其实是FedAvg的特殊情况。

 

我们定义三个参数:

:每轮参与联邦学习聚合的client数量占总client数量的比例。
:每个client在本地的训练次数(即自娱自乐的次数)。

:每个client在本地训练时的BatchSize。

 

然后FedAvg可以表示为如下伪代码:

 

Server executes:
initialize

for each round

do

 

for each client

 

in parallel do

 

Client

// Run on client
split
into batches of size

for each local epoch

from 1 to
do for batch

do

 

return

 

to server

 

可以看出,当C=B=1,且B为无穷大时,FedAvg与FedSGD一样。

 

至此,FedAvg的原理也介绍完毕了。

 

HierFAVG策略

 

具体请见这篇论文:
https://arxiv.org/abs/1905.06641

 

目前来说,FedAvg是基于云的,client连到云上,然后进行FL。这里有一个巨大缺陷:设备连接时产生了巨大的网络资源消耗,与云服务器的连接也不一定稳定,一旦断掉很麻烦。所以,可以引入边缘计算来解决这个问题。但是边缘节点毕竟接入量有限,可能导致训练性能的大量损失。

 

因此,上面那篇论文提出了一种边云协同策略,具体来说就是:

 

 

    1. 首先,client先训练,然后把参数上传到边缘

 

    1. 边缘进行聚合,当边缘聚合k轮后,把聚合好的数据上传到云。

 

    1. 云总共聚合n轮。

 

 

所以总共训练n*k论,根据那个文章,训练性能还不错。

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注