赞
踩
联邦学习原始论文中给出的FedAvg的算法框架为:
参数介绍:
K
K
K表示客户端的个数,
B
B
B表示每一次本地更新时的数据量,
E
E
E表示本地更新的次数,
η
\eta
η表示学习率。
首先是服务器执行以下步骤:
对每一个本地客户端来说,要做的就是更新本地参数,具体来讲:
我们仔细观察server的最后一步:
w
t
+
1
=
∑
k
=
1
K
n
k
n
w
t
+
1
k
w_{t+1}=\sum_{k=1}^{K}\frac{n_k}{n}w_{t+1}^k
wt+1=k=1∑Knnkwt+1k
也就是说,虽然我们只是对
m
m
m个客户端进行本地训练更新得到了其对应的
w
t
+
1
k
w_{t+1}^k
wt+1k,但最终我们却对所有
K
K
K个客户端进行了聚合。
那么针对聚合,就有以下两种情况。
服务器端每次将新的全局模型发送给全部客户端,并且聚合全部客户端的模型参数。如果客户端未被选中,那么一轮通信结束后,该客户端的模型为一轮通信开始时从服务器获得的初始模型。
设当前全局模型为
w
t
w_t
wt,服务器选中了
m
m
m个客户端(集合
V
V
V),
m
m
m个客户端本地更新完毕后,服务器端的聚合公式为:
w
t
+
1
=
∑
k
∈
V
n
k
n
w
t
+
1
k
+
∑
k
∉
V
n
k
n
w
t
w_{t+1}=\sum_{k \in V}\frac{n_k}{n}w_{t+1}^k+\sum_{k\notin V}\frac{n_k}{n}w_t
wt+1=k∈V∑nnkwt+1k+k∈/V∑nnkwt
也就是说,每一次聚合时服务器端都将所有客户端的模型考虑在内。
服务器每次只是将当前新的参数传递给被选中的模型,并且只是聚合被选中客户端的模型参数。
设当前全局模型为
w
t
w_t
wt,服务器选中了
m
m
m个客户端(集合
V
V
V),然后将
w
t
w_t
wt只发送给这
m
m
m个客户端。
m
m
m个客户端训练完毕后,服务器端的聚合公式为:
w
t
+
1
=
∑
k
∈
V
n
k
n
w
t
+
1
k
w_{t+1}=\sum_{k \in V}\frac{n_k}{n}w_{t+1}^k
wt+1=k∈V∑nnkwt+1k
虽然原始论文中对所有 K K K个客户端都进行了聚合,但在真正实现时,感觉用第二种会更好一点,因为如果客户端数量很庞大,每一次通信都会有不小的代价,用第二种会明显降低通信成本。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。