赞
踩
点击蓝字
关注我们
AI TIME欢迎每一位AI爱好者的加入!
本文介绍的是我们的一篇收录于 AAAI 2024 的论文,主要考虑的是数据异质和模型异构场景下的联邦学习框架。在异构联邦学习中,由于模型架构不同,传统联邦学习中的参数聚合方法不再适用,取而代之的是基于知识蒸馏的知识共享方法。
在这些方法中,我们关注不引入额外数据集的(data-free)这一类方法。这类方法普遍通过共享类别表征向量(prototype)实现,但在模型架构差异较大的场景,每个客户机生成的表征向量差异悬殊,直接在服务器端聚合表征向量会造成表征能力的下降。于是,我们提出一种在服务器端基于自适应间距强化的对比学习来提高表征向量的表征能力的方法 FedTGP,进一步提升客户端模型的分类能力。
论文标题:
FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning
论文链接:
https://arxiv.org/abs/2401.03230
代码链接:
https://github.com/TsingZ0/FedTGP(含有PPT和Poster)
运行实验所需仓库-个性化联邦学习算法库:
https://github.com/TsingZ0/PFLlib
运行实验所需仓库-异构联邦学习算法库:
https://github.com/TsingZ0/HtFLlib
异构联邦学习背景
传统联邦学习通过在每一次迭代中传递模型参数的方式实现知识共享,但该方式存在局限,无法适应更广泛的场景,尤其是不易寻找到参与联邦学习的客户机。客户机在参与联邦学习之前,有自己本地的模型训练任务,也有自研的模型架构和训练得到的模型参数。每个客户机参加联邦学习的动机是为了通过联邦学习增强自己模型的表现能力。若强制要求参与的客户机都使用相同的模型结构且进行模型参数共享,则需要每个客户机重新训练模型。
另一方面,每个客户机训练得到的模型参数也是一种数字资产,尤其是在大模型时代保护模型参数的知识产权尤为重要。此外,共享模型参数也有通讯量大的问题。通过允许异构模型参与联邦学习,并共享轻量化的知识载体,异构联邦学习拓展了传统联邦学习的边界,变得更加实用。
▲ 图1:异构联邦学习技术
目前异构联邦学习技术还未形成统一的知识共享机制,我们考虑一种轻量化且不需要额外数据的知识共享机制:共享 prototype。本文考虑的是面向图像的多分类任务,其 prototype 的定义就是每个类别的代表性特征向量,可通过平均该类所有的特征向量获得。现有工作中,FedProto [1] 是这方面最具代表性的方法之一,如下图所示。
▲ 图2:异构联邦学习中使用prototype作为知识载体
FedProto的局限性
虽然 FedProto 得到了广泛使用,但之前的工作要么将其用在传统联邦学习场景(异构联邦学习技术在传统场景也都适用),要么采用异构性不强的异构模型(比如增减全连接层数和改变 CNN 网络的卷积核等)。在这些场景下,通过加权平均聚合 prototype 的方式确实具有不错的表现。
但当我们考虑更一般的场景:参与联邦学习的客户机训练的模型的架构差异巨大,比如两层 CNN 模型和 ResNet-152 模型。此时 FedProto 的 prototype 聚合方法就出现了一些问题。我们观察到,由于模型架构相差巨大,不同模型的特征提取能力也天差地别,它们生成的 prototype 也天差地别。
当我们通过加权平均去计算全局 prototype(global prototype)时,具有较好表征能力(不同 prototype 之间的间距(margin)较大)的 prototype 会被较差表征能力的 prototype 影响,导致最终得到的 global prototype 表征能力弱于最好的客户机模型。我们称这种现象为间距收缩(margin shrink),如下图所示。进一步地,当这个特征提取能力最好的客户机模型使用了 global prototype 之后,其表征能力则会下降。
▲ 图3:FedProto在模型异构性较大场景下的间距收缩现象(Cifar10)
自适应间距强化的对比学习(ACL)
为了解决上述间距收缩的问题,我们提出了一种自适应间距强化的对比学习方法(ACL),如下图所示。
▲ 图4:FedProto与FedTGP的对比。其中圆形代表客户机上传的prototype,三角形代表global prototype。
该方法的核心思想是训练一个 global prototype,使其能够最大限度地保留最强客户机模型生成的 prototype 的表征能力,同时也汲取来自其他客户机的 prototype 信息。为了实现这一点,我们首先给传统对比学习方法加上一个间距限制,即尽可能保证 prototype 之间的间距不低于所设置的阈值 。考虑类别 对应的 trainable global prototype(TGP),我们定义其训练时候的损失函数为:
其中, 是在第 轮参与联邦学习的客户机集合, 是客户机 上生成的对应类别 的 prototype, 是间距计算函数。
但在联邦学习的过程中,各个客户机模型的特征提取能力不断变化,若设置一个固定的阈值,则会导致间距过大或过小。于是我们考虑将 设置为一个自适应的值,其计算细节如下,其描述的就是每一轮不同类别之间的最大间距,且具有最大值 。
从而我们得到最终的对比学习目标:
使用 ACL 之后,我们便可以消除间距收缩的问题:
▲ 图5:我们的FedTGP在使用ACL之后,消除了间距收缩的问题(Cifar10)
部分实验
由于篇幅原因,我们只展示部分实验结果,更多实验结果和分析详见论文。
▲ 表1:在4个数据集和8种异构模型场景下的测试准确率
▲ 表2:在Cifar100数据集和不同模型异构级别情况下的测试准确率
参考文献
[1] Tan Y, Long G, Liu L, et al. Fedproto: Federated prototype learning across heterogeneous clients[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2022.
往期精彩文章推荐
青年之声,未来之问 | WAIC 2024“未知边界”大模型青年学者说论坛圆满落幕
记得关注我们呀!每天都有新知识!
关于AI TIME
AI TIME源起于2019年,旨在发扬科学思辨精神,邀请各界人士对人工智能理论、算法和场景应用的本质问题进行探索,加强思想碰撞,链接全球AI学者、行业专家和爱好者,希望以辩论的形式,探讨人工智能和人类未来之间的矛盾,探索人工智能领域的未来。
迄今为止,AI TIME已经邀请了1800多位海内外讲者,举办了逾600场活动,超700万人次观看。
我知道你
在看
提出观点,表达想法,欢迎
留言
点击 阅读原文 查看更多!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。