当前位置:   article > 正文

【联邦学习新手必看】手把手教你读懂FedAvg代码,并顺利运行

fedavg代码

代码来源:GitHub - shaoxiongji/federated-learning: A PyTorch Implementation of Federated Learning http://doi.org/10.5281/zenodo.4321561

总览

image-20230614155506480

main_fed.py

这段代码的作用是根据传入的参数,加载指定的数据集(MNIST或CIFAR10),并根据参数设置选择相应的数据转换操作。然后,根据用户划分的方式(IID或Non-IID)生成对应的用户字典。最后,获取训练集中第一个样本的图像大小。

快速了解版

# 1.根据数据集 选择建立模型
if args.model == 'cnn' and args.dataset == 'cifar/mnist/mlp'

# 2.复制当前全局模型net_glob的权重
w_glob = net_glob.state_dict()

# 3.进行本地更新,w_locals和loss_locals分别存储本地权重和本地损失
LocalUpdate()

# 4.进行联邦平均 更新全局权重
w_glob = FedAvg(w_locals)

# 5.保存权重
net_glob.load_state_dict(w_glob)

#6. 打印损失曲线,测试准确率...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

逐行讲解详细版

    # load dataset and split users
    #如果args.dataset为'mnist',进入MNIST数据集的处理流程。
    if args.dataset == 'mnist':
        #定义了一系列数据转换操作,并将其组合成一个转换管道。其中包括将图像转换为张量(transforms.ToTensor())和进行归一化操作(transforms.Normalize())。这些转换操作将应用于MNIST数据集。
        #transforms.Compose()是一个组合多个数据转换操作的函数,将两个数据转换操作transforms.ToTensor()和transforms.Normalize()组合在一起,形成一个转换管道trans_mnist。
        #transforms.ToTensor()是一个数据转换操作,它将图像数据转换为张量格式
        #transforms.Normalize()是另一个数据转换操作,用于数据归一化。它通过减去均值并除以标准差的方式对图像数据进行归一化,通过指定(0.1307,)和(0.3081,)作为均值和标准差,对MNIST图像进行归一化操作。
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        #加载MNIST训练集,并设置了数据的存储路径、是否下载以及应用的转换操作。
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
        #加载MNIST测试集
        dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
        
        # sample users
        #如果args.iid为True,表示采用独立同分布(IID)的方式划分用户,调用了mnist_iid函数来生成用户字典dict_users。mnist_iid函数接受MNIST训练集和用户数量作为参数,返回一个用户字典,其中包含了每个用户的数据。
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        #表示采用非独立同分布(Non-IID)的方式划分用户
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
    elif args.dataset == 'cifar':
        #也是同理,Normalize中的三个0.5分别对应图像的三个通道(红色、绿色、蓝色),通过减去0.5并除以0.5的方式将像素值范围缩放到-1到1之间,以提高模型训练的效果
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    else:
        exit('Error: unrecognized dataset')
        
    # 获取训练集中第一个样本的图像大小,并将其赋值给变量img_size。
    img_size = dataset_train[0][0].shape
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

这段代码的作用是根据命令行参数选择合适的模型类型和数据集类型,构建相应的模型net_glob,并将其设置为训练模式。这个步骤通常是在开始训练前的模型初始化阶段进行的,确保选择合适的模型结构和参数设置。

    # build model
    # 根据命令行参数args.model和args.dataset来选择模型类型和数据集类型
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)
    # 调用net_glob.train()来启用模型的训练模式。
    net_glob.train()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

**核心部分:**实现了联邦学习的训练过程。它首先将全局模型的权重复制到每个客户端进行局部训练,然后根据一定的策略聚合客户端的权重,更新全局模型,并打印每轮训练的平均损失值。

    # copy weights
    # 复制当前全局模型net_glob的权重
    w_glob = net_glob.state_dict()

    # training
    # 训练过程中的损失函数列表
    loss_train = []
    # 存储交叉验证的损失和准确率列表
    cv_loss, cv_acc = [], []
    # 存储上一次迭代的验证集损失值和计数器。这些变量通常用于早停策略,在验证集损失不再下降时停止训练,以防止过拟合
    val_loss_pre, counter = 0, 0
    # 存储表现最好的模型和最佳模型对应的验证集损失值
    net_best = None
    best_loss = None
    # 用于存储验证集准确率和模型权重,常用于跟踪验证集上的性能变化和保存模型的快照
    val_acc_list, net_list = [], []
	# 如果为真,表示要对所有客户端进行聚合
    if args.all_clients: 
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(args.num_users)]
    for iter in range(args.epochs):
        # 存储每个客户端的局部损失值
        loss_locals = []
        # 如果不是所有客户,创建列表w_locals存储每个客户端的局部权重
        if not args.all_clients:
            w_locals = []
        # 根据命令行参数args.frac和args.num_users,确定参与本轮训练的客户端数量m
        m = max(int(args.frac * args.num_users), 1)
        # 随机选择m个客户端的索引,存储在idxs_users中
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        # 遍历被选的客户端
        for idx in idxs_users:
            # 执行本地更新
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            # 传入当前的全局模型net_glob的副本,并获取更新后的权重w和局部损失loss
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            # 如果是所有客户端,将更新后的权重w赋值给w_locals的对应索引位置
            if args.all_clients:
                w_locals[idx] = copy.deepcopy(w)
            else:
                # 否则,添加权重w到w_locals列表中
                w_locals.append(copy.deepcopy(w))
            # 将局部损失loss添加到loss_locals列表中
            loss_locals.append(copy.deepcopy(loss))
            
        # update global weights
        # 使用FedAvg函数对w_locals进行聚合,得到更新后的全局权重w_glob
        w_glob = FedAvg(w_locals)

        # copy weight to net_glob
        # 将更新后的全局权重w_glob加载到net_glob中,以便在下一轮迭代中使用
        net_glob.load_state_dict(w_glob)

        # print loss
        # 计算本轮训练的平均损失,并添加到loss_train列表中
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58

