当前位置:   article > 正文

【联邦学习实战】

联邦学习实战

模组分类

整体架构

主要分为主函数,服务端模块和客户端模块。
在客户端模块中
->训练本地模型
在服务端模块
->使用FedAvg聚合函数更新全局模型
->定义模型评估函数
在主函数中
->获取Json数据中的配置信息,
->初始化服务端模块和客户端模块
->进行全局模型训练:->初始化空模型参数weight_accumulator->遍历客户端,每个客户端本地训练模型并根据客户端的参数差值字典更新总体权重->进行模型参数聚合

主函数模块

import argparse, json
import datetime
import os
import logging
import torch, random




from server import *
from client import *
import models, datasets

if __name__ == '__main__':

    # 设置命令行程序
    parser = argparse.ArgumentParser(description='Federated Learning')
    parser.add_argument('-c', '--conf', dest='conf')
    # 获取所有的参数
    args = parser.parse_args()

    conf_path="./utils/conf.json"


    # # 读取配置文件
    # with open(args.conf, 'r') as f:
    #     conf = json.load(f)

    # 读取配置文件
    with open(conf_path, 'r') as f:
        conf = json.load(f)

    # 获取数据集, 加载描述信息
    train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])   # "type" : "cifar",

    # 开启服务器
    server = Server(conf, eval_datasets)
    # 客户端列表
    clients = []

    # 添加10个客户端到列表
    for c in range(conf["no_models"]):      # "no_models" : 10,
        clients.append(Client(conf, server.global_model, train_datasets, c))

    print("\n\n")

    # 全局模型训练
    for e in range(conf["global_epochs"]):                # "global_epochs" : 20,
        # 每次训练都是从clients列表中随机采样k个进行本轮训练
        candidates = random.sample(clients, conf["k"])    # "k" : 5,

        # 权重累计
        weight_accumulator = {}

        # 初始化空模型参数weight_accumulator
        for name, params in server.global_model.state_dict().items():
            # 生成一个和参数矩阵大小相同的0矩阵
            weight_accumulator[name] = torch.zeros_like(params)

        # 遍历客户端,每个客户端本地训练模型
        for c in candidates:
            diff = c.local_train(server.global_model)

            # 根据客户端的参数差值字典更新总体权重
            for name, params in server.global_model.state_dict().items():
                weight_accumulator[name].add_(diff[name])

        # 模型参数聚合
        server.model_aggregate(weight_accumulator)

        # 模型评估
        acc, loss = server.model_eval()

        print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss))








  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82

服务端模块

import torch
import models
"""
服务端类Server的工作:
1.将配合信息拷贝到服务端中
2.按照配置中的模型信息获取模型,
下载模型后将其作为全局初始模型
"""

class Server(object):

    def __init__(self, conf, eval_dataset):
        self.conf = conf
        self.global_model = models.get_model(self.conf["model_name"])                     # "model_name" : "resnet18",
        self.eval_loader=torch.utils.data.DataLoader(eval_dataset,
                                                     batch_size=self.conf["batch_size"],  # 单次用以训练的数据(样本)个数, "batch_size" : 32,
                                                     shuffle=True)                        # 用于打乱数据集,每次都会以不同的顺序返回

    # 使用FedAvg聚合函数更新全局模型
    def model_aggregate(self, weight_accumulator):                         # weight_accumulator 存储每一个客户端的上传参数变化值
        for name, data in self.global_model.state_dict().items():          # state_dict作为python的字典对象将每一层的参数映射成tensor张量
            update_per_layer=weight_accumulator[name]*self.conf["lambda"]  # "lambda" : 0.1 Lambda取值方式为:1/客户端总数
            if data.type() != update_per_layer.type():
                data.add_(update_per_layer.to(torch.int64))                # 返回指定类型的张量
            else:
                data.add_(update_per_layer)

    """
    服务端的模型评估函数对当前聚合后的全局模型进行分析,判断模型训练需要下一轮迭代、还是提前终止、或者模型
    是否出现发散退化的现象。根据不同的结果,服务端可以采取不同的措施。
    """
    # 定义模型评估函数
    def model_eval(self):
        self.global_model.eval()
        total_loss = 0.0
        correct = 0
        dataset_size = 0
        for batch_id, batch in enumerate(self.eval_loader):   # 将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标
            data, target = batch
            dataset_size += data.size()[0]
            if torch.cuda.is_available():                                           # 使数据在GPU上计算
                data = data.cuda()
                target = target.cuda()
            output = self.global_model(data)
            total_loss += torch.nn.functional.cross_entropy(output, target,
                                                            reduction='sum'.item())  # 把损失值聚合起来
            pred = output.data.max(1)[1]                                             # 获取最大的对数概率的索引值
            correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
        acc = 100.0 * (float(correct)/float(dataset_size))                           # 计算准确率
        total_l = total_loss / dataset_size                                          # 计算损失值
        return acc, total_l





  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56

客户端模块

在这里插入图片描述

import models, torch, copy

"""
客户端主要功能:接收服务端的下发指令和全局模型,并利用本地数据进行局部模型训练
客户端工作包括:
1.将配置信息拷贝到客户端中
2.按照配置中的模型信息获取模型,
通常有服务端将模型参数传递给客户端,客户端用全局模型覆盖本地模型
3.配置本地训练数据,不同客户端使用不同子数据集,互相之间无交集
"""

