PyTorch 对抗示例生成

2025-06-18 17:16 更新

对抗示例是机器学习领域中的一个重要研究方向,它揭示了模型在面对恶意攻击时的脆弱性。本教程教你如何生成对抗示例并攻击一个图像分类器。通过学习 FGSM 攻击方法,你将深入了解对抗示例的原理和实现方式。

一、对抗示例概述

对抗示例是指通过在输入数据中添加精心设计的扰动,使机器学习模型产生错误输出的样本。这些扰动通常很小,以至于人类无法察觉,但却能显著影响模型的性能。对抗示例的存在提醒我们在开发机器学习模型时,不仅要关注模型的准确性,还要重视其安全性和鲁棒性。

在实际应用中,攻击者可能对模型有不同的了解程度,这引出了白盒攻击和黑盒攻击的概念:

  • 白盒攻击 :攻击者完全了解模型的结构、参数和训练数据。
  • 黑盒攻击 :攻击者只能访问模型的输入和输出,对模型的内部结构和参数一无所知。

此外,根据攻击目标的不同,对抗示例可以分为错误分类和源 / 目标错误分类两种类型。

二、快速梯度符号攻击(FGSM)

FGSM 是一种简单而有效的对抗示例生成方法。它的核心思想是利用模型的梯度信息来构造对抗扰动。具体来说,FGSM 通过计算损失函数对输入数据的梯度,然后根据梯度的方向调整输入数据,使损失最大化,从而生成对抗示例。

FGSM 的公式可以表示为:

[ x_{\text{adv}} = x + \epsilon \cdot \text{sign}(\nabla_x J(\theta, x, y)) ]

其中,(x) 是原始输入,(\epsilon) 是扰动的幅度,(\text{sign}) 是取符号函数,(\nabla_x J(\theta, x, y)) 是损失函数对输入 (x) 的梯度。

三、实验实现

1. 导入必要的库和模块

我们首先导入实现对抗示例生成所需的库和模块。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

2. 定义受攻击的模型

我们使用一个预训练的 MNIST 分类器作为受攻击的模型。

## LeNet 模型定义
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)


    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


## 加载 MNIST 测试数据集
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(),
    ])),
    batch_size=1, shuffle=True)


## 检测设备并初始化模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)


## 加载预训练模型权重并设置为评估模式
model.load_state_dict(torch.load("data/lenet_mnist_model.pth", map_location=device))
model.eval()

3. 定义 FGSM 攻击函数

def fgsm_attack(image, epsilon, data_grad):
    # 获取数据梯度的符号
    sign_data_grad = data_grad.sign()
    # 生成对抗示例
    perturbed_image = image + epsilon * sign_data_grad
    # 将对抗示例的像素值限制在 [0, 1] 范围内
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

4. 测试函数

def test(model, device, test_loader, epsilon):
    correct = 0
    adv_examples = []


    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        data.requires_grad = True


        output = model(data)
        init_pred = output.max(1, keepdim=True)[1]


        if init_pred.item() != target.item():
            continue


        loss = F.nll_loss(output, target)
        model.zero_grad()
        loss.backward()
        data_grad = data.grad.data


        perturbed_data = fgsm_attack(data, epsilon, data_grad)
        output = model(perturbed_data)


        final_pred = output.max(1, keepdim=True)[1]


        if final_pred.item() == target.item():
            correct += 1
            if epsilon == 0 and len(adv_examples) < 5:
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))
        else:
            if len(adv_examples) < 5:
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))


    final_acc = correct / float(len(test_loader))
    print("Epsilon: {}\tTest Accuracy = {} / {} = {}".format(epsilon, correct, len(test_loader), final_acc))
    return final_acc, adv_examples

5. 运行攻击并可视化结果

epsilons = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
accuracies = []
examples = []


for eps in epsilons:
    acc, ex = test(model, device, test_loader, eps)
    accuracies.append(acc)
    examples.append(ex)


## 绘制精度与 epsilon 的关系图
plt.figure(figsize=(5, 5))
plt.plot(epsilons, accuracies, "*-")
plt.yticks(np.arange(0, 1.1, step=0.1))
plt.xticks(np.arange(0, 0.35, step=0.05))
plt.title("Accuracy vs Epsilon")
plt.xlabel("Epsilon")
plt.ylabel("Accuracy")
plt.show()


## 可视化对抗示例
cnt = 0
plt.figure(figsize=(8, 10))
for i in range(len(epsilons)):
    for j in range(len(examples[i])):
        cnt += 1
        plt.subplot(len(epsilons), len(examples[0]), cnt)
        plt.xticks([], [])
        plt.yticks([], [])
        if j == 0:
            plt.ylabel("Eps: {}".format(epsilons[i]), fontsize=14)
        orig, adv, ex = examples[i][j]
        plt.title("{} -> {}".format(orig, adv))
        plt.imshow(ex, cmap="gray")
plt.tight_layout()
plt.show()

四、实验结果

通过运行上述代码,我们可以得到不同 epsilon 值下模型的测试精度以及一些成功的对抗示例。

从精度与 epsilon 的关系图中可以看到,随着 epsilon 的增加,模型的测试精度逐渐下降。这表明对抗示例的扰动对模型的性能产生了显著影响。

对抗示例的可视化结果展示了在不同 epsilon 值下,原始图像被错误分类为其他类别的示例。尽管扰动很小,但模型的预测结果发生了变化,而人类仍然能够正确识别图像中的数字。

Epsilon 测试精度
0 0.981
0.05 0.9426
0.1 0.851
0.15 0.6826
0.2 0.4301
0.25 0.2082
0.3 0.0869

五、总结

本教程介绍了对抗示例的概念和 FGSM 攻击方法,并通过实验展示了如何生成对抗示例并攻击一个 MNIST 分类器。通过学习本教程,你了解了对抗示例的原理和实现方式,以及它们对模型性能的影响。在编程狮(W3Cschool)网站上,你可以找到更多关于 PyTorch 的详细教程和实战案例,帮助你进一步提升深度学习技能,成为人工智能领域的编程大神。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号