完成了以下几个任务:

  1. 绘制损失函数曲线并保存为图片。
  2. 将全局模型设置为评估模式。
  3. 在训练集和测试集上对模型进行测试,计算准确率和损失。
  4. 打印训练准确率和测试准确率。

通过绘制损失函数曲线和计算准确率,可以评估模型的训练效果和泛化能力,并对模型的性能进行分析和比较。

    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))

    # testing
    net_glob.eval()
    acc_train, loss_train = test_img(net_glob, dataset_train, args)
    acc_test, loss_test = test_img(net_glob, dataset_test, args)
    print("Training accuracy: {:.2f}".format(acc_train))
    print("Testing accuracy: {:.2f}".format(acc_test))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

model包

Fed.py

.deepcopy() 用于创建一个对象的深度副本(deep copy),对象的值和原始对象相同,但是在内存中具有不同的地址,以避免对原始对象的修改造成的影响。

#在main_fed.py中调用了该函数,用于更新global weights
# 1.如果是所有客户端
w_locals = [w_glob for i in range(args.num_users)]
w_locals[idx] = copy.deepcopy(w)
# 2.如果不是所有客户端,
w_locals = [] 
w_locals.append(copy.deepcopy(w))
# 然后调用FedAvg函数
w_glob = FedAvg(w_locals) 

#————————————————————
#传入的w参数,是所有参与训练客户端的局部权重列表
def FedAvg(w):
    # 建一个变量 w_avg,并将其初始化为参数列表 w 的第一个元素w[0]的深度副本
    w_avg = copy.deepcopy(w[0])
    # 遍历w_avg的每个key
    for k in w_avg.keys():
        # 遍历参数列表 w 的剩余元素,从第二个元素开始
        for i in range(1, len(w)):
            # 将参数列表 w 中的每个元素的键 k 对应的值加到 w_avg[k]上,在每个键上累积了所有参数的值
            w_avg[k] += w[i][k]
        # 对累积的值进行平均,将其除以参数列表 w 的长度。torch.div() 函数用于执行元素级除法。
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg

#举个例子
student1 = {
    'name': 'John',
    'age': 18,
    'grade': '12th',
    'school': 'ABC High School'
}

student2 = {
    'name': 'Jane',
    'age': 17,
    'grade': '11th',
    'school': 'XYZ High School'
}
student3 = {
    'name': 'hidisan',
    'age': 23,
    'grade': '14th',
    'school': 'NB High School'
}

w = [student1, student2, student3]
#那w_avg一开始等于student1,
w_avg = copy.deepcopy(w[0])
#经过遍历累加得到
w_avg = {'name': 'JohnJanehidisan', 
         'age': 58, 
         'grade': '12th11th14th', 
         'school': 'ABC High SchoolXYZ High SchoolNB High School'}
#然后经过torch.div,执行元素级除法
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55

Nets.py

定义了三个神经网络模型:MLPCNNMnistCNNCifar

第一个简单的多层感知机模型,包含一个输入层、一个隐藏层和一个输出层。它使用线性层进行线性变换,ReLU激活函数引入非线性变换,并应用Dropout层以减少过拟合。这个模型可以通过调用 forward 方法来进行前向传播,将输入数据经过网络的层操作得到输出结果。

class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

定义了一个基于卷积神经网络的模型 CNNMnist,用于处理MNIST数据集。该模型通过两个卷积层提取图像特征,然后通过线性层进行分类。ReLU激活函数和最大池化层用于非线性变换和特征降采样,Dropout层用于减少过拟合。

