当前位置:   article > 正文

PGD攻击生成对抗样本_pgd attack

pgd attack
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50
from advertorch.attacks import PGDAttack
from PIL import Image
import matplotlib.pyplot as plt
import requests

# 加载预训练的 ResNet 模型和测试图像
model = resnet50(pretrained=True)
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

image_path = 'C:/Users/Administrator/Desktop/flower.jpg'
image = transform(Image.open(image_path)).unsqueeze(0)

# 定义损失函数和 PGD 攻击器
criterion = torch.nn.CrossEntropyLoss()
adversary = PGDAttack(model, loss_fn=criterion, eps=0.01, nb_iter=40, eps_iter=0.01)

# 运行 PGD 攻击生成对抗样本
label = torch.tensor([985])  #daisy在ImageNet的label为985,如若使用其他图片可以直接输出一次识别结果即可
adv_image = adversary.perturb(image, label)

# 可视化原始图像和对抗样本
original_image = transforms.ToPILImage()(image.squeeze(0))
adversarial_image = transforms.ToPILImage()(adv_image.squeeze(0))

plt.subplot(1, 2, 1)
plt.imshow(original_image)
plt.title('Original Image')

plt.subplot(1, 2, 2)
plt.imshow(adversarial_image)
plt.title('Adversarial Image')
plt.show()

# 对抗样本的识别结果
with torch.no_grad():
    output_original = model(image)
    output_adversarial = model(adv_image)

    _, predicted_original = torch.max(output_original, 1)
    _, predicted_adversarial = torch.max(output_adversarial, 1)

# 获取 ImageNet 类别标签
labels = requests.get("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json").json()

print(f"Original Prediction: {labels[predicted_original.item()]}")
print(f"Adversarial Prediction: {labels[predicted_adversarial.item()]}")
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/365782
推荐阅读
相关标签
  

闽ICP备14008679号