SIGIR2024| RAREMed: 不放棄任何一個患者——提高對罕見病患者的藥物推薦準確性

子豪君發表於2024-10-11

SIGIR2024| RAREMed: 不放棄任何一個患者——提高對罕見病患者的藥物推薦準確性

TLDR:在本文中,我們針對藥物推薦模型對罕見病患者推薦精度低的問題,提出了一種新的基於預訓練-微調的藥物推薦模型框架RAREMed,並提出了兩個針對性的預訓練任務,來提高模型對患者病情,尤其是罕見病患者病情的表示學習能力,幫助藥物推薦模型提高對罕見病患者的推薦準確度,從而提升藥物推薦模型的公平性。我們在兩個常用的公開資料集上的實驗結果顯示出我們的方法在藥物推薦模型的準確度和公平性方面有著顯著優勢。

論文地址:https://arxiv.org/abs/2403.17745

程式碼地址:https://github.com/zzhUSTC2016/RAREMed

image-20240517093338433

引言

隨著人工智慧的快速發展,其在醫療健康領域的應用得到了越來越多的關注,其中,藥物推薦是一個重要的任務,近年來取得了快速的發展。藥物推薦任務旨在利用病人的疾病、手術、病史等臨床資訊,為患者或醫生推薦有效、安全的藥物組合,減輕醫生的工作負擔,並減少潛在的醫療失誤風險,例如藥物誤用、不良的藥物-藥物相互作用等。

現有的藥物推薦模型主要關注提高整體的藥物推薦準確度,但它們面臨的一個普遍問題是公平性問題——對患有罕見病病人的推薦準確率顯著低於其他病人,見圖1(b)。這主要是由於罕見病在訓練資料中出現的次數很少,模型難以學到準確的表達。如圖1(a)所示,一少部分疾病出現次數很多,而大多數長尾疾病出現次數很少。另一方面,現有的藥物推薦模型將兩個關鍵的輸入——疾病(diseases)和手術(procedures)分開建模,導致模型難以捕捉兩種輸入之間的關聯。

image-20240517110208414

圖1:(a) 兩個常見公開資料集上的疾病編碼的長尾分佈 (b) 罕見病病人和常見病病人的推薦準確度對比

在這篇工作中,為了解決藥物推薦面臨的公平性問題,我們利用transformer架構,並提出了兩種針對性的預訓練任務來提高模型的學習和表達能力。具體來說,模型包含兩個預訓練任務:序列匹配預測(Sequence Matching Prediction, SMP)和自重構(Self Reconstruction, SR)。序列匹配預測任務使模型可以分辨疾病和手術序列是否屬於同一個病人,從而更好地理解病人病情資訊中的關聯關係。自重構任務使模型可以利用學到的病人表示重建出病人的輸入編碼,從而更全面地捕捉病人的輸入資訊。

問題定義

在藥物推薦任務中,輸入資訊通常是EHR(Electronic Health Record, 醫療健康記錄),其中包含病人的醫療記錄資訊\(\mathcal{V}^{(j)} = \{\mathbf{d}^{(j)}, \mathbf{p}^{(j)}, \mathbf{m}^{(j)}\}\),其中\(\mathbf{d}^{(j)} = [d_1, d_2, \cdots, d_x] \in \mathcal{D}\)表示患者所患的疾病,\(\mathbf{p}^{(j)} = [p_1, p_2, \cdots, p_y] \in \mathcal{P}\)表示患者所做的手術,\(\mathbf{m}^{(j)} \in \{0, 1\}^{|\mathcal{M}|}\)表示醫生開的藥物。除此之外,輸入資訊還包括藥物-藥物相互作用關係(DDI Graph, Drug-Drug Interaction Graph),用於約束和評價藥物推薦結果中的不良藥物相互作用。

藥物推薦任務被定義為:給定病人的疾病序列\(\mathbf{d}\)和手術序列\(\mathbf{p}\),以及藥物藥物相互作用關係圖\(\mathbf{A}\),目標是推薦一個藥物集合\(\hat{\mathcal{Y}}\),以最大化預測準確率,並儘可能減少不良藥物相互作用。