class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        #用于提取图像的特征。args.num_channels 表示输入图像的通道数,10表示输出通道数,kernel_size=5 表示卷积核的大小为5x5
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        
        #创建第二个二维卷积层,进一步提取特征
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        
        #创建一个二维Dropout层,用于随机失活卷积层的输出特征图。
        self.conv2_drop = nn.Dropout2d()
        
        #创建一个线性层,将卷积层输出的特征图转换为50维的向量。
        self.fc1 = nn.Linear(320, 50)
        
        #创建最后一个线性层,将50维的向量映射到类别数量
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        #通过第一个卷积层,并应用ReLU激活函数和最大池化层来提取特征
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        
        #通过第二个卷积层,并应用ReLU激活函数、Dropout和最大池化层来进一步提取特征
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        
        #对特征图进行形状变换,将其展平为一维向量
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        
        #通过线性层进行特征到隐藏层的线性变换,并应用ReLU激活函数
        x = F.relu(self.fc1(x))
        
        #应用Dropout层,随机失活一部分隐藏层神经元
        x = F.dropout(x, training=self.training)
        
        #通过线性层进行隐藏层到输出层的线性变换
        x = self.fc2(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

定义了一个卷积神经网络模型 CNNCifar,用于处理CIFAR数据集。该模型通过两个卷积层提取图像特征,然后通过线性层进行分类。ReLU激活函数和最大池化层用于非线性变换和特征降采样。

class CNNCifar(nn.Module):
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        #创建一个二维卷积层,用于提取图像的特征。输入通道数为3,输出通道数为6,卷积核大小为5x5。
        self.conv1 = nn.Conv2d(3, 6, 5)
        
        #创建一个最大池化层,用于特征降采样
        self.pool = nn.MaxPool2d(2, 2)
        
        #创建第二个二维卷积层,进一步提取特征
        self.conv2 = nn.Conv2d(6, 16, 5)
        
        #创建一个线性层,将卷积层输出的特征图转换为120维的向量
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        
        #创建一个线性层,将120维的向量映射到84维的向量
        self.fc2 = nn.Linear(120, 84)
        
        #创建最后一个线性层,将84维的向量映射到类别数量
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        #通过第一个卷积层,并应用ReLU激活函数和最大池化层来提取特征
        x = self.pool(F.relu(self.conv1(x)))
        
        #通过第二个卷积层,并应用ReLU激活函数和最大池化层来进一步提取特征
        x = self.pool(F.relu(self.conv2(x)))
        
        #对特征图进行形状变换,将其展平为一维向量
        x = x.view(-1, 16 * 5 * 5)
        
        #通过线性层进行特征到隐藏层的线性变换,并应用ReLU激活函数
        x = F.relu(self.fc1(x))
        
        #通过线性层进行隐藏层到隐藏层的线性变换,并应用ReLU激活函数
        x = F.relu(self.fc2(x))
        
        #通过线性层进行隐藏层到输出层的线性变换
        x = self.fc3(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

test.py

该函数用于对给定的测试数据集进行模型评估。它通过迭代数据加载器,对每个批量的数据进行前向传播和损失计算,然后累加损失和正确分类的样本数。最后计算平均测试损失和准确率,并将其返回

#net_g表示要测试的模型,datatest表示测试数据集,args表示其他参数。在函数开头,将net_g设置为评估模式
def test_img(net_g, datatest, args):
    net_g.eval()
    # testing
    #计算测试损失和正确分类的样本数
    test_loss = 0
    correct = 0
    data_loader = DataLoader(datatest, batch_size=args.bs)
    l = len(data_loader)
    # 对数据加载器进行迭代,每次迭代获取一个批量的数据和对应的目标标签
    for idx, (data, target) in enumerate(data_loader):
        if args.gpu != -1:
            data, target = data.cuda(), target.cuda()
        #调用net_g模型对数据进行前向传播
        log_probs = net_g(data)
        # sum up batch loss
        # 使用交叉熵损失函数F.cross_entropy计算损失并累加到test_loss中
        test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
        # get the index of the max log-probability
        # 利用预测的对数概率计算预测的类别,并与目标标签进行比较,统计正确分类的样本数
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    # 计算平均测试损失和准确率
    test_loss /= len(data_loader.dataset)
    accuracy = 100.00 * correct / len(data_loader.dataset)
    #是否打印详细的测试结果
    if args.verbose:
        print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(data_loader.dataset), accuracy))
    return accuracy, test_loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

Update.py

定义了一个名为DatasetSplit的自定义数据集类,继承自Dataset类。通过使用DatasetSplit类,可以从原始数据集中创建一个子数据集,该子数据集仅包含特定的样本。这在分割数据集用于训练和验证时非常有用,可以根据索引划分数据集并创建相应的训练集和验证集。

class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

