爆火後反轉?「一夜幹掉MLP」的KAN:其實我也是MLP

机器之心發表於2024-05-07

KAN 作者:我想傳達的資訊不是「KAN 很棒」,而是「嘗試批判性地思考當前的架構,並尋求從根本上不同的替代方案,這些方案可以完成有趣、有用的事情。」


多層感知器(MLP),也被稱為全連線前饋神經網路,是當今深度學習模型的基礎構建塊。MLP 的重要性無論怎樣強調都不為過,因為它們是機器學習中用於逼近非線性函式的預設方法。

但是最近,來自 MIT 等機構的研究者提出了一種非常有潛力的替代方法 ——KAN。該方法在準確性和可解釋性方面表現優於 MLP。而且,它能以非常少的引數量勝過以更大引數量執行的 MLP。比如,作者表示,他們用 KAN 重新發現了結理論中的數學規律,以更小的網路和更高的自動化程度重現了 DeepMind 的結果。具體來說,DeepMind 的 MLP 有大約 300000 個引數,而 KAN 只有大約 200 個引數

這些驚人的結果讓 KAN 迅速走紅,吸引了很多人對其展開研究。很快,有人提出了一些質疑。其中,一篇標題為《KAN is just MLP》的 Colab 文件成為了議論的焦點。

圖片

KAN 只是一個普通的 MLP?

上述文件的作者表示,你可以把 KAN 寫成一個 MLP,只要在 ReLU 之前加一些重複和移位。

在一個簡短的例子中,作者展示瞭如何將 KAN 網路改寫為具有相同數量引數的、有輕微的非典型結構的普通 MLP。

需要記住的是,KAN 在邊上有啟用函式。它們使用 B - 樣條。在展示的例子中,為了簡單起見,作者將只使用 piece-wise 線性函式。這不會改變網路的建模能力。

下面是 piece-wise 線性函式的一個例子:
def f(x):
  if x < 0:
    return -2*x
  if x < 1:
    return -0.5*x
  return 2*x - 2.5

X = torch.linspace(-2, 2, 100)
plt.plot(X, [f(x) for x in X])
plt.grid()

圖片

作者表示,我們可以使用多個 ReLU 和線性函式輕鬆重寫這個函式。請注意,有時需要移動 ReLU 的輸入。
plt.plot(X, -2*X + torch.relu(X)*1.5 + torch.relu(X-1)*2.5)
plt.grid()

圖片

真正的問題是如何將 KAN 層改寫成典型的 MLP 層。假設有 n 個輸入神經元,m 個輸出神經元,piece-wise 函式有 k 個 piece。這需要 n∗m∗k 個引數(每條邊有 k 個引數,而你有 n∗m 條邊)。

現在考慮一個 KAN 邊。為此,需要將輸入複製 k 次,每個副本移動一個常數,然後透過 ReLU 和線性層(第一層除外)執行。從圖形上看是這樣的(C 是常數,W 是權重):

圖片

現在,可以對每一條邊重複這一過程。但要注意一點,如果各處的 piece-wise 線性函式網格相同,我們就可以共享中間的 ReLU 輸出,只需在其上混合權重即可。就像這樣:

圖片

在 Pytorch 中,這可以翻譯成以下內容:
k = 3 # Grid size
inp_size = 5
out_size = 7
batch_size = 10
X = torch.randn(batch_size, inp_size) # Our input
linear = nn.Linear(inp_size*k, out_size)  # Weights
repeated = X.unsqueeze(1).repeat(1,k,1)
shifts = torch.linspace(-1, 1, k).reshape(1,k,1)
shifted = repeated + shifts
intermediate = torch.cat([shifted[:,:1,:], torch.relu(shifted[:,1:,:])], dim=1).flatten(1)
outputs = linear(intermediate)

現在我們的層看起來是這樣的:
  • Expand + shift + ReLU

  • Linear

一個接一個地考慮三個層:

  • Expand + shift + ReLU (第 1 層從這裡開始)
  • Linear
  • Expand + shift + ReLU (第 2 層從這裡開始)
  • Linear
  • Expand + shift + ReLU (第 3 層從這裡開始)
  • Linear

忽略輸入 expansion,我們可以重新排列:

  • Linear (第 1 層從這裡開始)
  • Expand + shift + ReLU
  • Linear (第 2 層從這裡開始)
  • Expand + shift + ReLU

如下的層基本上可以稱為 MLP。你也可以把線性層做大,去掉 expand 和 shift,獲得更好的建模能力(儘管需要付出更高的引數代價)。

  • Linear (第 2 層從這裡開始)
  • Expand + shift + ReLU

透過這個例子,作者表明,KAN 就是一種 MLP。這一說法引發了大家對兩類方法的重新思考。

圖片

對 KAN 思路、方法、結果的重新審視

其實,除了與 MLP 理不清的關係,KAN 還受到了其他許多方面的質疑。

總結下來,研究者們的討論主要集中在如下幾點。

第一,KAN 的主要貢獻在於可解釋性,而不在於擴充套件速度、準確性等部分。

論文作者曾經表示:

  1. KAN 的擴充套件速度比 MLP 更快。KAN 比引數較少的 MLP 具有更好的準確性。
  2. KAN 可以直觀地視覺化。KAN 提供了 MLP 無法提供的可解釋性和互動性。我們可以使用 KAN 潛在地發現新的科學定律。

