Causal Inference理論學習篇-Tree Based-Causal Tree

real-zhouyc發表於2024-04-14

Tree-Based Algorithms

Tree-based這類方法,和之前meta-learning 類的方法最明顯的區別是: 這類方法把causal effect 的計算顯示的加入了到了樹模型節點分裂的標準中 從 response時代過渡到了effect時代。

大量的這類演算法基本圍繞著樹節點分裂方式做文章,普遍採用的是相容性比較高的[[萬字長文講述樹模型的歷史|cart樹]]

Causal Tree & Honest Tree

causal tree[4] 這篇文章算是較早透過改變樹模型node分裂方式來預估[[因果推斷及其重要相關概念#heterogeneous causal effects|異質因果效應]](heterogeneous causal effects)的演算法。
所以重點還是如何去構建 split criterion,前置可能要說一下相關的符號含義:
在特徵空間 \(\mathbb X\) 下存在節點分裂方式的集合:

\[\prod(\ell_1,...,\ell_{\#(T)}) \]

其中以 \(\ell(x;\prod)\) 表示葉子節點\(\ell\) 屬於劃分方式 \(\prod\), 此時該劃分方式下的,node的條件期望定義為:

\[\mu(x;\prod)=E[Y_i|X_i \in \ell(x;\prod)] \]

那麼,自然如果給定樣本\(S\) , 其對應節點無偏統計量為:

\[\hat \mu(x;S,\prod)=\frac{1}{\#(i\in S: X_i \in \ell(x;\prod))}\sum_{i\in S:X_i\in\ell_i(x;\prod)} Y_i \]

Causal Tree 學習的目標 or loss func

學習目標使用修改後的MSE, 在標準mse的基礎上多減去了一項和模型引數估計無關的\(E[Y^2]\),此外
訓練即build tree階段,train set被切為兩部分,一部分訓練樣本train set:\(S^{tr}\) , 一部分是估計樣本 est set \(S^{est}\),還有測試樣本test set \(S^{te}\)

這裡有點繞:和經典的樹模型不一樣的是:葉子節點上儲存的值不是根據train set來的, 而是劃分好之後透過est set進行估計。(顯然, 這種方式有點費樣本...)。所以,這也是文中為啥把這種方法叫做“Honest”的原因。

假設已經根據訓練樣本得到劃分方式,那麼評估這種劃分方式好壞被定義為:

\[MSE(S^{te}, S^{est},\prod)=\frac{1}{\#(S^{te})} \sum_{i\in S^{te}} \{(Y_i-\hat\mu(X_i;S^{est},\prod))^2-Y_i^2\} \]

整體求期望變成:

\[EMSE(\prod)=E_{S^{te},S^{est}}[MSE(S^{te}, S^{est},\prod)] \]

演算法的整體目標為:

\[Q^{H}(\pi)=-E_{S^{est}, S^{est}, S^{tr}}[MSE(S^{te}, S^{est},\pi(S^{tr}))] \]

其中,\(\pi(S)\) 定義為:

\[\pi(\mathcal{S})= \begin{cases}\{\{L, R\}\} & \text { if } \bar{Y}_L-\bar{Y}_R \leq c \\ \{\{L\},\{R\}\} & \text { if } \bar{Y}_L-\bar{Y}_R>c .\end{cases} \]

其實就是比較節點在劃分後,左右子節點的輸出差異是否滿足閾值c,\(\bar Y_L=\mu(L)\)

節點劃分方式

作者直接給出了節點劃分時的loss計算標準:

我們來推導一下:

\[\begin{aligned} EMSE(\small\prod)&=E_{S^{te},S^{est}}[\frac{1}{\#(S^{te})} \sum_{i\in S^{te}} \{(Y_i-\hat\mu(X_i;S^{est},\small\prod))^2-Y_i^2\}] \\ &=E_{S^{te},S^{est}}[(Y_i-\hat\mu(X_i;S^{est},\small\prod))^2-Y_i^2] \\ &=E_{S^{te},S^{est}}[(Y_i-\mu(X_i;\small\prod)+\mu(X_i;\small\prod)-\hat\mu(X_i;S^{est},\small\prod))^2 - Y_i^2] \\ &=E_{S^{te},S^{est}}[\{Y_i-\mu(X_i;\small\prod)\}^2-Y_i^2] \\ &+2E_{S^{te},S^{est}}[\{Y_i-\mu(X_i;\small\prod)\}\{(\mu(X_i;\small\prod)-\hat\mu(X_i;S^{est},\small\prod)\}] \\ &+E_{S^{te},S^{est}}[\{\mu(X_i;\small\prod)-\hat\mu(X_i;S^{est},\small\prod)\}^2] \end{aligned} \]

因為中間展開項期望為0, 所以公式變成:

\[\begin{aligned} EMSE(\small\prod)&=E_{S^{te},S^{est}}[\{Y_i-\mu(X_i;\small\prod)\}^2-Y_i^2]+E_{S^{te},S^{est}}[\{\mu(X_i;\small\prod)-\hat\mu(X_i;S^{est},\small\prod)\}^2] \\ &=E_{S^{te},S^{est}}[(\mu(X_i;\small\prod))^2-2Y_i\mu(X_i;\small\prod)]+E_{S^{te},S^{est}}[\{\mu(X_i;\small\prod)-\hat\mu(X_i;S^{est},\small\prod)\}^2] \end{aligned} \]

同樣的,展開項的項期望為0,由於無偏估計=> \(\mu(X_i;\small \prod)=E_{S^{est}}[\hat \mu(X_i;S^{est};\small\prod)]\) ,最終公式變成:

\[-EMSE(\small\prod)=E_{X_i}[\mu^2(X_i;\small \prod)]-E_{S^{est},X_i}[\mathbb V(\hat\mu(X_i;S^{est}, \small\prod))] \]

其中,\(E_{S^{est},X_i}[\mathbb V(\hat\mu(X_i;S^{est}, \small\prod))]=E_{S^{te},S^{est}}[\{\hat\mu(X_i;S^{est},\small\prod)-\mu(X_i;\small\prod\}^2]\)

公式中第一項可以理解為偏差的平方,第二項理解為方差。為什麼MSE可以被理解成偏差和方差的組合,以及展開項為0
我們來證明一下:(開個玩笑:),其實我是抄的Wikipedia,可以看證明1證明2

偏差項

接著分析偏差項:

\[\begin{aligned} E_{X_i}\left[\mu^2\left(X_i ; \Pi\right)\right] & =E_{X_i}\left\{\left[E_S\left(\hat{\mu}\left(X_i ; S, \Pi\right)\right)\right]^2\right\} \\ & =E_{X_i}\left\{E_S\left[\hat{\mu}^2\left(X_i ; S, \Pi\right)\right]-\mathbb V_S\left[\hat{\mu}\left(X_i ; S, \Pi\right)\right]\right\} \\ & =E_{X_i}\left\{E_S\left[\hat{\mu}^2\left(X_i ; S, \Pi\right)\right]\right\}-E_{X_i}\left\{\mathbb V_S\left[\hat{\mu}\left(X_i ; S, \Pi\right)\right]\right\} \end{aligned} \]

第一項總體估計值的期望使用訓練集的樣本,即:

\[\hat \mu^2(X_i;S^{tr},\small \prod)=E_S[\hat\mu^2(X_i;S;\small\prod)] \]

第二項方差項,葉子節點方差求均值

\[\mathbb V_S[\hat\mu^2(X_i;S;\small\prod)]=\frac{S_{S^{tr}}^2}{N^{tr}} \]

對於最外層的期望:

\[\begin{aligned} \hat{E}_{X_i}\left[\mu^2\left(X_i ; \Pi\right)\right] & =\sum_{l \in \Pi} \frac{N_l^{t r}}{N^{t r}} \hat{\mu}^2\left(X_i ; S^{t r}, \Pi\right)-\sum_{l \in \Pi} \frac{N_l^{t r}}{N^{t r}} \frac{S_{t r}^2[\ell(x, \Pi)]}{N_\ell^{t r}} \\ & =\frac{1}{N^{t r}} \sum_{i \in S^{t r}} \hat{\mu}^2\left(X_i ; S^{t r}, \Pi\right)-\frac{1}{N^{t r}} \sum_{\ell \in \Pi} S_{t r}^2[\ell(x, \Pi)] \end{aligned} \]

方差項

\[\mathbb V(\hat\mu(X_i;S^{est}, \small\prod)=\frac{S_{S^{tr}}^{2}(\ell(x;\small \prod)}{N^{est}(\ell(x;\small \prod))} \]

\[E_{S^{est},X_i}[\mathbb{V}\left(\hat{\mu}^2\left(X_i ; \mathcal{S}^{\text {est }}, \Pi\right) \mid i \in \mathcal{S}^{\text {te }}\right] \equiv \frac{1}{N^{\text {est }}} \cdot \sum_{\ell \in \Pi} S_{\mathcal{S}^{\text {tr }}}^2(\ell) \]

整合

最終估計量為:

\[-\hat{EMSE(S^{tr},\small \prod)}=\frac{1}{N^{\operatorname{tr}}} \sum_{i \in \mathcal{S}^{\mathrm{tr}}} \hat{\mu}^2\left(X_i ; \mathcal{S}^{\operatorname{tr}}, \Pi\right)-\left(\frac{1}{N^{\operatorname{tr}}}+\frac{1}{N^{\mathrm{est}}}\right) \cdot \sum_{\ell \in \Pi} S_{\mathcal{S}^{\operatorname{tr}}}^2(\ell) \]

\[=\frac{1}{N^{\operatorname{tr}}} \sum_{i \in \mathcal{S}^{\operatorname{tr}}} \hat{\mu}^2\left(X_i ; \mathcal{S}^{\operatorname{tr}}, \Pi\right)-\frac{2}{N^{\operatorname{tr}}} \cdot \sum_{\ell \in \Pi} S_{\mathcal{S}^{\operatorname{tr}}}^2(\ell) \]

其中, 偏差和方差不過的est估計量應該用est set,但是此處假設了train set和est set 同分布。

treatment effect 介入劃分:處理異質效應

前面定義了MSE的正規化,當需要考慮到異質效應時,定義異質效應:

\[\tau = \mu(1, x;\small\prod)-\mu(0;x;\small \prod) \]

很顯然,我們永遠觀測不到異質性處理效應,因為我們無法觀測到反事實,我們只能夠估計處理效應,給出異質性處理效應的估計量:

\[\hat \tau(w,x;S,\small \prod)=\hat \tau(1,s;S,\small \prod)-\hat \tau(0,s;S,\small \prod) \]

因果效應下的EMSE為:

\[MSE_{\tau}=\frac{1}{\#(S^{te})}\sum_{i\in S^{te}} \{(\tau_i-\hat \tau(Xi;S^{est},\small \prod))^2-\tau_i^2\} \]

\[-\operatorname{EMSE}_\tau(\Pi)=\mathbb{E}_{X_i}\left[\tau^2\left(X_i ; \Pi\right)\right]-\mathbb{E}_{\mathcal{S}^{\text {est }}, X_i}\left[\mathbb{V}\left(\hat{\tau}^2\left(X_i ; \mathcal{S}^{\text {est }}, \Pi\right)\right]\right. \]

使用\(\tau\)替代了\(\mu\) , 偏差項, 帶入整合公式:

\[-\hat {EMSE_{\tau}(S^{tr, \small\prod})}=\frac{1}{N^{tr}}\sum_{i\in S^{tr}} \hat \tau^2(Xi;S^{tr},\small \prod)-\frac{2}{N^{tr}}\sum_{\ell \in \small \prod}(\frac{S_{S_{treat}^{tr}}^2(\ell)}{p}+\frac{S_{S_{control}^{tr}}^2(\ell)}{1-p}) \]

其中,\(p\)表示相應treatment組的樣本佔比,該子式也是最終的計算節點分類標準的公式

有了節點劃分方式之後,build tree的過程和CART樹是一樣的

推理過程

推理過程和決策樹基本一樣,樹建好之後,只用根據每個node儲存的特徵和threshold進行path 遍歷,走到葉子節點返回值即可。
一般來說,causal tree的葉子節點儲存的

結構體之外還單獨儲存了一個 叫 value 的陣列,主要是儲存每個節點的預測值。對於兩個Treatment來說,儲存的大小就是1x2的list,第一個element儲存了control的 正樣本比例,第二個element儲存了treatment的正樣本比例。
一般來說,這個比例會做配平或者說懲罰:

所以,最終推理得到的是一個輸入一個樣本X,得到T-C的treatment effect,我們不用像meta-learning類的模型一樣,自己手動減得到ITE。

Causal Tree總結

  1. 作者改進了MSE,主動減去了一項模型引數無關的\(E[Y_i^2]\)。改進方法的MSE包含了組內方差,這個方差越大,MSE就會越低,所以它能夠在一定程度上限制模型的複雜性
  2. 把改進的 mse loss apply 到CATE中來指導節點分割 和 建立決策樹
  3. 構建樹的過程中,train set切割為了 \(S^{tr}\)\(S^{est}\) 兩部分,node的預測值由\(S^{est}\) 進行無偏估計,雖然最後實際上\(S^{est}\) 用train set替代了。
  4. 理論上causal tree 僅支援 兩個Treatment

    如果使用causalml的package,如果存在多個T,非C組都會被置為一個T

REF

  1. https://hwcoder.top/Uplift-1
  2. 工具: scikit-uplift
  3. Meta-learners for Estimating Heterogeneous Treatment Effects using Machine Learning
  4. Athey, Susan, and Guido Imbens. "Recursive partitioning for heterogeneous causal effects." Proceedings of the National Academy of Sciences 113.27 (2016): 7353-7360.
  5. https://zhuanlan.zhihu.com/p/115223013

相關文章