定义了一个名为LocalUpdate的类,用于在本地进行模型的训练和更。在train方法中,通过迭代数据加载器的批次,对模型进行前向传播、计算损失、反向传播和参数更新,最终返回模型的状态字典和训练周期的平均损失

class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        #保存传入的参数args,用于配置训练过程中的超参数
        self.args = args
        
        #保存一个交叉熵损失函数的实例,用于计算训练过程中的损失
        self.loss_func = nn.CrossEntropyLoss()
        
        #用于保存选择的客户端
        self.selected_clients = []
        
        #创建一个数据加载器DataLoader,加载一个子数据集DatasetSplit,其中子数据集由参数dataset和idxs指定,设置批量大小为self.args.local_bs,并进行随机洗牌
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)

    def train(self, net):
        #将模型设置为训练模式
        net.train()
        # train and update
        # 创建一个torch.optim.SGD的优化器,使用net.parameters()作为优化器的参数,设置学习率为self.args.lr和动量为self.args.momentum
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)

        #用于保存每个训练周期的损失
        epoch_loss = []
        for iter in range(self.args.local_ep):
            #用于保存每个批次的损失
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                #清零模型参数的梯度
                net.zero_grad()
                
                #通过模型进行前向传播,获取预测的对数概率
                log_probs = net(images)
                
                #使用损失函数计算损失
                loss = self.loss_func(log_probs, labels)
                
                #对损失进行反向传播和参数更新
                loss.backward()
                optimizer.step()
                
                #批次索引能被10整除,打印当前训练进度和损失
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                #计算每个训练周期的平均损失,并将其添加到epoch_loss中
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        #返回模型的状态字典和所有训练周期的平均损失
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

utils包

sampling.py

import numpy as np

from torchvision import datasets, transforms
def mnist_iid(dataset, num_users):
def mnist_noniid(dataset, num_users):
def cifar_iid(dataset, num_users):
    
if __name__ == '__main__':
    dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                   ]))
    num = 100
    d = mnist_noniid(dataset_train, num)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

第一、三个函数

用于从MNIST数据集中随机抽样生成独立同分布(IID)的客户端数据

用于从CIFAR数据集中随机抽样生成独立同分布(IID)的客户端数据

def mnist_iid(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    # 计算每个客户端应该获得的图像数量,即数据集总大小除以客户端数量
    num_items = int(len(dataset)/num_users)
    
    #用于保存生成的字典,键为客户端的标识符,值为图像的索引集合
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        # 从all_idxs中无重复地随机选择num_items个索引,将其作为当前客户端的图像索引集合,并将其添加到dict_users中
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        
        #从all_idxs中移除已分配给当前客户端的索引集合
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


def cifar_iid(dataset, num_users):
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

第二个函数

用于从MNIST数据集中生成非独立同分布(non-IID)的客户端数据

def mnist_noniid(dataset, num_users):
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """
    # 将数据集分成多少个shard(分片),每个shard应该包含多少个图像
    num_shards, num_imgs = 200, 300
    
    #初始化为包含0到num_shards减1的索引列表
    idx_shard = [i for i in range(num_shards)]
    
    #用于保存生成的字典,键为客户端的标识符,值为图像的索引数组
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
    
    #初始化为包含0到(num_shards * num_imgs) - 1的索引数组
    idxs = np.arange(num_shards*num_imgs)
    
    #获取数据集的标签并转换为NumPy数组
    labels = dataset.train_labels.numpy()

    # sort labels,对标签进行排序
    # 将idxs和labels按列堆叠为二维数组idxs_labels
    idxs_labels = np.vstack((idxs, labels))
    
    #排序,以保证相同标签的图像在一起
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]
    
    #更新idxs为排序后的索引数组
    idxs = idxs_labels[0,:]

    # divide and assign
    for i in range(num_users):
        #随机选择2个shard的索引,将其作为当前客户端的shard集合,并将其从idx_shard中移除
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        #对于每个选择的shard索引
        for rand in rand_set:
            #将对应shard中的图像索引范围添加到当前客户端的索引数组中
            dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    return dict_users
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

运行

电脑安装好pytroch,cuda,配置好虚拟环境
直接去pycharm上运行

运行cifar

python main_fed.py --dataset cifar --epoch 10 --num_channel 3 --gpu 0 --model_cnn --iid

  • 1
  • 2

image-20230613200949271

image-20230613201010381

跑20轮

image-20230614112113071

跑50轮

image-20230614112620954

image-20230614123323842

image-20230614125554121

运行mnist

python main_fed.py --dataset mnist --iid --num_channels 1 --model cnn --epochs 50 --gpu 0
	
  • 1
  • 2

image-20230613203734972

image-20230614111827905

image-20230613203756664

non-iid再跑50轮

image-20230614132842962

image-20230614133036272

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/190627?site
推荐阅读
相关标签
  

闽ICP备14008679号