当前位置:   article > 正文

基于 Pytorch 实现 Federated Learning 中的安全聚合(基于模型参数)_安全聚合算法krum

安全聚合算法krum

基于 Pytorch 实现 Federated Learning 中的安全聚合(基于模型参数)

最近看了一些关于 FL 的安全聚合的文章,也找了一些代码,但是发现他们都有一些共同点——全是基于 FedSGD 的(原版基于FedSGD 的 github :https://github.com/shanxuanchen/attacking_federate_learning)。但是现在用 FedSGD 的太少了,收敛速度还慢。因此我修改了两个比较经典的安全聚合算法:krum 和 trimmed_median 去适应 FedAVG。
话不多说,直接上代码:

Krum:

def krum(w, args):# csdn 第二姿态,
    distances = defaultdict(dict)
    non_malicious_count = int((args.num_users - args.atk_num) * args.frac)
    num = 0
    for k in w[0].keys():
        if num == 0:
            for i in range(len(w)):
                for j in range(i):
                    distances[i][j] = distances[j][i] = np.linalg.norm(w[i][k].cpu().numpy() - w[j][k].cpu().numpy())
            num = 1
        else:
            for i in range(len(w)):
                for j in range(i):
                    distances[j][i] += np.linalg.norm(w[i][k].cpu().numpy() - w[j][k].cpu().numpy())
                    distances[i][j] += distances[j][i]
    minimal_error = 1e20
    for user in distances.keys():
        errors = sorted(distances[user].values())
        current_error = sum(errors[:non_malicious_count])
        if current_error < minimal_error:
            minimal_error = current_error
            minimal_error_index = user
    return w[minimal_error_index]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

Trimmed_median:

def trimmed_mean(w, args): # csdn 第二姿态,
    number_to_consider = int((args.num_users - args.atk_num) * args.frac) - 1
    print(number_to_consider)
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        tmp = []
        for i in range(len(w)):
            tmp.append(w[i][k].cpu().numpy()) # get the weight of k-layer which in each client
        tmp = np.array(tmp)
        med = np.median(tmp,axis=0)
        new_tmp = []
        for i in range(len(tmp)):# cal each client weights - median
            new_tmp.append(tmp[i]-med)
        new_tmp = np.array(new_tmp)
        good_vals = np.argsort(abs(new_tmp),axis=0)[:number_to_consider]
        good_vals = np.take_along_axis(new_tmp, good_vals, axis=0)
        k_weight = np.array(np.mean(good_vals) + med)
        w_avg[k] = torch.from_numpy(k_weight).to(args.device)
    return w_avg
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

如果有不明白的参数可以继续在评论区交流!!!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/568839
推荐阅读
相关标签
  

闽ICP备14008679号