赞
踩
题目: 更快更好的联邦学习:一种特征融合方法
会议: IEEE ICIP 2019
论文地址:https://ieeexplore.ieee.org/abstract/document/8803001
本文将解读清华大学孙立峰教授团队在2019 IEEE International Conference on Image Processing (ICIP)上发表的论文《Towards Faster and Better Federated Learning: A Feature Fusion Approach》。该论文提出了一种特征融合方法来减少联邦学习中通讯的成本,并提升了模型性能:通过聚合来自本地和全局模型的特征,以更少的通信成本实现了更高的精度。此外,特征融合模块为新来的客户端提供更好的初始化,从而加快收敛过程。
联邦学习能够在由大量现代智能设备(如智能手机和物联网设备)组成的分布式网络上进行模型训练。然而,FedAvg算法通常需要很大的通信成本,并且性能也是一个很大的挑战,特别是当本地数据以非IID方式分布时。
因此,本文提出了一种特殊的特征融合机制来解决上述问题:通过聚合来自本地和全局模型的特征,以更少的通信成本实现了更高的精度。此外,特征融合模块为新来的客户端提供更好的初始化,从而加快收敛过程。
为了充分利用设备上的数据,传统的机器学习策略需要从客户端收集数据,然后在服务器上集中训练模型,然后将模型分发给客户端,这给通信网络带来了沉重的负担并且暴露于高隐私风险(所有客户端需要暴露自己的数据)。
2016年,谷歌提出了联邦学习(Federated Learning)的概念,并首次提出了FedAvg算法,它使用本地数据对客户端执行分布式培训,并将这些模型汇总到中央服务器中以避免数据共享。 通过这种方式,减轻了隐私暴露问题。然而,进一步的研究指出,与其他因素相比,通信成本仍然是FL的主要制约因素,例如计算成本,如果模型接受非IID数据训练,则FedAvg的准确性将显着下降。
在本文中,提出了一种新的具有特征融合机制(FedFusion)的FL算法来解决上述问题。通过引入特征融合模块,在特征提取阶段之后聚合来自局部和全局模型的特征,而几乎没有额外的计算成本。这些模块使每个客户端的训练过程更加高效,并且更有针对性地处理非IID数据,因为每个客户端将为自己学习最合适的特征融合模块。
本文贡献:
考虑到通信成本是限制FL的主要因素,目前已经有一些学者做了相关的研究工作。比如Konecny等人在客户端到服务器通信的背景下提出了结构化和草图更新;Yao等人对设备上的培训程序引入了额外的限制,旨在拟合本地数据的同时整合来自其他客户的更多知识;Caldas等人提出federated dropout来训练客户端的子集,并将有损压缩扩展到服务器到客户端的通信。
在本节中,首先介绍所提出的特征融合模块,然后给出具有特征融合机制(FedFusion)的联邦学习算法。
如下图所示:
其中蓝色的部分表示local模型提取的两通道特征,灰色部分表示global模型提取到的两通道特征。图1给出了三种特征融合方式:Conv, Multi和Single。特征的提取在CNN中可以理解为经过卷积和池化操作后得到的图片信息。
每一个输入的图像
x
x
x都会分别被局部特征提取器
E
l
E_l
El和全局特征提取器
E
g
Eg
Eg映射到
R
C
×
H
×
W
R^{C\times H\times W}
RC×H×W。
随后,特征融合算子 F F F将两个特征提取器提取到的特征进行融合: F ( E l ( x ) , E g ( x ) ) F(E_l(x),E_g(x)) F(El(x),Eg(x)),两个特征融合后被映射到 R C × H × W R^{C\times H\times W} RC×H×W。
F
c
o
n
v
(
E
l
(
x
)
,
E
g
(
x
)
)
=
W
c
o
n
v
(
E
g
(
x
)
∥
E
l
(
x
)
)
F_{c o n v}\left(E_{l}(x), E_{g}(x)\right)=W_{c o n v}\left(E_{g}(x) \| E_{l}(x)\right)
Fconv(El(x),Eg(x))=Wconv(Eg(x)∥El(x))
其中
W
c
o
n
v
W_{c o n v}
Wconv表示shape为
2
C
×
C
2C\times C
2C×C的可学习的权重矩阵。具体操作就是将global特征和local特征进行concat(||)后进行卷积操作。
关于卷积中通道C、高度H以及宽度W的解释可见:DL入门(1):卷积神经网络(CNN)
Multi算子:用一个
λ
\lambda
λ权重矩阵来对local和global进行一个加权求和。
Single算子:用一个标量
λ
\lambda
λ来对local和global进行一个加权求和。
经过上述操作后,global特征提取器提取到的特征和local特征提取器提取到的特征将融合成为一个新的特征,特征shape为 R C × H × W R^{C\times H\times W} RC×H×W。
本节讲述带有特征融合机制的联邦学习策略!
本文所提出的FedFusion的典型训练迭代如下图所示:
具体来讲:
客户端在第 i i i轮训练时,将会保留服务器发来的全局的特征提取器 E g E_g Eg,在本地分类器进行迭代更新时,会考虑将 E g E_g Eg和 E l E_l El进行融合。
在训练期间, E g E_g Eg被冻结并且引入了3.1中描述的附加特征融合模块。
在客户端上进行训练后,将与特征融合模块结合的本地模型发送到中央服务器进行模型聚合,这里使用指数移动平均策略来平滑更新。
算法伪代码:
对中央服务器:
对客户端t的第r轮训练来说:
在实验中使用MNIST和CIFAR10作为基本数据集。
对于MNIST数字识别任务,使用与FedAvg相同的模型:具有两个5×5卷积层的CNN(第一个具有32个通道,第二个具有64个通道,每个之后是ReLU激活和2×2最大池化),512个节点的完全连接层(ReLU+Random Dropout),softmax输出层。
对于CIFAR10,使用具有两个5×5卷积层的CNN(均具有64个通道,每个通道之后是ReLU激活和3×3最大池化,stride为2),两个完全连接层(第一个具有384个节点,第二个具有192个节点,每个之后是ReLU+Random Dropout)和最终的softmax输出层。
数据分割方式:
a和b表述了在人工形成的非IID场景下, FedFusion和FedAvg的收敛图。可以看到,在相同的通讯轮数下,不进行特征融合,也就是FedAvg的表现是最差的,其精度最低。
(图有些看不清),具体的数据如下表所示:
可以看到进行特征融合后(无论哪一种特征融合),模型的精度都有所提升。
Multi融合方式的效果最好,Conv融合方式次之。
为了模拟用户特定的非IID分区,对每个客户端的MNIST应用不同的排列,这就是之前几项研究中所谓的置换MNIST。
表2列出了达到某些精度(此处为94%和95%)的通信轮数以及通信轮数相对于FedAvg的减少:
从上表可以看出,FedFusion+Conv实现了通讯轮数最大幅度的降低。
值得注意的是,用户特定的“非IID分区更接近现实的FL场景,因此在这种情况下改进更有意义。
如下图所示:
在IID场景下,使用Multi和Conv进行融合可以以较低的通信成本实现更高的精度。
对特征融合算子做出如下简要概括:
联邦学习巨大的通讯成本是一个需要解决的紧急问题。 在本文中,尝试从减少沟通轮次的角度进行一些改进:提出了一种新的具有特征融合模块的FL算法,并在当前较为流行的FL设置中对其进行评估。实验结果表明,该方法具有较高的精度,同时将通信轮次减少了60%以上。
未来的工作可能包括将目前的算法扩展到更复杂的模型和场景,以及将通信轮次减少策略与其他类型的方法(例如梯度估计和压缩)相结合。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。