公平的藥物推薦任務是指,除了上述藥物推薦最佳化目標之外,還需要模型對患有常見病和罕見病的病人都有相似的推薦準確率,減少模型推薦結果對罕見病病人的不公平性。

罕見病病人特徵分析

1715922721103

圖2:橫座標為患者最罕見疾病出現次數,縱座標為患有不同流行度的患者組對應的(a) 疾病數量,(b) 手術數量,(c) 藥物數量,(d) 藥物流行度,均為平均值

我們將MIMIC-IV資料集中的患者根據所患最罕見疾病的流行度劃分為陣列,並統計各組內患者的特徵,如圖2所示,我們有如下兩個觀察:

Observation 1: 患有罕見病的病人病情更復雜。如圖2(a)(b)所示,各患者組內患者所患疾病數量和手術數量隨患者疾病流行度增加而下降。

Observation 2: 患有罕見病的病人治療方案更復雜、更個性化。如圖2(c)(d)所示,罕見病患者通常需要更多、更罕見的藥物進行治療。

這些都為罕見病病人的藥物推薦提出了更大的挑戰。為這些病人提供更準確的藥物推薦結果需要更全面地捕捉他們的病情,並提供更全面、更個性化的藥物推薦結果。

方法

為了提供更準確、更公平的藥物推薦,我們提出了Robust and Accurate REcommendations for Medication (RAREMed)模型,如圖3所示。模型分為患者編碼、預訓練和微調三大部分。

1715924266623

圖3:RAREMed模型圖

患者表示(Encoder)

為了更全面地捕捉患者的病情資訊,給定患者的疾病\(\mathbf{d}\)、手術\(\mathbf{p}\)等輸入,我們將這些編碼連線起來作為一個統一的序列輸入給transformer模型:

\[\begin{align} \label{eq:1} input = [\text{CLS}] \oplus \mathbf{d} \oplus [\text{SEP}] \oplus \mathbf{p}, \end{align} \]

其中[CLS]和[SEP]是特殊編碼,分別表示序列開端符和分隔符。\(\oplus\)表示序列連線符號。患者的疾病和手術編碼按照與患者此次住院的重要性排序,重要性在資料集中已經由專業醫生完成標註。

在此基礎上,我們設計了三個嵌入層,其中,符號嵌入層(Token Embedding)表示每個token的語義,相關性嵌入層(Relevance Embedding)用於編碼每個token的重要性,只與位置相關,分類嵌入層(Segment Embedding)用於編碼輸入token的類別,標識每個token屬於疾病還是手術。

最後,這些embedding經過一個transformer編碼器,生成病人表示,採用[CLS]符號的輸出層編碼:

\[\mathbf{r} = \text{Encoder}(E_{\text{tok}}(input) + E_{\text{seg}}(input) + E_{\text{rel}}(input))[0]。 \]

預訓練(Pre-training)

為了增強模型的表示學習能力,我們針對藥物推薦任務設計了兩種預訓練任務:

Task #1: 序列匹配預測 Sequence Matching Prediction (SMP) :SMP任務的目標是使模型能夠預測輸入的疾病和手術兩個序列是否屬於同一個病人。我們為每個真實病人的(\(\mathbf{d_i}\), \(\mathbf{p_i}\))正樣本對匹配一個負樣本對(\(\mathbf{d_i}\), \(\mathbf{p_j}\)),其中\(\mathbf{p_j}\)來自隨機取樣的另一個患者的手術序列。然後,我們使用Binary Cross-Entropy (BCE) loss來最佳化模型引數:

\[L_{SMP} = - \log(\hat{y}_i) + \log(1-\hat{y}_j), \]

其中\(\hat{y}_i=\sigma(W_1\mathbf{r}_i + b_1)\in\mathcal{R}\) 表示正樣本對的預測機率,\(\hat{y}_j\)表示負樣本對的預測機率。這裡,\(\sigma\)表示sigmoid函式。\(W_1\in\mathbb{R}^{dim}\)\(b_1\in\mathbb{R}\)是可訓練的引數。

**Task #2: 自重構 Self Reconstruction (SR): ** SR任務的目標是使模型可以從病人表示中重構出輸入序列。這個任務會鼓勵RAREMed捕捉和儲存輸入編碼中儘可能多的資訊。損失函式定義如下:

\[L_{SR} = -\sum_{j=1}^{|\mathcal{D}|+|\mathcal{P}|} \left[ \mathbf{c}_j \log(\hat{\mathbf{c}}_j) + (1-\mathbf{c}_j) \log(1-\hat{\mathbf{c}}_j) \right], \]

其中\(\hat{\mathbf{c}} = \sigma(W_2\mathbf{r}+b_2)\in[0,1]^{|\mathcal{D}|+|\mathcal{P}|}\)表示由RAREMed重構的所有疾病和手術的機率,\(W_2\in\mathbb{R}^{(|\mathcal{D}|+|\mathcal{P}|)\times dim}\)\(b_2\in\mathbb{R}^{|\mathcal{D}|+|\mathcal{P}|}\)是可學習的引數。在這裡,\(\mathbf{c}\in\{0,1\}^{|\mathcal{D}|+|\mathcal{P}|}\)表示真實標籤。僅當輸入序列中出現相應的標籤時,才將\(\mathbf{c}_j\)設定為1。

微調和推理(Fine-tune and Inference)

經過預訓練之後,我們在藥物推薦任務上對RAREMed進行微調,以使其適應下游的藥物推薦任務。為了預測藥物,我們整合了一個多標籤分類層,並利用患者表示作為輸入:

\[\hat{\mathbf{o}} = \sigma(W_3 \mathbf{r} + b_3), \]

where \(\hat{\mathbf{o}}\in[0,1]^{|\mathcal{M}|}\)為藥物被推薦的機率。\(W_3\in\mathbb{R}^{|\mathcal{M}|\times dim}\)\(b_3\in\mathbb{R}^{|\mathcal{M}|}\)是可學習引數。

我們使用如下損失函式對模型引數進行最佳化:

\[L_{bce} =-\sum_{i=1}^{|\mathcal{M}|} \left[ \mathbf{m}_i \log(\hat{\mathbf{o}}_i) + (1-\mathbf{m}_i) \log(1-\hat{\mathbf{o}}_i) \right].\\ L_{multi} = \sum_{i,j: \mathbf{m}_i=1, \mathbf{m}_j=0} \frac{\text{max}(0, 1-(\hat{\mathbf{o}}_i-\hat{\mathbf{o}}_j))}{|\mathcal{M}|}.\\ L_{ddi} = \sum_{i=1}^{|\mathcal{M}|} \sum_{j=1}^{|\mathcal{M}|} \mathbf{A}_{ij} \cdot \hat{\mathbf{o}}_i \cdot \hat{\mathbf{o}}_j. \\ L = (1-\beta)((1-\alpha) L_{bce} + \alpha L_{multi}) + \beta L_{ddi}, \]

其中𝛼和𝛽是平衡不同損失貢獻的超引數。

在推理過程中,我們向患者推薦機率大於閾值\(\delta = 0.5\)的藥物。因此,最終的推薦藥物集\(\hat{\mathcal{Y}}\)可以定義為:

\[\hat{\mathcal{Y}} = \{i | \hat{\mathbf{o}}_i > 0.5, 1\leq i \leq |\mathcal{M}|\}. \]

實驗

實驗設定

我們採用MIMIC-III和MIMIC-IV兩個EHR資料集,並仿照前人的工作對資料進行的篩選和處理。

image-20240517144236921

我們選取Jaccard、PRAUC、F1、DDI rate和#Med作為評測指標,評價各模型的推薦精度和安全性。其中Jaccard、PRAUC、F1越高,表示藥物推薦越準確;DDI rate越低,表示藥物相互作用越少,推薦結果越安全;而#Med越接近醫生開藥的平均藥物數量,說明模型預測越合理。

實驗結果

  • RAREMed比現有藥物推薦方法準確性更高、安全性更好。

image-20240517144419600

  • RAREMed可以產生更公平的藥物推薦結果,對罕見病病人的藥物推薦精度顯著高於現有模型

    image-20240517144752974

  • 預訓練任務、統一的序列編碼、設計的兩個額外的嵌入層都對推薦精確度和公平性有正向的影響

image-20240517144901147

  • 超引數對RAREMed推薦結果的影響

image-20240517144953995

相關文章