更多幹貨內容請關注微信公眾號“AI 前線”,(ID:ai-front)
對於事件(例如死亡)發生時間的個性化概率預測在制定決策中是十分重要的,尤其在臨床情況下。受到氣象學研究的啟發,吳恩達團隊通過最大化預測分佈的銳度(sharpness)來解決這個問題。在迴歸問題中,研究顯示通過優化連續分級概率評分(continuous ranked probability score, CRPS),能夠帶來銳度更高的預測分佈,同時能夠保持校準度(calibration)。
這篇論文介紹了 Survival-CRPS 模型,該方法是 CRPS 在預測事件發生時間上的推廣,並且提出了右刪失資料和間隔刪失資料兩個變種。為了全面評價預測分佈的銳利程度,研究人員提出了 Survival-AUPRC 評價標準,計算方法類似於準確度 – 召回曲線下的面積。通過構建一個迴圈神經網路,將提出的方法應用於死亡預測。研究使用了電子醫療記錄(Electronic Health Record, EHR)資料庫,其中包含數百萬個患者的醫療資料。實驗結果顯示,通過 Survival-CRPS 目標函式訓練的模型的表現相比於最大似然法有顯著提高。
在近幾十年內,電子醫療記錄(EHR)的普及為科學研究帶來了數百萬病患的詳細醫療資料。大量的資料使利用機器學習模型對病人做出個性化的預測成為可能。
傳統方法將病人的生存預測視作一個概率分類問題,即在一定的時間跨度上訓練二值分類器來預測事件結果。但是這種方法有三個缺陷:首先,模型受到時間跨度的限制——如果模型的訓練目標是預測一年內的死亡率,那麼就很難直接獲取 6 個月內的死亡率預測;其次,不能應用所有病人的資料——如果一個病人的 EHR 只有 3 個月的記錄,那麼對於預測一年內的死亡率問題,很難決定該將這個人劃為正樣本還是負樣本。最後,在建立資料集時,對預測時間的選擇毫無疑問受限於未來的結果——研究結果顯示評價標準相比於真實情況過於樂觀。
另一種方法是生存預測,即通過評價未來時間的概率分佈來預測事件發生時間。但是常用的生存預測模型也有一些問題:第一,傳統模型通常做出很強的假設。第二,這種模型應用於有大量刪失資料的資料庫時,對於低發病人群的預測不是很準確。第三,此類生存分析方法通常是對風險的點評價,而不是對預測分佈的全域性評價。
對比之下氣象學預報通常是基於過去和目前的觀測情況,對所有的天氣情況作出全面的預測分佈。預測結果由最大化預測分佈的銳利程度來評價。一個預測分佈的有用程度體現在它的銳利程度中,即資料的聚集程度。為了提高預測分佈曲線的銳利程度,我們提出採用優於最大似然法的適當評分法則(Proper Scoring Rule)作為訓練的目標函式。我們將 CRPS 擴充套件至生存問題中,定義為 Survival-CRPS,並且分別進行了資料右刪失和間隔刪失的延伸。
圖 1
(1) 提出了適當評分法則 Survival-CRPS 作為生存預測的目標函式,並且提出了它的右刪失和間隔刪失變體。
(2) 提出了新的評價標準 Survival-AUPRC,來全面評價預測分佈的質量。
(3) 給出了死亡預測任務的實用方法:在訓練時使用對數正則引數化和間隔刪失。
(4) 我們應用上述技巧,利用 EHR 資料訓練一個深度迴圈神經網路模型,對患者的死亡進行準確的預測。
引數生存預測將事件發生的時間建模為一簇由分佈引數定義的概率分佈曲線。生存函式定義為 S(t)=[0,正無窮),定義域為 0 到 1。在正實數範圍內單調遞減。S(0)=1,t= 正無窮時 S(t)=0。生存函式代表一個個體在給定時間 t 內沒有發生事件(死亡)的概率。每一個生存函式都有一個對應的累積分佈函式(CDF):F(t)=1-S(t),和一個概率密度函式(PDF):f(t)=dF(t)/dt。
我們將病人 i 的醫療記錄記為:
t 代表病人與該健康記錄的互動次數。xt(i) 對應於第 t 次互動對應的特徵,at 是時間 t 時的年齡,d(i) 代表死亡年齡,或最後進行互動的年齡。c(i) 是刪失指標。對於每一個 xt(i),我們定義一個量 yt(i)=d(i)-a(i),代表對應的死亡時間或資料刪失時間。
評分法則是評價概率預報質量的方法。對於連續輸出的預報是所有可能結果的概率密度函式 f(PDF),以及對應的累積分佈函式 F。在現實生活中,我們可以觀測到一些實際結果 y。評分法則 S 計算預測分佈和實際結果之間的誤差,返回損失值 S(F, y)。如果對於所有的可能分佈 G,S 都是一個合適評分法則,則有:
合適評分法則鼓勵模型進行真實預測。當採用合適評分法則作為損失函式時,它自然地約束模型輸出校準後的概率。
氣象學中常用 CRPS 作為預測連續結果的適當評分法則:
CRPS 通常作為迴歸問題的目標函式,相比於最大似然法能夠產生更銳利的預測分佈,同時保持資料校準。上式的後兩項積分對應於圖 2a 中的兩個陰影區域。
圖 2
為了預測事件發生所需時間,我們提出了 Survival-CRPS,用於計算右刪失或間隔刪失資料的概率:
當 c=0 時,上面兩個方程都退化成原始的 CRPS。同樣的,上面兩個式子的積分項也分別對應於圖 2b 和圖 2c 中的陰影區域。對於刪失資料的結果,Survival-CRPS 懲罰出現在刪失時間之前的值,對於間隔刪失,則刪失時間之後的值也會被懲罰。
Survival-CRPS 的兩個變體都是適當評分法則。他們可以算是閾值加權 CRPS 的特例,權重函式即為未刪失區域的指示值。
校準度評估的是預測事件概率與觀測事件頻率的匹配程度。它對於預測模型,尤其是臨床診斷決策來說是十分重要的。我們採用下面的方法來衡量校準度:我們在預測累積密度的分位點處對比預測累積概率密度和觀測的事件頻率。右刪失的觀測值不計算刪失點之後的分位點。間隔刪失的觀測資料與此類似,但是在事件一定會出現的時間點之後的分位點再引入。
在保證校準度的情況下,我們也希望得到銳度較高的預測分佈。我們使用變異係數(Coefficient of Variation, CoV)作為銳度的度量指標。CoV 定義為標準差和均值的比值:
由於銳度僅僅是預報分佈的一個函式,因此只有當模型完全校準時,評價銳度才是有意義的。我們提出一個新的衡量標準,可以衡量預報分佈質量的聚集程度,對未校準的模型具有魯棒性。這個想法與計算精確度 – 召回率曲線(Precision-Recall curve)下方的面積類似,只是這裡僅考慮一個結果和對應的一個預測分佈。
首先考慮未刪失的情況,我們用事件發生時間附近的間隔來類比精確度,例如在時間 y 的事件周圍,預測精度為 0.9 的間隔則為 [0.9y, y/0.9]。對應於這個精度的區域,我們用預報分佈在這個區間段上分配的質量來類比召回率:F(y/0.9)-F(0.9y)。曲線下的面積衡量的是隨精度視窗擴充套件,預測的質量在真實結果附近的聚集速度。
Survival-AUPRC 的最高分為 1,此時預測分佈為一個狄拉克函式,在事件發生時間附近聚集。最低分數為 0,此時預測分佈無窮大。所有樣例的 Survival-AUPRC 分數的平均值為預測的質量提供了一個整體的評估。
上述的度量指標只適用於未刪失的情況,在刪失資料的情況下,我們使用同樣的類比方法,但是對時間間隔進行了調整:
我們通過構建一個多層迴圈神經網路,將提出的方法應用在死亡預測任務中。網路輸入為特徵序列(EHR 中的病人資訊),來預測概率分佈函式 F 的引數。該網路只依賴目前和之前的時間點資料,不依賴未來的資料。每個時間點輸出的概率分佈構成整體損失:
這種序列性的單調遞減模型,我們稱之為倒數計時迴歸(Countdown Regression)。
我們在死亡預測任務上進行實驗,評估四個不同的訓練目標:最大似然 S-MLE-RIGHT 和 S-MLE-INTVL,以及基於我們提出的評分法則的損失函式:S-CRPS-RIGHT 和 S-CRPS-INTVL。
我們採用電子醫療檔案 EHR(來自 STARR 資料倉儲)用於訓練和驗證。該資料包含超過 300 萬名病人的記錄(約 2.6% 的病人有死亡日期記錄),跨度大概為 27 年。每個病人的輸入序列的時間點對應 EHR 給定日期的所有資料。我們使用了診斷碼、實驗測試順序碼、治療型別碼以及人口統計學資料(年齡和性別)。每個程式碼都有一個隨即初始化的內嵌向量,作為需要學習的引數。300 萬個病人,對應 5100 萬的時間點,按比例隨機分成 8:1:1,分別作為訓練集、驗證集、測試集。
我們首先驗證模型的校準度(圖 3)。變異係數和 Survival-AUPRC 度量指標均顯示帶有間隔刪失的 Survival-CRPS 方法可以得到銳度最高的預測分佈(表 1)。
圖 3
每個模型的校準度圖。我們對比了預測累積密度和觀測時間頻率,分別在預測累積密度的分位點進行比較。
表 1 分別用最大似然和 Survival-CRPS 目標函式訓練的右刪失和間隔刪失模型的銳度和校準度比較。
通過分析預測模型給超過 120 歲之後分配的死亡時間可以看出,用最大似然簡單訓練的模型會將超過 75% 的質量分配給不合理的時間。我們發現這種行為主要由於未刪失資料中低發病率的樣例,但這在真實世界的 EHR 資料中廣泛存在。刪失樣例的損失函式可以通過將質量儘可能推向右側來最小化,也因此支配了少量未刪失樣例。
通過對死亡時間的整體預測,這個模型也可以應用於對不同時間點進行預測。當預測 6 個月、1 年和 5 年內的死亡概率時,我們的模型可以保持很好的校準度,以及極高的區分度。
圖 4 間隔刪失 Survival-CRPS 模型的區分度和校準度曲線。
圖 5 間隔刪失 Survival-CRPS 模型預測單個病人死亡時間的中位數。我們的模型給出了最可靠的預測結果。真實的死亡時間基本位於預測的時間段內。
我們可以通過比最大似然更好的目標函式,以及對預測分佈進行全面評價的標準來打造更好的生存預測模型。在這篇論文中,由於受到 CRPS 評分標準的啟發,我們提出了 Survival-CRPS 目標函式,可以產生銳度較高的預測分佈,同時保持校準度。我們介紹了 Survival-CRPS 評價標準,能夠捕捉到預測分佈在觀測時間附近的聚集情況。通過對數正則引數化方法,我們訓練了一個深度迴圈模型,能夠成功進行大型生存預測。通過對事件發生時間的整體分佈預測,我們解決了二值分類法不能解決的時間點預測問題,並且可以給出指定時間內的精確預測結果。能夠進行精確生存預測的意義是巨大的,尤其對於健康護理領域。我們希望我們的工作可以幫助到那些正在設計或部署這種模型的人。
論文原文連結:
https://arxiv.org/pdf/1806.08324v1.pdf