class Client(object):
    def __init__(self, conf, model, train_dataset, id=1):
        self.conf = conf                          # 配置文件
        self.local_model = model                  # 客户端本地模型
        self.client_id = id                       # 客户端ID
        self.train_dataset = train_dataset        # 客户端本地数据集
        all_range = list(range(len(self.train_dataset)))
        data_len = int(len(self.train_dataset) / self.conf['no_models'])       # "no_models" : 10,
        indices = all_range[id * data_len : (id+1) * data_len]
        self.train_loader = torch.utils.data.DataLoader(self.train_dataset,
                                                        batch_size=conf["batch_size"],
                                                        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))

    """
    模型本地训练函数:作为一个图像分类例子,使用交叉熵作为本地模型的损失函数,
    利用梯度下降来求解并更新参数值,实现如下
    """
    def local_train(self, model):
        for name, param in model.state_dict().items():                # 客户端用服务器下发的全局模型覆盖本地模型
            self.local_model.state_dict()[name].copy_(param.clone())  # 当多个设备或进程上的模型需要进行同步时,可以使用这种方式将参数从主模型复制到本地模型中
        optimizer = torch.optim.SGD(self.local_model.parameters(),    # 定义最优化函数器,用于本地模型训练
                                    lr=self.conf['lr'],
                                    momentum=self.conf['momentum'])   # 避免损失函数在训练的过程中出现局部最小值的情况,而没有达到全局最优的状态。
        # 本地训练模型
        self.local_model.train()
        for e in range(self.conf["global_epochs"]):
            for batch_id, batch in enumerate(self.train_loader):
                data, target = batch
                if torch.cuda.is_available():
                    data = data.cuda()
                    target = target.cuda()
                optimizer.zero_grad()                                         # 清空过往梯度;
                output = self.local_model(data)
                loss = torch.nn.functional.cross_entropy(output,target)
                loss.backward()                                               # 反向传播,计算当前梯度;
                optimizer.step()                                              # 根据梯度更新网络参数
            print("Epoch %d done." % e)
        diff = dict()                                                         # 创建一个字典
        for name, data in self.local_model.state_dict.items():
            diff[name] = (data - model.state_dict()[name])
        return diff


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54

数据集模块


from torchvision import datasets, transforms

# Transforms是指常见的图像变换功能

#用到Pytorch框架中的一些方法
def get_dataset(dir,name):
    if name=='mnist':
        train_dataset=datasets.MNIST(dir, train=True, download=True,   # dir-->root: str,
                                     transform=transforms.ToTensor())  # ToTensor将其转换为张量格式
        eval_dataset=datasets.MNIST(dir, train=False, transform=transforms.ToTensor())

    elif name=="cifar":
        # 首先使用transforms工具包对CIFAR中的数据进行加工,转换为tensor形式
        transform_train=transforms.Compose([transforms.RandomCrop(32,padding=4), # 依据给定的size对图片随机裁剪(32*32)
                                            transforms.RandomHorizontalFlip(),   # 以给定的概率水平(随机)翻转图像,p的默认值是0.5
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914,0.4822,0.4465),    # 将获取张量图像,并使用平均值和标准差对其进行归一化
                                                                 (0.2023, 0.1994, 0.2010))  # 3个参数:mean, std, inplace默认是false
                                            ])
        transform_test=transforms.Compose([transforms.ToTensor(),
                                           transforms.Normalize((0.4914,0.4822,0.4465),
                                                                 (0.2023, 0.1994, 0.2010))
                                           ])
        train_dataset=datasets.CIFAR10(dir, train=True, download= True, transform=transform_train)
        eval_dataset=datasets.CIFAR10(dir, train=False, transform=transform_test)

    return train_dataset, eval_dataset
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

模型模块

import torch
from torchvision import models


def get_model(name="vgg16", pretrained=True):
    if name == "resnet18":
        model = models.resnet18(pretrained=pretrained)
    elif name == "resnet50":
        model = models.resnet50(pretrained=pretrained)
    elif name == "densenet121":
        model = models.densenet121(pretrained=pretrained)
    elif name == "alexnet":
        model = models.alexnet(pretrained=pretrained)
    elif name == "vgg16":
        model = models.vgg16(pretrained=pretrained)
    elif name == "vgg19":
        model = models.vgg19(pretrained=pretrained)
    elif name == "inception_v3":
        model = models.inception_v3(pretrained=pretrained)
    elif name == "googlenet":
        model = models.googlenet(pretrained=pretrained)

    if torch.cuda.is_available():
        return model.cuda()
    else:
        return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

Json数据

{
  "model_name" : "resnet18",
  "no_models" : 10,
  "type" : "cifar",
  "global_epochs" : 20,
  "local_epochs" : 3,
  "k" : 6,
  "batch_size" : 32,
  "lr" : 0.001,
  "momentum" : 0.0001,
  "lambda" : 0.1
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

model_name:模型名称
no_models:客户端总数量
type:数据集信息
global_epochs:全局迭代次数,即服务端与客户端的通信迭代次数
local_epochs:本地模型训练迭代次数
k:每一轮迭代时,服务端会从所有客户端中挑选k个客户端参与训练。
batch_size:本地训练每一轮的样本数
lr,momentum,lambda:本地训练的超参数设置

联邦训练与集中式训练的效果对比

左图为精准度对比,右图为损失值对比

联邦模型与单点训练模型对比

联邦训练后的模型与单点训练的模型在推断阶段的性能比较

从图中可以看出,每一轮参与联邦训练的客户端数目(K值)不同,其性能也不同。K值越大,即每一轮参与训练的客户端数目越多,训练的模型性能越好。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/221341
推荐阅读
相关标签
  

闽ICP备14008679号