赞
踩
帮忙打工…
我只负责看代码…论文当时偷懒没咋看(代码在GitHub已经开源)
实验结果
对fedsage+的理解
from src.train_locSagePlus import LocalOwner from src.utils import config import torch import torch.nn.functional as F from torch import optim from src.global_task import Global from src.models import feat_loss from src.utils import mending_graph from stellargraph.mapper import GraphSAGENodeGenerator import numpy as np import time # GraphSage和NeighGen的本地联合培训 def train_fedgen(local_owners:list,feat_shape:int): assert len(local_owners) == config.num_owners for owner in local_owners: assert owner.__class__.__name__ == LocalOwner.__name__ local_gen_list=[] optim_list=[] t=time.time() for local_i in local_owners: # 3 local_i.set_fed_model() # NeighGen结构中的 Hg部分=dgen+fgen local_gen_list.append(local_i.fed_model.gen) optim_list.append(optim.Adam(local_gen_list[-1].parameters(), lr=config.lr, weight_decay=config.weight_decay)) for epoch in range(config.gen_epochs): #20 for i in range(config.num_owners): # 3 local_gen_list[i].train() # optim_list[i].zero_grad() local_model=local_owners[i].fed_model # fedsage_plus input_feat = local_owners[i].all_feat # feature input_edge = local_owners[i].edges input_adj = local_owners[i].adj output_missing, output_feat, output_nc = local_model(input_feat, input_edge, input_adj) output_missing = torch.flatten(output_missing) output_feat = output_feat.view(len(local_owners[i].all_ids), local_owners[i].num_pred, local_owners[i].feat_shape) output_nc = output_nc.view(len(local_owners[i].all_ids) + len(local_owners[i].all_ids) * local_owners[i].num_pred, local_owners[i].num_classes) # GraphSage和NeighGen的本地联合培训 # 单独训练 loss_train_missing = F.smooth_l1_loss(output_missing[local_owners[i].train_ilocs].float(), local_owners[i].all_targets_missing[local_owners[i]. train_ilocs].reshape(-1).float()) loss_train_feat = feat_loss.greedy_loss(output_feat[local_owners[i].train_ilocs], local_owners[i].all_targets_feat[local_owners[i].train_ilocs], output_missing[local_owners[i].train_ilocs], local_owners[i].all_targets_missing[ local_owners[i].train_ilocs ]).unsqueeze(0).mean().float() true_nc_label = torch.argmax(local_owners[i].all_targets_subj[local_owners[i].train_ilocs], dim=1).view(-1) # 541 if config.cuda: true_nc_label = true_nc_label.cuda() loss_train_label = F.cross_entropy(output_nc[local_owners[i].train_ilocs], true_nc_label) acc_train_missing = local_owners[i].accuracy_missing(output_missing[local_owners[i].train_ilocs], local_owners[i].all_targets_missing[local_owners[i].train_ilocs]) acc_train_nc = local_owners[i].accuracy(output_nc[local_owners[i].train_ilocs], local_owners[i].all_targets_subj[local_owners[i].train_ilocs]) # ----------------------------------------------------------------------------------------------------------- loss = (config.a * loss_train_missing + config.b * loss_train_feat + config.c * loss_train_label).float() print('Data owner ' + str(i), ' Epoch: {:04d}'.format(epoch + 1), 'loss_train: {:.4f}'.format(loss.item()), 'missing_train: {:.4f}'.format(acc_train_missing), 'nc_train: {:.4f}'.format(acc_train_nc), 'loss_miss: {:.4f}'.format(loss_train_missing.item()), 'loss_nc: {:.4f}'.format(loss_train_label.item()), 'loss_feat: {:.4f}'.format(loss_train_feat.item()), 'time: {:.4f}s'.format(time.time() - t)) # ---------------------------------------------------------------------------------------------------------- # 联合训练 for j in range(config.num_owners): if j != i: choice = np.random.choice(len(list(local_owners[j].subG.nodes())), len(local_owners[i].train_ilocs)) others_ids=local_owners[j].subG.nodes()[choice] global_target_feat = [] for c_i in others_ids: neighbors_ids=local_owners[j].subG.neighbors(c_i) # 第j个子图,i个idx的邻居 while len(neighbors_ids)==0: c_i=np.random.choice(len(list(local_owners[j].subG.nodes())),1)[0] id_i = local_owners[j].subG.nodes()[c_i] neighbors_ids = local_owners[j].subG.neighbors(id_i) choice_i = np.random.choice(neighbors_ids,config.num_pred) for ch_i in choice_i: global_target_feat.append(local_owners[j].subG.node_features([ch_i])[0]) global_target_feat = np.asarray(global_target_feat).reshape( (len(local_owners[i].train_ilocs), config.num_pred, feat_shape)) # 特征损失 loss_train_feat_other = feat_loss.greedy_loss(output_feat[local_owners[i].train_ilocs], global_target_feat,# 这里不一样 output_missing[local_owners[i].train_ilocs], local_owners[i].all_targets_missing[ local_owners[i].train_ilocs] ).unsqueeze(0).mean().float() loss += config.b * loss_train_feat_other loss = 1.0 / config.num_owners * loss loss.backward() optim_list[i].step() for i in range(config.num_owners): local_owners[i].save_fed_model() return def train_fedSagePC(classifier_list:list,local_owner_list:list,global_task:Global,acc_path): assert len(classifier_list) == config.num_owners fed_gen_classifier_list=[] fill_train_gen_list=[] # 3个classifier_list for classifier in classifier_list: classifier.fedSagePC=classifier.build_classifier(classifier.hasG_node_gen) # GraphSAGE fed_gen_classifier_list.append(classifier.fedSagePC) weights=fed_gen_classifier_list[0].get_weights() weights_len=len(weights) # 8 for fed_gen_classifier in fed_gen_classifier_list[1:]: weights_cur=fed_gen_classifier.get_weights() for i in range(weights_len): weights[i]+=weights_cur[i] for i in range(weights_len): weights[i]=1.0/config.num_owners*weights[i] for owner_i in range(config.num_owners): local_owner = local_owner_list[owner_i] classifier = classifier_list[owner_i] input_feat = local_owner.all_feat input_edge = local_owner.edges input_adj = local_owner.adj # pred_missing, pred_feats, _ = local_owner.fed_model(input_feat, input_edge, input_adj) # 补点 修补 # NeighGen模型生成的合成邻域修补每个局部图之后 fill_nodes, fill_G = mending_graph.fill_graph(local_owner.hasG_hide, local_owner.subG, pred_missing, pred_feats, local_owner.feat_shape) fillG_node_gen = GraphSAGENodeGenerator(fill_G, config.batch_size, config.num_samples) fill_train_gen = fillG_node_gen.flow(classifier.train_subjects.index, classifier.train_targets, shuffle=True) fill_train_gen_list.append(fill_train_gen) # FedSage以获得广义节点分类模型 classifier.fedSagePC = classifier.build_classifier(fillG_node_gen) # GraphSAGE classifier.fedSagePC.set_weights(weights) grad_list=[] classifier=classifier_list[0] for epoch in range(config.epoch_classifier): weight_cur = classifier.fedSagePC.get_weights() for owner_i in range(config.num_owners): history = classifier.fedSagePC.fit(fill_train_gen_list[owner_i], epochs=config.epochs_local, verbose=2, shuffle=False) weight_send = classifier.fedSagePC.get_weights() grad_list.append([weight_send]) classifier.fedSagePC.set_weights(weight_cur) print("local do = " + str(owner_i) + " communication round = " + str(epoch)) for grad in grad_list[1:]: for i in range(len(grad[0])): grad_list[0][0][i] += grad[0][i] for i in range(len(grad_list[0][0])): grad_list[0][0][i] *= 1.0 / config.num_owners classifier.fedSagePC.set_weights(grad_list[0][0]) print("epoch " + str(epoch)) grad_list = [] print("FedSage+ end!") classifier.save_fedSagePC() classifier.load_fedSagePC(global_task.org_gen) classifier.test_global(global_task,classifier.fedSagePC,acc_path, name='FedSage+',prefix='') return
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。