[联邦学习]FedProx算法&&工作原理&&步骤

[联邦学习]FedProx

FedProx(Federalized Proximal Algorithm)是一种在联邦学习(Federated Learning, FL)环境下设计的优化算法,旨在处理数据在不同客户端之间可能存在的不均匀分布(Non-IID Data)的问题。联邦学习是一种机器学习设置,允许多个客户端协作训练一个共享的模型,同时保持数据的隐私和安全,因为数据不需要集中存储或处理。

FedProx是Li Tian等人于2018年(论文链接)所提出的一种针对系统异构性鲁棒的联邦优化算法,发表于MLSys 2020上。它相较于FedAvg主要做出了两点改进:

采样阶段 使用了按数据集大小比例,可放回采样,并直接平均聚合(无加权)来获得无偏梯度估计
本地训练阶段 基于近端项优化的思路,魔改了本地训练的目标函数为

L + μ 2 ∣ ∣ w k , i t − w g l o b a l ∣ ∣ 2 L + \frac{\mu}{2}||w^t_{k,i} - w_{global}||^2 L+2μ∣∣wk,itwglobal2

"采样"指的是服务器从参与方(客户端)的数据集中选择样本进行模型更新。因此,在FedProx中,采样是服务器在每轮迭代中从参与方的数据集中按照每个参与方数据集大小的比例进行选择的过程。具体来说,如果某个参与方的数据集更大,则它在采样中被选中的概率更高。

因此,这里的"采样"是指服务器在联邦学习中选择参与方的过程,而不是指参与方选择自己的数据的过程。

背景和问题

在标准的联邦学习模型中,如FedAvg(Federated Averaging),每个客户端独立地在本地数据上训练模型,然后将更新的模型发送给中央服务器。服务器将这些更新平均合并,以更新全局模型。然而,当不同客户端的数据分布差异很大时(即Non-IID),这种简单的平均可能导致模型性能下降,因为它没有考虑到各客户端更新的差异性。

FedProx的工作原理

FedProx在FedAvg的基础上增加了一个正则化项,这个正则化项惩罚模型参数与全局模型参数之间的偏差。具体来说,FedProx的目标是最小化以下目标函数:

L ( w ) = ∑ k = 1 K n k n ( F k ( w ) + μ 2 ∣ w − w t ∣ 2 ) L(w) = \sum_{k=1}^K \frac{n_k}{n} \left( F_k(w) + \frac{\mu}{2} |w - w^t|^2 \right) L(w)=k=1Knnk(Fk(w)+2μwwt2)

其中:

  • ( w ) 是模型参数。 ( w ) 是模型参数。 (w)是模型参数。
  • ( K ) 是客户端的数量。 ( K ) 是客户端的数量。 (K)是客户端的数量。
  • ( n k ) 是第 ( k ) 个客户端的数据点数。 ( n_k ) 是第 ( k ) 个客户端的数据点数。 (nk)是第(k)个客户端的数据点数。
  • ( n ) 是所有客户端的数据点总数。 ( n ) 是所有客户端的数据点总数。 (n)是所有客户端的数据点总数。
  • ( F k ( w ) ) 是第 ( k ) 个客户端上的损失函数。 ( F_k(w) ) 是第 ( k ) 个客户端上的损失函数。 (Fk(w))是第(k)个客户端上的损失函数。
  • ( μ ) 是正则化参数,控制本地更新与全局模型之间一致性的强度。 ( \mu ) 是正则化参数,控制本地更新与全局模型之间一致性的强度。 (μ)是正则化参数,控制本地更新与全局模型之间一致性的强度。
  • ( w t ) 是当前全局模型的参数。 ( w^t ) 是当前全局模型的参数。 (wt)是当前全局模型的参数。

目标函数的第一部分 ( ∑ k = 1 K p k F k ( w ) ) (\sum_{k=1}^K p_k F_k(w)) (k=1KpkFk(w))代表了在所有客户端上的加权平均损失,这保证了模型优化不只是在某个特定客户端上表现良好,而是在所有参与的客户端上都尽可能有效。

