赞
踩
联邦学习以轮为单位,每轮包含设备选择、参数分发、本地更新和全局更新这4个步骤
服务器端选择该轮参与训练的设备,在设备选择阶段,有的设备可能处于离线状态,需要选择在线的设备,并且设备的电量、网络状况等符合一定的要求,被选中的设备参与该轮训练。
服务器将当前的模型参数分发给选中的设备。
本地设备下载到服务器发来的新的模型参数后,在此基础上,用本地的数据训练更新模型。
设备把本地更新了的模型发到服务器端,服务器按照一定规则进行聚合,对全局模型进行更新
&emspl&emspl;在满足一定的收敛规则时,停止迭代
联邦学习的算法,源自分布式机器学习。2015年CCS中提出了分布选择随机梯度下降(DSSGD)算法,它是一个异步的协议,分为下载、训练和上传三个阶段。下载阶段,客户端可以选择一部分参数来进行更新本地模型;训练阶段,客户端在本地进行训练;上传阶段,客户端可以选择本地模型的一部分参数上传给服务器。每个客户端完成一次训练,马上将最新的参数选择一部分进行上传,服务器立即更新全局模型,然后进行广播。
之前的DSSGD存在的问题是通信量巨大,而且是异步的,不能用于很多用户的场景。2017年谷歌的Mcmahan等提出了FedAvg算法,是一个同步的协议,全局更新的每一轮可以有上百个客户端,进行加权平均,是目前主流的联邦学习算法。
在上面的算法中,所有的梯度都是以明文的形式给出的,然而,从梯度会泄露用户的个人信息,在最新的NeurIPS 2019中,《Deep Leakage from Gradients》一文指出,从梯度可以推断出原始的训练数据,包括图像和文本数据。谷歌的Bonawitz等人,提出了安全聚合SMPC加密方案,服务器只能看到聚合完成之后的梯度,不能知道每个用户的私有的真实梯度值。
在说安全聚合SMPC前,先说说一个常用的密钥交换协议——DH密钥交换。DH密钥交换的目的,是让想要通信的Alice、Bob双方,他们之间能够拥有一个私密的密钥,这个密钥只有A和B两个人知道。DH密钥交换包含如下步骤:
1. 首先,Alice和Bob商量好DH的参数,一个大数素数
P
\mathcal{P}
P,和
z
p
\mathbb{z}_p
zp上的一个生成元
G
G
G(
1
<
G
<
P
1<G<\mathcal{P}
1<G<P,一种比较特殊的素数)
2. Alice和Bob都各自产生一个随机数,A和B是Alice和Bob的私钥
3. Alice和Bob分别计算
G
A
=
G
A
(
m
o
d
P
)
G^A=G^A(mod \mathcal{\;P})
GA=GA(modP)和
G
B
=
G
B
(
m
o
d
P
)
G^B=\mathcal{G}^B(mod \mathcal{\;P})
GB=GB(modP),
G
A
G^A
GA和
G
B
G^B
GB是Alice和Bob的公钥。(由公钥推导出私钥是困难的)
4. Alice和Bob分别将公钥发送给对方
5. Alice收到Bob发来的他的公钥
G
B
G^B
GB,计算出用来和Bob秘密通信的密钥
s
A
B
=
(
G
B
)
A
(
m
o
d
P
)
s_{AB}=(G^B)^A(mod \mathcal{\;P)}
sAB=(GB)A(modP);同理Bob收到alice发来的他的公钥
G
A
G^A
GA,计算出用来和Bob秘密通信的密钥
s
B
A
=
(
G
A
)
B
(
m
o
d
P
)
s_{BA}=(G^A)^B(mod \mathcal{\;P)}
sBA=(GA)B(modP)。显然
s
A
B
=
s
B
A
s_{AB}=s_{BA}
sAB=sBA是相等的,他们在公开环境中,可以通过密钥建立私有通信通道,使用该密钥来加密消息。
联邦学习中的安全聚合是基于安全多方计算的,安全多方计算是基于秘密分享的,秘密分享由1978年被Shamir提出(RSA中的S)。
L e m m a r 1 \mathcal{Lemmar \;1} Lemmar1:一个二维平面上, 给出任意 k k k个点 ( x 1 , y 1 ) , . . . , ( x k , y k ) (x_1,y_1), ... ,(x_k,y_k) (x1,y1),...,(xk,yk)的坐标,有且仅有一个 k − 1 k-1 k−1次的多项式 q ( x ) q(x) q(x),对于所有给定的 x i x_i xi,使得 q ( x i ) = y i q(x_i)=y_i q(xi)=yi。
假设秘密 s = f ( 0 ) s=f(0) s=f(0), s s s被分享给 n = 3 n=3 n=3个用户,阈值 t = 2 t=2 t=2。
在原始的FedAVG中,用户
u
u
u发送更新值
y
u
y_u
yu给服务器,
y
u
y_u
yu是其真实的模型更新值,服务器进行聚合,再按照聚合规则取平均:
用户u和v之间通过DH建立秘密通信通道,他们之间知道一个秘密随机数 s u v s_{uv} suv。用户1发送给服务器的更新值 y 1 y_1 y1是真实值 x 1 + s 12 + s 13 x_1+s_{12}+s_{13} x1+s12+s13,这样服务器收到 y 1 y_1 y1时,并不知道 x 1 x_1 x1是多少。服务器对收到的所有值进行聚合以后,它们正负才会抵消,相当于真实值的聚合,等同于FedAVG。
上面的方案是存在问题的,假设用户2在上传
y
2
y_2
y2时掉线了,没有把
y
2
y_2
y2发送给服务器,那么这一轮全局更新中,服务端的聚合值
∑
y
i
\sum{y_i}
∑yi是没有意义的。
因为上面的方案存在用户掉线后,聚合值失效的问题,所以考虑带恢复的方案。当用户2掉线时,在恢复阶段,它的值
s
12
s_{12}
s12和
s
23
s_{23}
s23用户1和用户3是知道的,服务器询问用户1和用户3,用户1和用户3进行报告。在恢复阶段结束后,服务器完成聚合。
完整的SMPC的方案,可以在谷歌的论文中读到,可以参阅《Practical Secure Aggregation for Privacy-Preserving Machine Learning》。
图大部分引用自论文:
[1] Jack Sullivan. “Secure Analytics: Federated Learning and Secure Aggregation”. Jan 2020. URL: https://inst.eecs.berkeley.edu/~cs261/fa18/scribe/10_15_revised.pdf
谷歌的方案:
[2] Bonawitz K, Ivanov V, Kreuter B, et al. Practical Secure Aggregation for Privacy-Preserving Machine Learning[C]. CCS, 2017: 1175-1191.
谷歌FL论文作者演讲PPT:
[3] Jakub Konečný. “Federated Learning-Privacy-Preserving Collaborative Machine Learning without Centralized Training Data”. Jan 2020. URL: http://jakubkonecny.com/files/2018-01_UW_Federated_Learning.pdf
[4] Jakub Konečný. “Federated Learning”. FL-IJCAI’19 ppts. Jan 2020. URL: http://fml2019.algorithmic-crowdsourcing.com/
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。