当前位置:   article > 正文

Subgraph Federated Learning with Missing Neighbor Generation——打工日记1_graphsagenodegenerator

graphsagenodegenerator

帮忙打工…
我只负责看代码…论文当时偷懒没咋看(代码在GitHub已经开源)

实验结果
在这里插入图片描述
在这里插入图片描述

对fedsage+的理解

from src.train_locSagePlus import LocalOwner
from src.utils import config
import torch
import torch.nn.functional as F
from torch import optim
from src.global_task import Global
from src.models import feat_loss
from src.utils import mending_graph
from stellargraph.mapper import GraphSAGENodeGenerator
import numpy as np
import time



# GraphSage和NeighGen的本地联合培训
def train_fedgen(local_owners:list,feat_shape:int):
    assert len(local_owners) == config.num_owners
    for owner in local_owners:
        assert owner.__class__.__name__ == LocalOwner.__name__
    local_gen_list=[]
    optim_list=[]
    t=time.time()


    for local_i in local_owners:  # 3
        local_i.set_fed_model()
        # NeighGen结构中的 Hg部分=dgen+fgen
        local_gen_list.append(local_i.fed_model.gen)
        optim_list.append(optim.Adam(local_gen_list[-1].parameters(),
                                  lr=config.lr, weight_decay=config.weight_decay))


    for epoch in range(config.gen_epochs):   #20
        for i in range(config.num_owners):    # 3
            local_gen_list[i].train()   #
            optim_list[i].zero_grad()

            local_model=local_owners[i].fed_model  # fedsage_plus
            input_feat = local_owners[i].all_feat  # feature
            input_edge = local_owners[i].edges
            input_adj = local_owners[i].adj

            output_missing, output_feat, output_nc = local_model(input_feat, input_edge, input_adj)
            output_missing = torch.flatten(output_missing)


            output_feat = output_feat.view(len(local_owners[i].all_ids), local_owners[i].num_pred, local_owners[i].feat_shape)
            output_nc = output_nc.view(len(local_owners[i].all_ids) + len(local_owners[i].all_ids) * local_owners[i].num_pred, local_owners[i].num_classes)

            # GraphSage和NeighGen的本地联合培训
            # 单独训练
            loss_train_missing = F.smooth_l1_loss(output_missing[local_owners[i].train_ilocs].float(),
                                                  local_owners[i].all_targets_missing[local_owners[i].
                                                  train_ilocs].reshape(-1).float())

            loss_train_feat = feat_loss.greedy_loss(output_feat[local_owners[i].train_ilocs],
                                               local_owners[i].all_targets_feat[local_owners[i].train_ilocs],
                                               output_missing[local_owners[i].train_ilocs],
                                               local_owners[i].all_targets_missing[
                                                    local_owners[i].train_ilocs
                                                ]).unsqueeze(0).mean().float()

            true_nc_label = torch.argmax(local_owners[i].all_targets_subj[local_owners[i].train_ilocs], dim=1).view(-1)   # 541

            if config.cuda:
                true_nc_label = true_nc_label.cuda()
            loss_train_label = F.cross_entropy(output_nc[local_owners[i].train_ilocs], true_nc_label)

            acc_train_missing = local_owners[i].accuracy_missing(output_missing[local_owners[i].train_ilocs],
                                                      local_owners[i].all_targets_missing[local_owners[i].train_ilocs])
            acc_train_nc = local_owners[i].accuracy(output_nc[local_owners[i].train_ilocs],
                                         local_owners[i].all_targets_subj[local_owners[i].train_ilocs])

            # -----------------------------------------------------------------------------------------------------------
            loss = (config.a * loss_train_missing + config.b * loss_train_feat + config.c * loss_train_label).float()
            print('Data owner ' + str(i),
                  ' Epoch: {:04d}'.format(epoch + 1),
                  'loss_train: {:.4f}'.format(loss.item()),
                  'missing_train: {:.4f}'.format(acc_train_missing),
                  'nc_train: {:.4f}'.format(acc_train_nc),
                  'loss_miss: {:.4f}'.format(loss_train_missing.item()),
                  'loss_nc: {:.4f}'.format(loss_train_label.item()),
                  'loss_feat: {:.4f}'.format(loss_train_feat.item()),
                  'time: {:.4f}s'.format(time.time() - t))
            # ----------------------------------------------------------------------------------------------------------
            # 联合训练
            for j in range(config.num_owners):
                if j != i:
                    choice = np.random.choice(len(list(local_owners[j].subG.nodes())),
                                              len(local_owners[i].train_ilocs))

                    others_ids=local_owners[j].subG.nodes()[choice]

                    global_target_feat = []

                    for c_i in others_ids:
                        neighbors_ids=local_owners[j].subG.neighbors(c_i)  # 第j个子图,i个idx的邻居

                        while len(neighbors_ids)==0:
                            c_i=np.random.choice(len(list(local_owners[j].subG.nodes())),1)[0]
                            id_i = local_owners[j].subG.nodes()[c_i]
                            neighbors_ids = local_owners[j].subG.neighbors(id_i)

                        choice_i = np.random.choice(neighbors_ids,config.num_pred)

                        for ch_i in choice_i:
                            global_target_feat.append(local_owners[j].subG.node_features([ch_i])[0])
                    global_target_feat = np.asarray(global_target_feat).reshape(
                        (len(local_owners[i].train_ilocs), config.num_pred, feat_shape))
                    # 特征损失
                    loss_train_feat_other = feat_loss.greedy_loss(output_feat[local_owners[i].train_ilocs],
                                                             global_target_feat,# 这里不一样
                                                             output_missing[local_owners[i].train_ilocs],
                                                             local_owners[i].all_targets_missing[
                                                                  local_owners[i].train_ilocs]
                                                             ).unsqueeze(0).mean().float()
                    loss += config.b * loss_train_feat_other
            loss = 1.0 / config.num_owners * loss


            loss.backward()
            optim_list[i].step()

    for i in range(config.num_owners):
        local_owners[i].save_fed_model()

    return



