萬事通,專精部分領域的多功能 Transformer 智慧體

HuggingFace發表於2024-05-13

介紹

我們很高興分享“萬事通”(Jack of All Trades,簡稱 JAT) 專案,該專案旨在朝著通用智慧體的方向發展。該專案最初是作為對 Gato (Reed 等,2022 年) 工作的公開復現啟動的,Gato 提出訓練一種能夠執行視覺與語言以及決策任務的 Transformer。於是我們首先構建了 Gato 資料集的開放版本。隨後,我們在此基礎上訓練了多模態 Transformer 模型,並針對處理順序資料和連續值引入了若干改進。

總體而言,該專案取得了以下成果:

  • 釋出了大量在各種任務上表現優異的 專家 RL 智慧體
  • 釋出了 JAT 資料集,這是第一個用於通用智慧體訓練的資料集。它包含了由專家智慧體收集的數十萬條專家軌跡。
  • 釋出了 JAT 模型,這是一種基於 Transformer 的智慧體,能夠玩電子遊戲、控制機器人執行各種任務、理解並在簡單的導航環境中執行命令等!

資料集和專家策略

專家策略

傳統的強化學習 (RL) 涉及在單一環境中訓練策略。利用這些專家策略是構建多功能智慧體的有效方法。我們選擇了各種性質和難度不同的環境,包括 Atari、BabyAI、Meta-World 和 MuJoCo。在每個環境中,我們訓練一個智慧體,直到它達到最先進的效能水平。(對於 BabyAI,我們使用的是 BabyAI bot)。這些訓練結果被稱為專家智慧體,並已在🤗 Hub 上釋出。您可以在 JAT 資料集卡 中找到所有智慧體的列表。

JAT 資料集

我們釋出了 JAT 資料集,這是第一個用於通用智慧體訓練的資料集。JAT 資料集包含由上述專家智慧體收集的數十萬條專家軌跡。要使用此資料集,只需像從🤗 Hub 載入任何其他資料集一樣載入它:

>>> from datasets import load_dataset
>>> dataset = load_dataset("jat-project/jat-dataset", "metaworld-assembly")
>>> first_episode = dataset["train"][0]
>>> first_episode.keys()
dict_keys(['continuous_observations', 'continuous_actions', 'rewards'])
>>> len(first_episode["rewards"])
500
>>> first_episode["continuous_actions"][0]
[6.459120273590088, 2.2422609329223633, -5.914587020874023, -19.799840927124023]

除了強化學習 (RL) 資料,我們還包含了文字資料集,以為使用者提供獨特的介面。因此,您還會發現 WikipediaOscarOK-VQAConceptual-Captions 的子集。

JAT 智慧體架構

JAT 的架構基於 Transformer,使用了 EleutherAI 的 GPT-Neo 實現。JAT 的特別之處在於其嵌入機制,該機制專門用於內在地處理順序決策任務。我們將觀測嵌入與動作嵌入交錯排列,並結合相應的獎勵。

JAT 網路的架構。在順序決策任務中,一方面將觀測和獎勵編碼,另一方面將動作編碼並交錯排列。模型使用因果掩碼自迴歸地生成下一個嵌入,並根據預期的模態進行解碼。

每個嵌入因此對應於一個觀測 (與獎勵相關聯) 或一個動作。那麼 JAT 是如何編碼這些資訊的呢?這取決於資料的型別。如果資料 (觀測或動作) 是影像 (如在 Atari 中的情況),那麼 JAT 使用 CNN。如果是連續向量,則 JAT 使用線性層。最後,如果是離散值,JAT 使用線性投影層。同樣的原理也用於模型輸出,具體取決於要預測的資料型別。預測是因果的,將觀測值移位一個時間步長。透過這種方式,智慧體必須根據所有先前的觀測和動作來預測下一個動作。

此外,我們認為讓我們的智慧體執行 NLP 和 CV 任務會很有趣。為此,我們還讓編碼器可以選擇將文字和影像資料作為輸入。對於文字資料,我們使用 GPT-2 的標記化策略,對於影像,我們使用 ViT 型別的編碼器。

考慮到資料的模態可能因環境而異,JAT 如何計算損失呢?它分別計算每種模態的損失。對於影像和連續值,它使用 MSE 損失。對於離散值,它使用交叉熵損失。最終損失是序列中每個元素損失的平均值。 等等,這是否意味著我們對預測動作和觀測賦予了相等的權重?實際上並不是這樣,但我們將在 下文 中詳細討論。

實驗與結果

我們在所有 157 個訓練任務上評估 JAT。我們收集了 10 個回合的資料並記錄總獎勵。為了便於閱讀,我們按領域彙總結果。

每個 RL 領域的彙總專家標準化得分及其 95%置信區間 (CI),作為學習步數的函式。

