赞
踩
使用pySyft进行一次简单的联邦平均FedAVG
- import torch
- from torch import nn
- from torch import optim
- import syft as sy
-
- # 扩展pytorch功能使其满足联邦学习训练
- hook = sy.TorchHook(torch)
- # 建立工作机和安全工作机,工作机作为客户端,用来训练模型
- # 安全工作机作为服务器,用于数据的聚合和交流
- Li = sy.VirtualWorker(hook, id='li')
- Zhang = sy.VirtualWorker(hook, id='zhang')
- secure_worker = sy.VirtualWorker(hook, id='secure_worker')
- data = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1.]], requires_grad=True)
- target = torch.tensor([[0], [0], [0], [1.]], requires_grad=True)
- dataLi = data[0:2].send(Li)
- targetLi = target[0:2].send(Li)
- dataZhang = data[2:].send(Zhang)
- targetZhang = target[2:].send(Zhang)
- # 建立模型
- model = nn.Linear(2, 1)
-
-
- # 训练函数
- def train():
- # 设置迭代次数
- interations = 20
- workerInters = 5
- for inter in range(interations):
- # 将服务器上全局模型发给两个参与方
- LiModel = model.copy().send(Li)
- ZhangModel = model.copy().send(Zhang)
- liOpt = optim.SGD(params=LiModel.parameters(), lr=0.1)
- ZhangOpt = optim.SGD(params=ZhangModel.parameters(), lr=0.1)
- for wi in range(workerInters):
- # li训练一次
- # 消除之前的梯度
- liOpt.zero_grad()
- # 预测
- liPre = LiModel(dataLi)
- # 计算损失
- liLoss = ((liPre - targetLi) ** 2).sum()
- # 回传损失
- liLoss.backward()
- # 更新参数
- liOpt.step()
- liLoss = liLoss.get().data
- # Zhang训练一次
- ZhangOpt.zero_grad()
- ZhangPre = ZhangModel(dataZhang)
- ZhangLoss = ((ZhangPre - targetZhang) ** 2).sum()
- ZhangLoss.backward()
- ZhangOpt.step()
- ZhangLoss = ZhangLoss.get().data
-
- # 将更新的局部模型发送给安全工作机
- LiModel.move(secure_worker)
- ZhangModel.move(secure_worker)
-
- # 模型平均
- with torch.no_grad():
- model.weight.set_(((ZhangModel.weight.data + LiModel.weight.data) / 2).get())
- model.bias.set_(((ZhangModel.bias.data + LiModel.bias.data) / 2).get())
- print('第' + str(inter+1) + '轮')
- print('Li: ' + str(liLoss) + ' zhang: ' + str(ZhangLoss))
- pass
-
-
- pass
-
-
- # 开始训练
- train()
- # 用全局模型预测训练结果
- preSecure = model(data)
- loss = ((preSecure-target)**2).sum()
- print(target)
- print(preSecure)
- print(loss.data)

运行结果:
第1轮
Li: tensor(6.9580e-05) zhang: tensor(0.2173)
第2轮
Li: tensor(0.0061) zhang: tensor(0.1492)
第3轮
Li: tensor(0.0168) zhang: tensor(0.1112)
第4轮
Li: tensor(0.0268) zhang: tensor(0.0893)
第5轮
Li: tensor(0.0347) zhang: tensor(0.0762)
第6轮
Li: tensor(0.0406) zhang: tensor(0.0681)
第7轮
Li: tensor(0.0448) zhang: tensor(0.0631)
第8轮
Li: tensor(0.0477) zhang: tensor(0.0599)
第9轮
Li: tensor(0.0498) zhang: tensor(0.0579)
第10轮
Li: tensor(0.0512) zhang: tensor(0.0566)
第11轮
Li: tensor(0.0522) zhang: tensor(0.0557)
第12轮
Li: tensor(0.0529) zhang: tensor(0.0552)
第13轮
Li: tensor(0.0534) zhang: tensor(0.0548)
第14轮
Li: tensor(0.0537) zhang: tensor(0.0546)
第15轮
Li: tensor(0.0540) zhang: tensor(0.0544)
第16轮
Li: tensor(0.0542) zhang: tensor(0.0544)
第17轮
Li: tensor(0.0543) zhang: tensor(0.0543)
第18轮
Li: tensor(0.0545) zhang: tensor(0.0543)
第19轮
Li: tensor(0.0546) zhang: tensor(0.0542)
第20轮
Li: tensor(0.0546) zhang: tensor(0.0542)
tensor([[0.],
[0.],
[0.],
[1.]], requires_grad=True)
tensor([[-0.1793],
[ 0.3207],
[ 0.1649],
[ 0.6649]], grad_fn=<AddmmBackward>)
tensor(0.2745)Process finished with exit code 0
参考文献
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。