第二部分 ( μ 2 ∣ w − w t ∣ 2 ) (\frac{\mu}{2} |w - w^t|^2) (2μwwt2) 则确保了在每一轮训练中,本地更新的模型参数不会偏离过大于初始的全局模型参数 ( w t ) ( w^t ) (wt),这在通信受限或客户端数据不均匀的情况下尤其重要。

通过这种方式,FedProx算法旨在改善标准联邦学习方法(如FedAvg)在处理非独立同分布(non-IID)数据或客户端之间计算能力、数据量不一致时的效率和效果。

步骤

FedProx的核心是引入了一个正则化项,这个正则化项惩罚模型参数偏离初始全局模型参数的程度,以此来控制模型的本地更新。具体步骤如下:

1.初始化模型:

  • 中央服务器初始化全局模型参数 ( w 0 ) 。 中央服务器初始化全局模型参数 ( w^0 )。 中央服务器初始化全局模型参数(w0)

2.广播模型:

  • 选择一部分客户端,并向这些客户端发送当前的全局模型参数 ( w t ) 。 选择一部分客户端,并向这些客户端发送当前的全局模型参数 ( w^t )。 选择一部分客户端,并向这些客户端发送当前的全局模型参数(wt)

3.本地更新:

  • 每个客户端使用其本地数据集对模型进行训练。与 F e d A v g 不同,训练过程中加入了正则化项 ( μ 2 ∣ w − w t ∣ 2 ) ,其中 ( w ) 是本地更新后的模型参数, ( w t ) 是接收到的全局模型参数,而 ( μ ) 是控制本地更新与全局模型偏离程度的超参数。 每个客户端使用其本地数据集对模型进行训练。与FedAvg不同,训练过程中加入了正则化项 (\frac{\mu}{2} |w - w^t|^2),其中 ( w ) 是本地更新后的模型参数,( w^t ) 是接收到的全局模型参数,而 ( \mu ) 是控制本地更新与全局模型偏离程度的超参数。 每个客户端使用其本地数据集对模型进行训练。与FedAvg不同,训练过程中加入了正则化项(2μwwt2),其中(w)是本地更新后的模型参数,(wt)是接收到的全局模型参数,而(μ)是控制本地更新与全局模型偏离程度的超参数。

  • 这个正则化项有助于控制在数据分布非常不均匀或客户端计算能力不一的情况下,模型更新的稳定性和一致性。 这个正则化项有助于控制在数据分布非常不均匀或客户端计算能力不一的情况下,模型更新的稳定性和一致性。 这个正则化项有助于控制在数据分布非常不均匀或客户端计算能力不一的情况下,模型更新的稳定性和一致性。

4.上传更新:

  • 客户端将本地更新后的模型参数或与全局模型参数的差异发送回中央服务器。 客户端将本地更新后的模型参数或与全局模型参数的差异发送回中央服务器。 客户端将本地更新后的模型参数或与全局模型参数的差异发送回中央服务器。

5.聚合更新(可能有加权):

  • 服务器根据各客户端的贡献(可能是数据量或其他指标)加权聚合这些更新,更新全局模型。 服务器根据各客户端的贡献(可能是数据量或其他指标)加权聚合这些更新,更新全局模型。 服务器根据各客户端的贡献(可能是数据量或其他指标)加权聚合这些更新,更新全局模型。

6.重复:

  • 重复步骤 2 到 5 ,直到模型收敛或完成足够的训练轮次。 重复步骤2到5,直到模型收敛或完成足够的训练轮次。 重复步骤25,直到模型收敛或完成足够的训练轮次。

通过这种方式,FedProx不仅处理了数据分布问题,还通过正则化项减少了模型在训练过程中的波动,使得模型更加健壮,特别是在客户端环境多样化的情况下。这使得FedProx在实际应用中比传统的联邦学习方法更有优势。

主要优势

FedProx的主要优势在于它能够更好地处理客户端间的数据异质性。通过引入正则化项,它鼓励客户端向全局模型更加“温和”地更新,从而减少了因数据非独立同分布(Non-IID)带来的模型波动和性能下降问题。

应用场景

FedProx适用于数据分布极不均匀的联邦学习场景,例如医疗健康数据分析、移动设备上的个性化推荐系统等领域,这些领域中的数据隐私性要求高,同时数据分布在不同设备或组织中往往是非均匀的。