其中,網路的可解釋性對於模型解決現實問題的重要性不言而喻:

圖片

但問題在於:「我認為他們的主張只是它學得更快並且具有可解釋性,而不是其他東西。如果 KAN 的引數比等效的 NN 少得多,則前者是有意義的。我仍然感覺訓練 KAN 非常不穩定。」

圖片

那麼 KAN 究竟能不能做到引數比等效的 NN 少很多呢?

這種說法目前還存在疑問。在論文中,KAN 的作者表示,他們僅用 200 個引數的 KAN,就能復現 DeepMind 用 30 萬引數的 MLP 發現數學定理研究。在看到該結果後,佐治亞理工副教授 Humphrey Shi 的兩位學生重新審視了 DeepMind 的實驗,發現只需 122 個引數DeepMind 的 MLP 就能媲美 KAN 81.6% 的準確率。而且,他們沒有對 DeepMind 程式碼進行任何重大修改。為了實現這個結果,他們只減小了網路大小,使用隨機種子,並增加了訓練時間。

圖片

圖片

對此,論文作者也給出了積極的回應:

圖片

第二,KAN 和 MLP 從方法上沒有本質不同。

圖片

「是的,這顯然是一回事。他們在 KAN 中先做啟用,然後再做線性組合,而在 MLP 中先做線性組合,然後再做啟用。將其放大,基本上就是一回事。據我所知,使用 KAN 的主要原因是可解釋性和符號迴歸。」

圖片

除了對方法的質疑之外,研究者還呼籲對這篇論文的評價迴歸理性:

「我認為人們需要停止將 KAN 論文視為深度學習基本單元的巨大轉變,而只是將其視為一篇關於深度學習可解釋性的好論文。在每條邊上學習到的非線性函式的可解釋性是這篇論文的主要貢獻。」

第三,有研究者表示,KAN 的思路並不新奇。

圖片

「人們在 20 世紀 80 年代對此進行了研究。Hacker News 的討論中提到了一篇義大利論文討論過這個問題。所以這根本不是什麼新鮮事。40 年過去了,這只是一些要麼回來了,要麼被拒絕的東西被重新審視的東西。」

但可以看到的是,KAN 論文的作者也沒有掩蓋這一問題。

「這些想法並不新鮮,但我不認為作者回避了這一點。他只是把所有東西都很好地打包起來,並對 toy 資料進行了一些很好的實驗。但這也是一種貢獻。」

與此同時,Ian Goodfellow、Yoshua Bengio 十多年前的論文 MaxOut(https://arxiv.org/pdf/1302.4389)也被提到,一些研究者認為二者「雖然略有不同,但想法有點相似」。

作者:最初研究目標確實是可解釋性

熱烈討論的結果就是,作者之一 Sachin Vaidya 站出來了。

圖片

作為該論文的作者之一,我想說幾句。KAN 受到的關注令人驚歎,而這種討論正是將新技術推向極限、找出哪些可行或不可行所需要的。

我想我應該分享一些關於動機的背景資料。我們實現 KAN 的主要想法源於我們正在尋找可解釋的人工智慧模型,這種模型可以「學習」物理學家發現自然規律的洞察力。因此,正如其他人所意識到的那樣,我們完全專注於這一目標,因為傳統的黑箱模型無法提供對科學基礎發現至關重要的見解。然後,我們透過與物理學和數學相關的例子表明,KAN 在可解釋性方面大大優於傳統方法。我們當然希望,KAN 的實用性將遠遠超出我們最初的動機。


在 GitHub 主頁中,論文作者之一劉子鳴也對這項研究受到的評價進行了回應:

最近我被問到的最常見的問題是 KAN 是否會成為下一代 LLM。我對此沒有很清楚的判斷。

KAN 專為關心高精度和可解釋性的應用程式而設計。我們確實關心 LLM 的可解釋性,但可解釋性對於 LLM 和科學來說可能意味著截然不同的事情。我們關心 LLM 的高精度嗎?縮放定律似乎意味著如此,但可能精度不太高。此外,對於 LLM 和科學來說,準確性也可能意味著不同的事情。

我歡迎人們批評 KAN,實踐是檢驗真理的唯一標準。很多事情我們事先並不知道,直到它們經過真正的嘗試並被證明是成功還是失敗。儘管我願意看到 KAN 的成功,但我同樣對 KAN 的失敗感到好奇。

KAN 和 MLP 不能相互替代,它們在某些情況下各有優勢,在某些情況下各有侷限性。我會對包含兩者的理論框架感興趣,甚至可以提出新的替代方案(物理學家喜歡統一理論,抱歉)。

圖片

KAN 論文一作劉子鳴。他是一名物理學家和機器學習研究員,目前是麻省理工學院和 IAIFI 的三年級博士生,導師是 Max Tegmark。他的研究興趣主要集中在人工智慧 AI 和物理的交叉領域。

參考連結:https://colab.research.google.com/drive/1v3AHz5J3gk-vu4biESubJdOsUheycJNz#scrollTo=WVDbcpBqAFop
https://github.com/KindXiaoming/pykan?tab=readme-ov-file#authors-note

相關文章