瞭解積分梯度如何幫助識別哪些輸入特徵對模型的預測貢獻
在金融和醫療保健等受到高度監管的行業中使用 AI 模型肩負著關鍵責任:可解釋性。您的模型預測準確是不夠的。您應該能夠解釋您的模型做出特定預測的原因。例如,如果我們正在開發一個基於腦部 MRI 掃描的腫瘤檢測模型,我們應該能夠解釋我們的模型使用哪些資訊以及它如何處理這些資訊並導致腫瘤識別。在這種情況下,監管機構或醫生需要了解這些細節,以確保結果公正和準確。那麼,您如何解釋您的模型決策呢?手動解釋它們並不容易,因為沒有簡單的 “if-else” 邏輯 — 深度學習模型通常有數百萬個引數以非線性方式互動,因此無法追蹤從輸入到輸出的路徑。
當我們要了解某個特徵(比如水果的顏色)對模型預測結果(比如水果的價格)的影響時,我們會從一個基線開始——這個基線是一個我們認為對模型輸出沒有影響的點,比如將所有特徵值設為零或者平均值。然後,我們會逐步增加這個特徵的重要性,直到達到實際輸入的值。在這個過程中,我們會記錄每次小幅度增加特徵值時模型輸出的變化。最後,我們將這些小變化加起來,就能得到這個特徵對模型輸出總貢獻的一個估計。Integrated Gradients是一種幫助我們理解複雜模型工作原理的技術。它透過計算特徵值從小到大的過程中模型輸出的變化量,來量化每個特徵對模型預測結果的影響。這種方法不僅適用於影像識別任務,還可以應用於文字分析、聲音識別等多個領域。
在滿足這一需求方面,我有實踐經驗的技術之一是積分梯度。它是由 Google 的研究人員於 2017 年推出的,這是一種強大的方法,透過整合從基線到實際輸入的梯度來計算歸因。在本文中,我將引導您完成一個影像分類用例,並向您展示整合梯度如何幫助我們瞭解哪些影像畫素在決策中最重要。我將使用 Captum 庫來計算歸因,並使用預先訓練的 ResNet 模型來預測影像。
環境設定
已安裝 Python 3.10 或更高版本
安裝下面提到的必要軟體包
pip install captum torch torchvision Matplotlib NumPy PIL
下載示例影像並將其命名為 image.jpg(您可以下載所需的任何影像)。
現在,我們將載入瞭解物件的預訓練 ResNet 模型,使用 ResNet 模型對影像進行分類,使用整合梯度技術計算屬性,並視覺化結果,顯示影像的哪些畫素對模型的預測最重要。
以下是帶有詳細註釋的完整實現。
import torch
import torchvision
import torchvision.transforms as transforms
from captum.attr import IntegratedGradients
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np# 載入我們剛剛下載的影像。如果你下載的影像名字不同,請更改影像名稱。
image_to_predicted = Image.open('image.jpg')# 將影像調整為標準格式並轉換為數字。
transformed_image = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])(image_to_predicted).unsqueeze(0)# 下載預訓練的ResNet模型,使其準備好進行預測。
model = torchvision.models.resnet18(pretrained=True)
model.eval()# 進行預測並使用argmax函式找到機率最高的物件。
predicted_image_class = torch.nn.functional.softmax(model(transformed_image)[0], dim=0).argmax().item()# 建立IntegratedGradients物件並計算屬性
integrated_gradients = IntegratedGradients(model)# 建立基線參考,IG計算從這裡開始
baseline_image = torch.zeros_like(transformed_image)# 使用預測的影像類別、基線和轉換後的影像計算歸因
computed_attributions, delta = integrated_gradients.attribute(transformed_image, baseline_image, target=predicted_image_class, return_convergence_delta=True)# 將歸因轉換為numpy陣列以進行視覺化
attributions_numpy = np.abs(np.transpose(computed_attributions.squeeze().cpu().detach().numpy(), (1, 2, 0))) * 255 * 10
attributions_numpy = attributions_numpy.astype(np.uint8)# 視覺化下載的影像和帶有歸因的影像
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img)
axes[0].axis('off')
axes[0].set_title('下載的影像')
axes[1].imshow(attributions_numpy, cmap='magma')
axes[1].axis('off')
axes[1].set_title('帶有歸因的影像')
plt.show()
這是我下載的示例影像的輸出。您可以看到突出顯示的區域顯示每個畫素對於模型預測的重要性。
結論
我們探討了如何使用積分梯度來解釋深度學習模型的預測。使用積分梯度,我們深入瞭解了影像的哪些畫素對模型的預測最重要。這不是唯一可用於模型可解釋性的技術。其他技術,如特徵重要性、Shapley 加法解釋 (SHAP),也可用於深入瞭解模型行為。
今天先到這兒,希望對雲原生,技術領導力, 企業管理,系統架構設計與評估,團隊管理, 專案管理, 產品管理,資訊保安,團隊建設 有參考作用 , 您可能感興趣的文章:
構建創業公司突擊小團隊
國際化環境下系統架構演化
微服務架構設計
影片直播平臺的系統架構演化
微服務與Docker介紹
Docker與CI持續整合/CD
網際網路電商購物車架構演變案例
網際網路業務場景下訊息佇列架構
網際網路高效研發團隊管理演進之一
訊息系統架構設計演進
網際網路電商搜尋架構演化之一
企業資訊化與軟體工程的迷思
企業專案化管理介紹
軟體專案成功之要素
人際溝通風格介紹一
精益IT組織與分享式領導
學習型組織與企業
企業創新文化與等級觀念
組織目標與個人目標
初創公司人才招聘與管理
人才公司環境與企業文化
企業文化、團隊文化與知識共享
高效能的團隊建設
專案管理溝通計劃
構建高效的研發與自動化運維
某大型電商雲平臺實踐
網際網路資料庫架構設計思路
IT基礎架構規劃方案一(網路系統規劃)
餐飲行業解決方案之客戶分析流程
餐飲行業解決方案之採購戰略制定與實施流程
餐飲行業解決方案之業務設計流程
供應鏈需求調研CheckList
企業應用之效能實時度量系統演變
如有想了解更多軟體設計與架構, 系統IT,企業資訊化, 團隊管理 資訊,請關注我的微信訂閱號:
作者:Petter Liu
出處:http://www.cnblogs.com/wintersun/
本文版權歸作者和部落格園共有,歡迎轉載,但未經作者同意必須保留此段宣告,且在文章頁面明顯位置給出原文連線,否則保留追究法律責任的權利。
該文章也同時釋出在我的獨立部落格中-Petter Liu Blog。