当前位置:   article > 正文

联邦学习FedAvg自编写代码_联邦学习代码

联邦学习代码

联邦学习中,联邦平均算法获得了很大的使用空间,因此常常被用于进行同步训练操作
不多废话了,以下为Fedavg代码
由于使用场景为NonIID场景,因此我使用了别人的一个MNIST数据集自定义的代码(见附录)
FedAvg代码如下,功能具体看注释
工作环境:python3.8.5 + pytorch(无cuda)
divergence模块可直接删除

# coding: utf-8

# In[1]:


import argparse
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable
from PIL import Image
import torch
import copy
import pandas as pd
import random
import time
import sys
import re
import matplotlib.pyplot as plt
#import divergence

name = str(sys.argv[0])

# In[2]:

home_path = "./"
class MyDataset(torch.utils.data.Dataset): #创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset
    def __init__(self,root,data,label,transform=None, target_transform=None): #初始化一些需要传入的参数
        super(MyDataset,self).__init__()
        imgs = []                      #创建一个名为img的空列表,一会儿用来装东西
        self.img_route = root
        for i in range(len(data)):
            imgs.append((data[i],int(label[i])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
 
    def __getitem__(self, index):    #这个方法是必须要有的,用于按照索引读取每个元素的具体内容
        fn, label = self.imgs[index] #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
        route = self.img_route + str(label) + "/" + fn
        img = Image.open(route) #按照path读入图片from PIL import Image # 按照路径读取图片
        if self.transform is not None:
            img = self.transform(img) #是否进行transform
        return img,label  #return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
 
    def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
        return len(self.imgs)


# In[3]:


filePath = home_path + 'data/MNIST/image_turn/'
train_data = []
train_label = []
for i in range(10):
    train_data.append(os.listdir(filePath+str(i)))
    train_label.append([i]*len(train_data[i]))
filePath = home_path + 'data/MNIST/image_test_turn/'
test_data = []
test_label = []
for i in range(10):
    test_data.append(os.listdir(filePath+str(i)))
    test_label.append([i]*len(test_data[i]))
test_ori = []
test_label_ori = []
for x in range(10):
    test_ori += test_data[x]
    test_label_ori += test_label[x]
test_data=MyDataset(home_path + "data/MNIST/image_test_turn/",test_ori,test_label_ori, transform=transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64)


# In[4]:


class
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号