引言
在深度學習領域內的對抗樣本綜述(二)中,我們知道了幾種著名的對抗攻擊和對抗防禦的方法。下面具體來看下幾種對抗攻擊是如何工作的。這篇文章介紹FGSM(Fast Gradient Sign Method)。
預備知識
符號函式sign
泰勒展開
當函式\(f(x)\)在點\(x_0\)處可導時,在點\(x_0\)的鄰域\(U(x_0)\)內恆有:
因為\(o(x-x_0)\)是一個無窮小量,故有:
這是在對函式進行區域性線性化處理時常用的公式之一。從幾何上看,它是用切線近似代替曲線。這樣的近似是比較粗糙的,而且只在點的附近才有近似意義。
梯度
梯度是偏導陣列成的向量。若有函式\(f(x^{(1)},x^{(2)},x^{(3)})\),則\(f\)在點\(θ_0=[x_0^{(1)},x_0^{(2)},x_0^{(3)}]^T\)處的梯度為:
一元函式的導數表示函式增加最快的方向,那麼梯度表示多元函式值增加最快的方向。
FGSM公式
ϵ為hyperparameter,控制原影像和對抗樣本之間的差異程度。(字母加粗表示向量)
在梯度下降法中,我們求損失函式關於權重w、偏移b(統稱引數θ)的梯度,然後更新引數,即引數\(\textbf{θ}=\textbf{θ}-η*\nabla_θ J(\textbf{x},y,\textbf{θ})\),η為learning rate。
而在FGSM中,我們用加梯度方向的ϵ倍的方式更新輸入。
注意兩者的不同:梯度代表函式值增加最快的方向,更新引數時,我們要做的是使損失函式J減小(在輸入確定的情況下),因此減去梯度;而獲取對抗樣本時,我們要做的是使損失函式J增大(在θ確定的情況下),因此增加梯度,但又要控制擾動的大小,因此只取梯度的方向,其大小統一控制為ϵ。
為什麼FGSM中要讓損失函式增加?因為J 越大,表明預測class概率向量和真實one-hot class向量的距離越大,更有可能使預測器輸出錯誤的label。用數學來解釋下,損失函式在輸入x附近\(x_{adv}\)處的泰勒展開:
\(ϵ*sign(∇_\textbf{x}J(\textbf{x},y,\textbf{θ}))\)即泰勒展開中的\((x-x_0)\)項。
在上式中,\(\nabla_x J(\textbf{x},y,\textbf{θ})^T*ϵ*sign(∇_\textbf{x}J(\textbf{x},y,\textbf{θ}))\)為非負數,則\(J(\textbf{x}_{adv},y,\textbf{θ})>=J(\textbf{x},y,\textbf{θ})\),說明我們達到了讓損失函式增大的目的。
\(\nabla_x J(\textbf{x},y,\textbf{θ})^T*ϵ*sign(∇_\textbf{x}J(\textbf{x},y,\textbf{θ}))\)是非負數,因為:
FGSM程式碼
def fgsm(model, loss, eps, softmax=False):
"""
單次FGSM
model為目標模型
loss為傳入的損失函式計算函式
eps為限定擾動大小
"""
def attack(img, label):
output = model(img)
if softmax:
error = loss(output, label)
else:
error = loss(output, label.unsqueeze(1).float())
error.backward() # 計算損失函式對輸入x的梯度
# clamp()使perturbed_img的各分量在[0,1]區間
perturbed_img = torch.clamp(img + eps * img.grad.data.sign(), 0, 1).detach()
img.grad.zero_()
return perturbed_img
return attack
def ifgsm(model, loss, eps, iters=4, softmax=False):
# 多次FGSM
def attack(img, label):
perturbed_img = img
perturbed_img.requires_grad = True
for _ in range(iters):
output = model(perturbed_img)
if softmax:
error = loss(output, label)
else:
error = loss(output, label.unsqueeze(1).float())
error.backward()
temp = torch.clamp(perturbed_img + eps * perturbed_img.grad.data.sign(), 0, 1).detach()
perturbed_img = temp.data
perturbed_img.requires_grad = True
return perturbed_img.detach()
return attack
參考文獻
[1] Goodfellow I J , Shlens J , Szegedy C . Explaining and Harnessing Adversarial Examples[J]. Computer Science, 2014.
[2] 為什麼函式的導數大於等於零或小於等於零就可以判斷函式是增還是減? - Observer的回答 - 知乎 https://www.zhihu.com/question/377992767/answer/1104094160