当前位置:   article > 正文

联邦学习算法FedAvg实现(PyTorch)_pytorch fedavg

pytorch fedavg

联邦学习方法FedAvg实现(PyTorch

通俗来讲,联邦学习(Federated Learning)结构由Server和若干Client组成,在联邦学习方法过程中,没有任何用户数据被发送到Server端,通过这种方式保护了用户的数据隐私。另外,通信中传输的参数是特定于改进当前模型的,因此一旦应用了他们,Server就没有理由存储它们,这进一步提高了安全性。

联邦学习的整体思路是“数据不动 模型动”。Server提供全局共享的模型,Client下载模型并训练自己的数据集,同时更新模型参数。在Server和Client的每一次通信中,Server将当前的模型参数分发给各个Client(或者说Client下载服务端的模型参数),经过Client的训练之后,将更新后的模型参数返回给Server,Server通过某种方法将聚合得到的N个模型参数融合成一个作为更新后的Server模型参数。以此循环。

本文实战联邦学习算法,梳理其方法流程,完成pytorch的代码实现。

数据集划分

首先我们要为每个客户端分配数据,在实际场景中,每个客户端有自己独有的数据,这里为了模拟场景,手动划分数据集给每个客户端。

客户端之间的数据可能是独立分布IID,也可能是非独立同分布Non_IID的

以Minist数据集为例——0~9的手写数字数据集,独立分布IID的意思是每个客户端都拥有0~9的完整数据集,而非独立同分布Non_IID就是每个客户端可能只拥有一部分数据集,比如说一个只拥有0、1的数据集,一个只拥有2、3的数据集

可以参考提升联邦学习的效率和效果 - 雪琪的文章 - 知乎 https://zhuanlan.zhihu.com/p/108163485

def cifar_iid(dataset, num_users):

    num_items = int(len(dataset)/num_users)
    # num_items = 500 # 测试时使用

    #print(num_items)
    dict_users, all_idxs = {
   }, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    print('cifar_iid is ok!')
    return dict_users

def cifar_noniid(dataset, num_users):
  
    num_items = int(len(dataset) / num_users) # 每个节点的图片总数
    # print(num_items)
    num_labels = 2  # 每个节点只包含两类图片
    num_pics = (int)(num_items / num_labels) # 每个节点每类所包含的图片总数
    # print(num_pics)
    dict_users, idxs, per_labal_idxs = {
   }, [i for i in range(10)], {
   }

    for i in range(10):
        per_labal_idxs[i] = [i for i in range(i * 5000,(i+1) * 5000)]

    for i in range(num_users):
        # print(idxs)
        random_labels = np.random.choice(idxs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/190592?site
推荐阅读
相关标签
  

闽ICP备14008679号