如果要用一個數字來總結這些結果,那就是 65.8%,這是在 4 個領域中相對於 JAT 專家的平均表現。這表明 JAT 能夠在各種任務中模仿專家的表現。讓我們更詳細地看看:

  • 對於 Atari 57,智慧體達到了專家得分的 14.1%,相當於人類表現的 37.6%。在 21 個遊戲中超過了人類表現。
  • 對於 BabyAI,智慧體達到了專家得分的 99.0%,僅在 1 個任務上未能超過專家得分的 50%。
  • 對於 Meta-World,智慧體達到了專家得分的 65.5%。
  • 對於 MuJoCo,智慧體達到了專家得分的 84.8%。

JAT 智慧體在 Atari 57 基準測試中的人類標準化得分。

最令人印象深刻的是,JAT 在所有領域中使用 單一網路 實現了這一效能。為了衡量這一效能,讓我們來看看 JAT 在一些任務中的渲染效果:

想試試嗎?你可以的!JAT 模型 已在 🤗 Hub 上提供!

我們的模型顯示了初步的文字任務處理能力,詳情請參閱 論文

預測觀測值的驚人好處

在訓練 RL 智慧體時,主要目標是最大化未來獎勵。但是,如果我們還要求智慧體預測它將來會觀測到的內容,這個額外的任務會幫助還是妨礙學習過程呢?

對於這個問題有兩種對立的觀點。一方面,學習預測觀測值可以提供對環境更深入的理解,從而導致更好更快的學習。另一方面,這可能會使智慧體偏離其主要目標,導致在觀測和動作預測方面的表現平平。

為了解決這一爭論,我們進行了一個實驗,使用了一個結合觀測損失和動作損失的損失函式,並用一個加權引數 ( \kappa ) 來平衡這兩個目標。

對於所選任務的觀測預測學習影響研究的彙總度量及 95%置信區間 (CI)。結果覆蓋所選的 \( k \) 值範圍,並基於每個任務 100 次評估。選擇最佳的 \( k \) 值可以顯著提高智慧體的效能。

結果非常顯著。當 \( k \) 值過高 (0.5) 時,預測觀測的額外目標似乎阻礙了學習過程。但是,當 \( k \) 值較低時,對學習的影響可以忽略不計,智慧體的表現與不將預測觀測作為目標時相似。

然而,我們發現 \( k = 0.005 \) 左右是一個最佳點,此時學習預測觀測實際上提高了智慧體的學習效率。 我們的研究表明,只要平衡得當,將預測觀測新增到學習過程中是有益的。這一發現對這類智慧體的設計有重要意義,強調了輔助目標在提高學習效率方面的潛在價值。

所以,下次訓練 RL 智慧體時,可以考慮讓它預測將來會觀測到的內容。這可能會帶來更好的表現和更快的學習速度!

結論

在這項工作中,我們介紹了 JAT,一個能夠掌握各種順序決策任務並在 NLP 和 CV 任務中表現出初步能力的多用途 Transformer 智慧體。對於所有這些任務,JAT 都使用單一網路。我們的貢獻包括髮布專家級 RL 智慧體、JAT 資料集和 JAT 模型。我們希望這項工作能夠激發未來在通用智慧體領域的研究,並有助於開發更多功能和更強大的 AI 系統。

下一步是什麼?研究請求

我們相信,JAT 專案為通用智慧體領域的研究開闢了新的方向,而我們只是剛剛開始。以下是一些未來工作的想法:

  • 改進資料: 儘管具有開創性,JAT 資料集仍處於初期階段。專家軌跡僅來自每個環境中的一個專家智慧體,這可能會導致一些偏差。儘管我們盡力達到了最先進的效能,但有些環境仍然具有挑戰性。我們相信,收集更多的資料和訓練更多的專家智慧體將會 大有幫助
  • 使用離線 RL: JAT 智慧體是使用基本的行為克隆訓練的。這意味著兩件事: (1) 我們無法利用次優軌跡,(2) JAT 智慧體不能超過專家的表現。我們選擇這種方法是為了簡單,但我們相信使用離線 RL 可以 大大提高 智慧體的效能,同時實現起來也不會太複雜。
  • 釋放更聰明的多工取樣策略的全部潛力: 目前,JAT 智慧體從所有任務中均勻取樣資料,但這種方法可能會限制其表現。透過動態調整取樣率以集中於最具挑戰性的任務,我們可以加速智慧體的學習過程並釋放 顯著的效能提升

相關連結

  • 📄 論文
  • 💻 原始碼
  • 🗂️ JAT 資料集
  • 🤖 JAT 模型

引文

@article{gallouedec2024jack,
    title = {{Jack of All Trades, Master of Some, a Multi-Purpose Transformer Agent}},
    author = {Gallouédec, Quentin and Beeching, Edward and Romac, Clément and Dellandréa, Emmanuel},
    journal = {arXiv preprint arXiv:2402.09844},
    year = {2024},
    url = {https://arxiv.org/abs/2402.09844}
}

英文原文: https://hf.co/blog/jat

原文作者: Quentin Gallouédec, Edward Beeching, Clément ROMAC, Thomas Wolf

譯者: xiaodouzi

相關文章