链接: 官方文档,官方github主页
在命令行输入如下的命令以安装必要的包PyTorch (torch and torchvision) and Flower (flwr)
pip install -q flwr[simulation] torch torchvision matplotlib
以flwr.client.NumpyClient为例,其子类需要实现三个方法:get_parameters, fit, 和 evaluate
函数名 | 需要实现的功能 | 输入 | 输出 |
get_parameters | 返回当前客户端的模型参数 | self, config | List[np.ndarray] |
fit | 从服务器接收并本地训练模型参数,返回更新后的模型参数 | self, parameters (NDArrays), config (Dict[str, Scalar]) | parameters (NDArrays),num_examples (int),metrics (Dict[str, Scalar])] |
evaluate | 从服务器接收模型参数,并通过本地数据评价,返回评价结果 | self, parameters (NDArrays), config (Dict[str, Scalar]) | loss (float),num_examples (int),metrics (Dict[str, Scalar])] |
class FlowerClient(fl.client.NumPyClient): def __init__(self, net, trainloader, valloader): self.net = net self.trainloader = trainloader self.valloader = valloader def get_parameters(self, config): return get_parameters(self.net) def fit(self, parameters, config): set_parameters(self.net, parameters) train(self.net, self.trainloader, epochs=1) return get_parameters(self.net), len(self.trainloader), {} def evaluate(self, parameters, config): set_parameters(self.net, parameters) loss, accuracy = test(self.net, self.valloader) return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
当在一台机器上仿真多个客户端时,它们需要共享CPU, GPU, 内存等。这可能会导致迅速耗尽可用的内存资源。
def client_fn(cid: str) -> FlowerClient:
"""Create a Flower client representing a single organization."""
# Load model
net = Net().to(DEVICE)
# Load data (CIFAR-10)
# Note: each client gets a different trainloader/valloader, so each client
# will train and evaluate on their own unique data
trainloader = trainloaders[int(cid)]
valloader = valloaders[int(cid)]
# Create a single Flower client representing a single organization
return FlowerClient(net, trainloader, valloader)
# Create FedAvg strategy strategy = fl.server.strategy.FedAvg( fraction_fit=1.0, # Sample 100% of available clients for training fraction_evaluate=0.5, # Sample 50% of available clients for evaluation min_fit_clients=10, # Never sample less than 10 clients for training min_evaluate_clients=5, # Never sample less than 5 clients for evaluation min_available_clients=10, # Wait until all 10 clients are available ) # Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU) client_resources = None if DEVICE.type == "cuda": client_resources = {"num_gpus": 1} # Start simulation fl.simulation.start_simulation( client_fn=client_fn, num_clients=NUM_CLIENTS, config=fl.server.ServerConfig(num_rounds=5), strategy=strategy, client_resources=client_resources, )
Flower即可以在服务器端也可以在客户端评价聚合模型,即集中评估(Centralized Evaluation)和联合评估(Federated Evaluation):
# The `evaluate` function will be by Flower called after every round
def evaluate(
server_round: int,
parameters: fl.common.NDArrays,
config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
net = Net().to(DEVICE)
valloader = valloaders[0]
set_parameters(net, parameters) # Update model with the latest parameters
loss, accuracy = test(net, valloader)
print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
return loss, {"accuracy": accuracy}
strategy = fl.server.strategy.FedAvg( fraction_fit=0.3, fraction_evaluate=0.3, min_fit_clients=3, min_evaluate_clients=3, min_available_clients=NUM_CLIENTS, initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(Net())), evaluate_fn=evaluate, # Pass the evaluation function ) fl.simulation.start_simulation( client_fn=client_fn, num_clients=NUM_CLIENTS, config=fl.server.ServerConfig(num_rounds=3), # Just three rounds strategy=strategy, client_resources=client_resources, )
有时,我们希望服务器端能对客户端 (fit, evaluate)进行配置。如服务器要求客户端训练一定数量的本地epoch。可以使用fit(或evaluate)函数的config参数接受配置字典,通过读取配置字典来调整本地的运行。对于server端,需要编写以round数为输入的函数,并将函数名输入到on-fit-config_fn及on_evaluate_config_fn
class FlowerClient(fl.client.NumPyClient): def __init__(self, cid, net, trainloader, valloader): self.cid = cid self.net = net self.trainloader = trainloader self.valloader = valloader def get_parameters(self, config): print(f"[Client {self.cid}] get_parameters") return get_parameters(self.net) def fit(self, parameters, config): # Read values from config server_round = config["server_round"] local_epochs = config["local_epochs"] # Use values provided by the config print(f"[Client {self.cid}, round {server_round}] fit, config: {config}") set_parameters(self.net, parameters) train(self.net, self.trainloader, epochs=local_epochs) return get_parameters(self.net), len(self.trainloader), {} def evaluate(self, parameters, config): print(f"[Client {self.cid}] evaluate, config: {config}") set_parameters(self.net, parameters) loss, accuracy = test(self.net, self.valloader) return float(loss), len(self.valloader), {"accuracy": float(accuracy)} def client_fn(cid) -> FlowerClient: net = Net().to(DEVICE) trainloader = trainloaders[int(cid)] valloader = valloaders[int(cid)] return FlowerClient(cid, net, trainloader, valloader)
def fit_config(server_round: int):
"""Return training configuration dict for each round.
Perform two rounds of training with one local epoch, increase to two local
epochs afterwards.
config = {
"server_round": server_round, # The current round of federated learning
"local_epochs": 1 if server_round < 2 else 2, #
return config
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。