当前位置:   article > 正文

联邦学习安全之后门攻击

联邦学习安全之后门攻击

本博客地址:https://security.blog.csdn.net/article/details/124067669

一、后门攻击定义

联邦学习中,后门攻击是意图让模型对具有某种特定特征的数据做出错误的判断,但模型不会对主任务产生影响。

举个例子,在图像识别中,攻击者意图让带有红色的小车都被识别为小鸟,那攻击者会先通过修改其挟持的客户端样本标签,将带有红色的小车标注为小鸟,让模型重新训练,这样训练得到的最终模型在推断的时候,会将带有红色的小车错误判断为小鸟,但不会影响对其他图片的判断。

在联邦学习场景下进行后门攻击会比较困难,一个原因就是在服务端进行聚合运算时,平均化之后会很大程度消除恶意客户端模型的影响,另一个原因是由于服务端的选择机制,因为并不能保证被攻击者挟持的客户端在每一轮都能被选取,从而降低了被后门攻击的风险。

二、后门攻击策略

带有后门攻击行为的联邦学习,其客户端可以分为恶意客户端和正常客户端。不同类型的客户端,其本地训练策略各不相同。

2.1、正常客户端训练

正常客户端的训练算法如下,其执行过程就是常规的梯度下降过程。

正常客户端的训练算法:

---------------------------------------------------------------------------------------
input: 客户端ID: k ;
          全局模型: θ0 ;
          学习率: η ;
          本地迭代次数: E ;
          每一轮训练的样本大小: B ;
output: 返回模型更新: θ

利用服务端下发的全局模型参数θ0,更新本地模型θθθ0
for 对每一轮的迭代 i = 1, 2, 3, ……, E,执行下面的操作 do
        将本地数据切分为 |B| 份数据 B
        for 对每一个 batch bB
                执行梯度下降:θθη\pounds(θ;b)
        end
end

---------------------------------------------------------------------------------------

2.2、恶意客户端训练

对于恶意客户端的本地训练,主要体现在两个方面:损失函数的设计和上传服务端的模型权重。

● 对于损失函数的设计,恶意客户端训练的目标,一方面是保证在正常数据集和被篡改毒化的数据集中都取得较好的性能;另一方面是保证本地训练的模型与全局模型之间的距离尽量小(距离越小,被服务端判断为异常模型的概率就越小)。 
● 对于上传服务端的模型权重,根据以下公式: Lmt+1nη(XGt)+Gt 可以看出,通过增大异常客户端m的模型权重,使其在后面的聚合过程中,对全局模型的影响和贡献尽量持久。

恶意客户端的训练算法:

---------------------------------------------------------------------------------------
input: 客户端ID: k ;
          全局模型: θ0 ;
          学习率: η ;
          本地迭代次数: E ;
          每一轮的训练样本大小: B ;
output: 返回模型更新: r(Xθ0)+θ0

利用服务端下发的全局模型参数 θ0 ,更新本地模型 X:Xθ0
损失函数:\pounds=\poundsclass_loss+\poundsdistance_loss

for 对每一轮的迭代 i = 1, 2, 3, ……, E,执行下面的操作 do
        将本地数据切分为 |B| 份数据 B
        for 对每一个 batch bB
                数据集 b={Dadvm,Dadvm} 中包含正常的数据集 Dclnm 和被篡改毒化的数据集 Dadvm
                执行梯度下降:XXη\pounds(θ;b)
        end
end

---------------------------------------------------------------------------------------

三、后门攻击具体实现

3.1、客户端

人为篡改客户端client.py的代码,已对代码做出了具体的注释说明,具体细节阅读代码即可。

