赞
踩
链接: 官方文档,官方github主页
在命令行输入如下的命令以安装必要的包PyTorch (torch and torchvision) and Flower (flwr)
pip install -q flwr[simulation] torch torchvision matplotlib
联邦学习系统包含一个服务器和多个客户端。在Flower中,通过创建flwr.client.Client或flwr.client.NumpyClient等的子类。
以flwr.client.NumpyClient为例,其子类需要实现三个方法:get_parameters, fit, 和 evaluate
函数名 | 需要实现的功能 | 输入 | 输出 |
---|---|---|---|
get_parameters | 返回当前客户端的模型参数 | self, config | List[np.ndarray] |
fit | 从服务器接收并本地训练模型参数,返回更新后的模型参数 | self, parameters (NDArrays), config (Dict[str, Scalar]) | parameters (NDArrays),num_examples (int),metrics (Dict[str, Scalar])] |
evaluate | 从服务器接收模型参数,并通过本地数据评价,返回评价结果 | self, parameters (NDArrays), config (Dict[str, Scalar]) | loss (float),num_examples (int),metrics (Dict[str, Scalar])] |
示例代码:
class FlowerClient(fl.client.NumPyClient): def __init__(self, net, trainloader, valloader): self.net = net self.trainloader = trainloader self.valloader = valloader def get_parameters(self, config): return get_parameters(self.net) def fit(self, parameters, config): set_parameters(self.net, parameters) train(self.net, self.trainloader, epochs=1) return get_parameters(self.net), len(self.trainloader), {} def evaluate(self, parameters, config): set_parameters(self.net, parameters) loss, accuracy = test(self.net, self.valloader) return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
当在一台机器上仿真多个客户端时,它们需要共享CPU, GPU, 内存等。这可能会导致迅速耗尽可用的内存资源。
Flower提供了特殊的模拟功能,仅在实际需要训练或评估时才创建FlowerClient实例。为了使Flower能在必要时创建客户端,需要提供client_fn函数。Flower在需要特定客户端实例来调用fit或evaluate函数时,会调用client_fn创建实例,并通常在使用后被丢弃,故不应该保留任何本地状态。客户端通过客户端ID(cid)标志
def client_fn(cid: str) -> FlowerClient:
"""Create a Flower client representing a single organization."""
# Load model
net = Net().to(DEVICE)
# Load data (CIFAR-10)
# Note: each client gets a different trainloader/valloader, so each client
# will train and evaluate on their own unique data
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
# Create a single Flower client representing a single organization
return FlowerClient(net, trainloader, valloader)
Flwr.server.strategy中封存了联邦学习方法/算法,例如FedAvg或FedAdagrad。
# Create FedAvg strategy strategy = fl.server.strategy.FedAvg( fraction_fit=1.0, # Sample 100% of available clients for training fraction_evaluate=0.5, # Sample 50% of available clients for evaluation min_fit_clients=10, # Never sample less than 10 clients for training min_evaluate_clients=5, # Never sample less than 5 clients for evaluation min_available_clients=10, # Wait until all 10 clients are available ) # Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU) client_resources = None if DEVICE.type == "cuda": client_resources = {"num_gpus": 1} # Start simulation fl.simulation.start_simulation( client_fn=client_fn, num_clients=NUM_CLIENTS, config=fl.server.ServerConfig(num_rounds=5), strategy=strategy, client_resources=client_resources, )
Flower即可以在服务器端也可以在客户端评价聚合模型,即集中评估(Centralized Evaluation)和联合评估(Federated Evaluation):
集中评估与集中式机器学习评价类似。如果有服务器端的评价数据集,可以不必在评价时将聚合模型发送给客户。在flower中,需要实现evaluate函数,并将该函数名作为参数输入到strategy的evaluate_fn参数中。
联合评估更复杂但也更强大,这使得我们可以在更大的数据集上评估模型,并得到更贴近现实的结果,然而这种能力是有代价的:如果这些客户端并不总是可用,我们的评价数据集会在连续的学习中发生变化。此外,每个客户端持有的数据集也可以在连续的轮中更改。在flower中通过为FlowerClient部署evaluate方法实现
# The `evaluate` function will be by Flower called after every round
def evaluate(
server_round: int,
parameters: fl.common.NDArrays,
config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
net = Net().to(DEVICE)
valloader = valloaders[0]
set_parameters(net, parameters) # Update model with the latest parameters
loss, accuracy = test(net, valloader)
print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
return loss, {"accuracy": accuracy}
strategy = fl.server.strategy.FedAvg( fraction_fit=0.3, fraction_evaluate=0.3, min_fit_clients=3, min_evaluate_clients=3, min_available_clients=NUM_CLIENTS, initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(Net())), evaluate_fn=evaluate, # Pass the evaluation function ) fl.simulation.start_simulation( client_fn=client_fn, num_clients=NUM_CLIENTS, config=fl.server.ServerConfig(num_rounds=3), # Just three rounds strategy=strategy, client_resources=client_resources, )
有时,我们希望服务器端能对客户端 (fit, evaluate)进行配置。如服务器要求客户端训练一定数量的本地epoch。可以使用fit(或evaluate)函数的config参数接受配置字典,通过读取配置字典来调整本地的运行。对于server端,需要编写以round数为输入的函数,并将函数名输入到on-fit-config_fn及on_evaluate_config_fn
客户端使用参数:
class FlowerClient(fl.client.NumPyClient): def __init__(self, cid, net, trainloader, valloader): self.cid = cid self.net = net self.trainloader = trainloader self.valloader = valloader def get_parameters(self, config): print(f"[Client {self.cid}] get_parameters") return get_parameters(self.net) def fit(self, parameters, config): # Read values from config server_round = config["server_round"] local_epochs = config["local_epochs"] # Use values provided by the config print(f"[Client {self.cid}, round {server_round}] fit, config: {config}") set_parameters(self.net, parameters) train(self.net, self.trainloader, epochs=local_epochs) return get_parameters(self.net), len(self.trainloader), {} def evaluate(self, parameters, config): print(f"[Client {self.cid}] evaluate, config: {config}") set_parameters(self.net, parameters) loss, accuracy = test(self.net, self.valloader) return float(loss), len(self.valloader), {"accuracy": float(accuracy)} def client_fn(cid) -> FlowerClient: net = Net().to(DEVICE) trainloader = trainloaders[int(cid)] valloader = valloaders[int(cid)] return FlowerClient(cid, net, trainloader, valloader)
服务端传递参数函数:
def fit_config(server_round: int):
"""Return training configuration dict for each round.
Perform two rounds of training with one local epoch, increase to two local
epochs afterwards.
"""
config = {
"server_round": server_round, # The current round of federated learning
"local_epochs": 1 if server_round < 2 else 2, #
}
return config
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。