def train_fedSagePC(classifier_list:list,local_owner_list:list,global_task:Global,acc_path):
    assert len(classifier_list) == config.num_owners
    fed_gen_classifier_list=[]
    fill_train_gen_list=[]

    # 3个classifier_list
    for classifier in classifier_list:
        classifier.fedSagePC=classifier.build_classifier(classifier.hasG_node_gen)   # GraphSAGE
        fed_gen_classifier_list.append(classifier.fedSagePC)


    weights=fed_gen_classifier_list[0].get_weights()
    weights_len=len(weights)  # 8
    for fed_gen_classifier in fed_gen_classifier_list[1:]:
        weights_cur=fed_gen_classifier.get_weights()
        for i in range(weights_len):
            weights[i]+=weights_cur[i]


    for i in range(weights_len):
        weights[i]=1.0/config.num_owners*weights[i]



    for owner_i in range(config.num_owners):
        local_owner = local_owner_list[owner_i]
        classifier = classifier_list[owner_i]
        input_feat = local_owner.all_feat
        input_edge = local_owner.edges
        input_adj = local_owner.adj

        #
        pred_missing, pred_feats, _ = local_owner.fed_model(input_feat, input_edge, input_adj)
        # 补点   修补
        # NeighGen模型生成的合成邻域修补每个局部图之后
        fill_nodes, fill_G = mending_graph.fill_graph(local_owner.hasG_hide,
                                                      local_owner.subG,
                                                      pred_missing, pred_feats, local_owner.feat_shape)

        fillG_node_gen = GraphSAGENodeGenerator(fill_G, config.batch_size, config.num_samples)

        fill_train_gen = fillG_node_gen.flow(classifier.train_subjects.index, classifier.train_targets, shuffle=True)
        fill_train_gen_list.append(fill_train_gen)

        # FedSage以获得广义节点分类模型
        classifier.fedSagePC = classifier.build_classifier(fillG_node_gen)  # GraphSAGE
        classifier.fedSagePC.set_weights(weights)


    grad_list=[]
    classifier=classifier_list[0]

    for epoch in range(config.epoch_classifier):
        weight_cur = classifier.fedSagePC.get_weights()
        for owner_i in range(config.num_owners):
            history = classifier.fedSagePC.fit(fill_train_gen_list[owner_i],
                                                         epochs=config.epochs_local,
                                                         verbose=2, shuffle=False)
            weight_send = classifier.fedSagePC.get_weights()
            grad_list.append([weight_send])
            classifier.fedSagePC.set_weights(weight_cur)
            print("local do = " + str(owner_i) + " communication round = " + str(epoch))

        for grad in grad_list[1:]:
            for i in range(len(grad[0])):
                grad_list[0][0][i] += grad[0][i]
        for i in range(len(grad_list[0][0])):
            grad_list[0][0][i] *= 1.0 / config.num_owners
        classifier.fedSagePC.set_weights(grad_list[0][0])
        print("epoch " + str(epoch))
        grad_list = []

    print("FedSage+ end!")
    classifier.save_fedSagePC()
    classifier.load_fedSagePC(global_task.org_gen)
    classifier.test_global(global_task,classifier.fedSagePC,acc_path,
                               name='FedSage+',prefix='')

    return


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/249819
推荐阅读
相关标签
  

闽ICP备14008679号