client.py

  1. import models, torch, copy
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. class Client(object):
  5. def __init__(self, conf, model, train_dataset, id = -1):
  6. self.conf = conf
  7. self.local_model = models.get_model(self.conf["model_name"])
  8. self.client_id = id
  9. self.train_dataset = train_dataset
  10. all_range = list(range(len(self.train_dataset)))
  11. data_len = int(len(self.train_dataset) / self.conf['no_models'])
  12. train_indices = all_range[id * data_len: (id + 1) * data_len]
  13. self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=conf["batch_size"],
  14. sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))
  15. def local_train(self, model):
  16. for name, param in model.state_dict().items():
  17. self.local_model.state_dict()[name].copy_(param.clone())
  18. optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'],
  19. momentum=self.conf['momentum'])
  20. self.local_model.train()
  21. for e in range(self.conf["local_epochs"]):
  22. for batch_id, batch in enumerate(self.train_loader):
  23. data, target = batch
  24. if torch.cuda.is_available():
  25. data = data.cuda()
  26. target = target.cuda()
  27. optimizer.zero_grad()
  28. output = self.local_model(data)
  29. loss = torch.nn.functional.cross_entropy(output, target)
  30. loss.backward()
  31. optimizer.step()
  32. print("Epoch %d done." % e)
  33. diff = dict()
  34. for name, data in self.local_model.state_dict().items():
  35. diff[name] = (data - model.state_dict()[name])
  36. return diff
  37. def local_train_malicious(self, model):
  38. for name, param in model.state_dict().items():
  39. self.local_model.state_dict()[name].copy_(param.clone())
  40. # 设置优化数据
  41. optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'],
  42. momentum=self.conf['momentum'])
  43. pos = []
  44. # 手动篡改数据,设置毒化数据的样式
  45. for i in range(2, 28):
  46. pos.append([i, 3])
  47. pos.append([i, 4])
  48. pos.append([i, 5])
  49. self.local_model.train()
  50. for e in range(self.conf["local_epochs"]):
  51. for batch_id, batch in enumerate(self.train_loader):
  52. data, target = batch
  53. # 在线修改数据,模拟被攻击场景
  54. for k in range(self.conf["poisoning_per_batch"]):
  55. img = data[k].numpy()
  56. for i in range(0,len(pos)):
  57. img[0][pos[i][0]][pos[i][1]] = 1.0
  58. img[1][pos[i][0]][pos[i][1]] = 0
  59. img[2][pos[i][0]][pos[i][1]] = 0
  60. target[k] = self.conf['poison_label']
  61. if torch.cuda.is_available():
  62. data = data.cuda()
  63. target = target.cuda()
  64. optimizer.zero_grad()
  65. output = self.local_model(data)
  66. # 类别损失
  67. class_loss = torch.nn.functional.cross_entropy(output, target)
  68. # 距离损失
  69. dist_loss = models.model_norm(self.local_model, model)
  70. # 总的损失函数为类别损失与距离损失之和
  71. loss = self.conf["alpha"]*class_loss + (1-self.conf["alpha"])*dist_loss
  72. loss.backward()
  73. optimizer.step()
  74. print("Epoch %d done." % e)
  75. diff = dict()
  76. # 计算返回值
  77. for name, data in self.local_model.state_dict().items():
  78. # 恶意客户端返回值
  79. diff[name] = self.conf["eta"]*(data - model.state_dict()[name])+model.state_dict()[name]
  80. return diff

3.2、服务端

由于服务端一般是难以攻破的,所以服务端代码不做改动。

server.py

  1. import models, torch
  2. class Server(object):
  3. def __init__(self, conf, eval_dataset):
  4. self.conf = conf
  5. self.global_model = models.get_model(self.conf["model_name"])
  6. self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.conf["batch_size"], shuffle=True)
  7. def model_aggregate(self, weight_accumulator):
  8. for name, data in self.global_model.state_dict().items():
  9. update_per_layer = weight_accumulator[name] * self.conf["lambda"]
  10. if data.type() != update_per_layer.type():
  11. data.add_(update_per_layer.to(torch.int64))
  12. else:
  13. data.add_(update_per_layer)
  14. def model_eval(self):
  15. self.global_model.eval()
  16. total_loss = 0.0
  17. correct = 0
  18. dataset_size = 0
  19. for batch_id, batch in enumerate(self.eval_loader):
  20. data, target = batch
  21. dataset_size += data.size()[0]
  22. if torch.cuda.is_available():
  23. data = data.cuda()
  24. target = target.cuda()
  25. output = self.global_model(data)
  26. total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
  27. pred = output.data.max(1)[1]
  28. correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
  29. acc = 100.0 * (float(correct) / float(dataset_size))
  30. total_l = total_loss / dataset_size
  31. return acc, total_l

3.3、配置文件

已对代码做出了具体的注释说明,具体细节阅读代码即可。

conf.json

  1. {
  2. "model_name" : "resnet18",
  3. "no_models" : 10,
  4. "type" : "cifar",
  5. "global_epochs" : 20,
  6. "local_epochs" : 3,
  7. "k" : 3,
  8. "batch_size" : 32,
  9. "lr" : 0.001,
  10. "momentum" : 0.0001,
  11. "lambda" : 0.3,
  12. "eta" : 2, // 恶意客户端的权重参数
  13. "alpha" : 1.0, // class_loss和dist_loss之间的权重比例
  14. "poison_label" : 2, // 约定将被毒化的数据归类为哪一类
  15. "poisoning_per_batch" : 4 // 当恶意客户端在本地训练时,在每一轮迭代过程中被篡改的数据量
  16. }

3.4、模型文件

models.json

  1. import torch
  2. from torchvision import models
  3. import math
  4. def get_model(name="vgg16", pretrained=True):
  5. if name == "resnet18":
  6. model = models.resnet18(pretrained=pretrained)
  7. elif name == "resnet50":
  8. model = models.resnet50(pretrained=pretrained)
  9. elif name == "densenet121":
  10. model = models.densenet121(pretrained=pretrained)
  11. elif name == "alexnet":
  12. model = models.alexnet(pretrained=pretrained)
  13. elif name == "vgg16":
  14. model = models.vgg16(pretrained=pretrained)
  15. elif name == "vgg19":
  16. model = models.vgg19(pretrained=pretrained)
  17. elif name == "inception_v3":
  18. model = models.inception_v3(pretrained=pretrained)
  19. elif name == "googlenet":
  20. model = models.googlenet(pretrained=pretrained)
  21. if torch.cuda.is_available():
  22. return model.cuda()
  23. else:
  24. return model
  25. # 定义两个模型的距离函数
  26. def model_norm(model_1, model_2):
  27. squared_sum = 0
  28. for name, layer in model_1.named_parameters():
  29. squared_sum += torch.sum(torch.pow(layer.data - model_2.state_dict()[name].data, 2))
  30. return math.sqrt(squared_sum)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/592506
推荐阅读
相关标签
  

闽ICP备14008679号