赞
踩
联邦学习中,联邦平均算法获得了很大的使用空间,因此常常被用于进行同步训练操作
不多废话了,以下为Fedavg代码
由于使用场景为NonIID场景,因此我使用了别人的一个MNIST数据集自定义的代码(见附录)
FedAvg代码如下,功能具体看注释
工作环境:python3.8.5 + pytorch(无cuda)
divergence模块可直接删除
# coding: utf-8 # In[1]: import argparse import torch import os import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms from torch.autograd import Variable from PIL import Image import torch import copy import pandas as pd import random import time import sys import re import matplotlib.pyplot as plt #import divergence name = str(sys.argv[0]) # In[2]: home_path = "./" class MyDataset(torch.utils.data.Dataset): #创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset def __init__(self,root,data,label,transform=None, target_transform=None): #初始化一些需要传入的参数 super(MyDataset,self).__init__() imgs = [] #创建一个名为img的空列表,一会儿用来装东西 self.img_route = root for i in range(len(data)): imgs.append((data[i],int(label[i]))) self.imgs = imgs self.transform = transform self.target_transform = target_transform def __getitem__(self, index): #这个方法是必须要有的,用于按照索引读取每个元素的具体内容 fn, label = self.imgs[index] #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息 route = self.img_route + str(label) + "/" + fn img = Image.open(route) #按照path读入图片from PIL import Image # 按照路径读取图片 if self.transform is not None: img = self.transform(img) #是否进行transform return img,label #return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容 def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分 return len(self.imgs) # In[3]: filePath = home_path + 'data/MNIST/image_turn/' train_data = [] train_label = [] for i in range(10): train_data.append(os.listdir(filePath+str(i))) train_label.append([i]*len(train_data[i])) filePath = home_path + 'data/MNIST/image_test_turn/' test_data = [] test_label = [] for i in range(10): test_data.append(os.listdir(filePath+str(i))) test_label.append([i]*len(test_data[i])) test_ori = [] test_label_ori = [] for x in range(10): test_ori += test_data[x] test_label_ori += test_label[x] test_data=MyDataset(home_path + "data/MNIST/image_test_turn/",test_ori,test_label_ori, transform=transforms.ToTensor()) test_loader = DataLoader(dataset=test_data, batch_size=64) # In[4]: class
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。