当前位置:   article > 正文

联邦学习实战

联邦学习实战

使用pySyft进行一次简单的联邦平均FedAVG


  1. import torch
  2. from torch import nn
  3. from torch import optim
  4. import syft as sy
  5. # 扩展pytorch功能使其满足联邦学习训练
  6. hook = sy.TorchHook(torch)
  7. # 建立工作机和安全工作机,工作机作为客户端,用来训练模型
  8. # 安全工作机作为服务器,用于数据的聚合和交流
  9. Li = sy.VirtualWorker(hook, id='li')
  10. Zhang = sy.VirtualWorker(hook, id='zhang')
  11. secure_worker = sy.VirtualWorker(hook, id='secure_worker')
  12. data = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1.]], requires_grad=True)
  13. target = torch.tensor([[0], [0], [0], [1.]], requires_grad=True)
  14. dataLi = data[0:2].send(Li)
  15. targetLi = target[0:2].send(Li)
  16. dataZhang = data[2:].send(Zhang)
  17. targetZhang = target[2:].send(Zhang)
  18. # 建立模型
  19. model = nn.Linear(2, 1)
  20. # 训练函数
  21. def train():
  22. # 设置迭代次数
  23. interations = 20
  24. workerInters = 5
  25. for inter in range(interations):
  26. # 将服务器上全局模型发给两个参与方
  27. LiModel = model.copy().send(Li)
  28. ZhangModel = model.copy().send(Zhang)
  29. liOpt = optim.SGD(params=LiModel.parameters(), lr=0.1)
  30. ZhangOpt = optim.SGD(params=ZhangModel.parameters(), lr=0.1)
  31. for wi in range(workerInters):
  32. # li训练一次
  33. # 消除之前的梯度
  34. liOpt.zero_grad()
  35. # 预测
  36. liPre = LiModel(dataLi)
  37. # 计算损失
  38. liLoss = ((liPre - targetLi) ** 2).sum()
  39. # 回传损失
  40. liLoss.backward()
  41. # 更新参数
  42. liOpt.step()
  43. liLoss = liLoss.get().data
  44. # Zhang训练一次
  45. ZhangOpt.zero_grad()
  46. ZhangPre = ZhangModel(dataZhang)
  47. ZhangLoss = ((ZhangPre - targetZhang) ** 2).sum()
  48. ZhangLoss.backward()
  49. ZhangOpt.step()
  50. ZhangLoss = ZhangLoss.get().data
  51. # 将更新的局部模型发送给安全工作机
  52. LiModel.move(secure_worker)
  53. ZhangModel.move(secure_worker)
  54. # 模型平均
  55. with torch.no_grad():
  56. model.weight.set_(((ZhangModel.weight.data + LiModel.weight.data) / 2).get())
  57. model.bias.set_(((ZhangModel.bias.data + LiModel.bias.data) / 2).get())
  58. print('第' + str(inter+1) + '轮')
  59. print('Li: ' + str(liLoss) + ' zhang: ' + str(ZhangLoss))
  60. pass
  61. pass
  62. # 开始训练
  63. train()
  64. # 用全局模型预测训练结果
  65. preSecure = model(data)
  66. loss = ((preSecure-target)**2).sum()
  67. print(target)
  68. print(preSecure)
  69. print(loss.data)

运行结果:

 第1轮
Li: tensor(6.9580e-05)    zhang: tensor(0.2173)
第2轮
Li: tensor(0.0061)    zhang: tensor(0.1492)
第3轮
Li: tensor(0.0168)    zhang: tensor(0.1112)
第4轮
Li: tensor(0.0268)    zhang: tensor(0.0893)
第5轮
Li: tensor(0.0347)    zhang: tensor(0.0762)
第6轮
Li: tensor(0.0406)    zhang: tensor(0.0681)
第7轮
Li: tensor(0.0448)    zhang: tensor(0.0631)
第8轮
Li: tensor(0.0477)    zhang: tensor(0.0599)
第9轮
Li: tensor(0.0498)    zhang: tensor(0.0579)
第10轮
Li: tensor(0.0512)    zhang: tensor(0.0566)
第11轮
Li: tensor(0.0522)    zhang: tensor(0.0557)
第12轮
Li: tensor(0.0529)    zhang: tensor(0.0552)
第13轮
Li: tensor(0.0534)    zhang: tensor(0.0548)
第14轮
Li: tensor(0.0537)    zhang: tensor(0.0546)
第15轮
Li: tensor(0.0540)    zhang: tensor(0.0544)
第16轮
Li: tensor(0.0542)    zhang: tensor(0.0544)
第17轮
Li: tensor(0.0543)    zhang: tensor(0.0543)
第18轮
Li: tensor(0.0545)    zhang: tensor(0.0543)
第19轮
Li: tensor(0.0546)    zhang: tensor(0.0542)
第20轮
Li: tensor(0.0546)    zhang: tensor(0.0542)
tensor([[0.],
        [0.],
        [0.],
        [1.]], requires_grad=True)
tensor([[-0.1793],
        [ 0.3207],
        [ 0.1649],
        [ 0.6649]], grad_fn=<AddmmBackward>)
tensor(0.2745)

Process finished with exit code 0

参考文献

  • 王健宗,李泽远,何安珣. 《深入浅出联邦学习:原理与实践》. 机械工业出版社. 2021年5月
本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号