赞
踩
通俗来讲,联邦学习(Federated Learning)结构由Server和若干Client组成,在联邦学习方法过程中,没有任何用户数据被发送到Server端,通过这种方式保护了用户的数据隐私。另外,通信中传输的参数是特定于改进当前模型的,因此一旦应用了他们,Server就没有理由存储它们,这进一步提高了安全性。
联邦学习的整体思路是“数据不动 模型动”。Server提供全局共享的模型,Client下载模型并训练自己的数据集,同时更新模型参数。在Server和Client的每一次通信中,Server将当前的模型参数分发给各个Client(或者说Client下载服务端的模型参数),经过Client的训练之后,将更新后的模型参数返回给Server,Server通过某种方法将聚合得到的N个模型参数融合成一个作为更新后的Server模型参数。以此循环。
本文实战联邦学习算法,梳理其方法流程,完成pytorch的代码实现。
首先我们要为每个客户端分配数据,在实际场景中,每个客户端有自己独有的数据,这里为了模拟场景,手动划分数据集给每个客户端。
客户端之间的数据可能是独立分布IID,也可能是非独立同分布Non_IID的
以Minist数据集为例——0~9的手写数字数据集,独立分布IID的意思是每个客户端都拥有0~9的完整数据集,而非独立同分布Non_IID就是每个客户端可能只拥有一部分数据集,比如说一个只拥有0、1的数据集,一个只拥有2、3的数据集
可以参考提升联邦学习的效率和效果 - 雪琪的文章 - 知乎 https://zhuanlan.zhihu.com/p/108163485
def cifar_iid(dataset, num_users): num_items = int(len(dataset)/num_users) # num_items = 500 # 测试时使用 #print(num_items) dict_users, all_idxs = { }, [i for i in range(len(dataset))] for i in range(num_users): dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) all_idxs = list(set(all_idxs) - dict_users[i]) print('cifar_iid is ok!') return dict_users def cifar_noniid(dataset, num_users): num_items = int(len(dataset) / num_users) # 每个节点的图片总数 # print(num_items) num_labels = 2 # 每个节点只包含两类图片 num_pics = (int)(num_items / num_labels) # 每个节点每类所包含的图片总数 # print(num_pics) dict_users, idxs, per_labal_idxs = { }, [i for i in range(10)], { } for i in range(10): per_labal_idxs[i] = [i for i in range(i * 5000,(i+1) * 5000)] for i in range(num_users): # print(idxs) random_labels = np.random.choice(idxs
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。