機器學習中的新數學,加速AI訓練離不開數字表示方式和基本計算的變革

機器之心發表於2022-11-27


圖片

近年來 AI 領域的發展令人震驚,但為完成這些壯舉而訓練神經網路的成本也異常巨大。以大規模語言模型 GPT-3 和藝術生成器 DALL-E 2 為例,它們需要在高效能 GPU 叢集上訓練數月時間,耗資數百萬美元,消耗百萬億計的基本計算。

同時,處理單元的訓練能力一直在快速增長,僅 2021 年就翻了一番。為了保持這一趨勢,研究人員正在深入研究最基礎的計算構建塊,即計算機表示數字的方式。

在上個月舉辦的第 29 屆 IEEE 計算機算術研討會(IEEE Symposium on Computer Arithmetic)的一場 Keynote 演講中,英偉達首席科學家、高階研究副總裁 Bill Dally 表示,「過去 10 年,單個晶片的訓練效能提升了 1000 倍,其中很大部分要歸功於數字表示。」

在朝著更高效 AI 訓練前進的過程中,首先「犧牲」的是 32-bit 浮點數表示,俗稱標準精度。為了全面追求速度、能效以及晶片面積和記憶體的更好利用,機器學習研究人員一直努力透過更少 bit 表示的數字來獲得相同的訓練水平。對於試圖取代 32-bit 格式的競爭者來說,這個領域依然很開放,無論是在數字表示本身還是完成基礎運算的方式上。

英偉達每向量縮放量化方案(VSQ)

我們知道,影像生成神器 DALL-E 在英偉達 A100 GPU 叢集上接受了標準 32-bit 數字和低精度 16-bit 數字的組合訓練。Hopper GPU 更是支援了更小的 8-bit 浮點數。最近,英偉達在一項研究中開發了一個原型晶片,透過使用 8-bit 和 4-bit 數字的組合更進一步推動了這一趨勢。

圖片

論文地址:https://ieeexplore.ieee.org/document/9830277

儘管使用了更低精確的數字,但該晶片努力保持計算準確率,至少在訓練過程中的推理部分是這樣。推理是在經過充分訓練的模型上執行以獲得輸出,但在訓練期間也會重複進行。Bill Dally 表示,「我們最終以 4-bit 精度得到 8-bit 結果。」

圖片

英偉達的每向量縮放方案比 INT4 等標準格式更好地表示機器學習中需要的數字。

得益於這種方案,英偉達能夠在沒有顯著準確率損失的情況下減少數字大小。基本理念是這樣的:一個 4-bit 數字只能精確表示 16 個值。因此,每個數字都會四捨五入到這 16 個值的其中一個。這種舍入導致的準確率損失被稱為量化誤差。

但是,你可以新增一個縮放因子在數軸上將 16 個值均勻地壓縮在一起或將它們拉得更遠,從而減少或增加量化誤差。

所以訣竅在於壓縮或擴充套件這 16 個值,這樣它們就能與你在神經網路中實際需要表示的數字範圍形成最佳匹配。這種縮放對於不同的資料集也是不同的。透過為神經網路模型中每個包含 64 個數字的集合微調這種縮放引數,英偉達的研究者能夠最大限度地減少量化誤差。他們還發現,計算縮放因子的開銷也可以忽略不計。但隨著 8-bit 表示減少至 4-bit,能效翻了一番。

實驗晶片仍在開發當中,英偉達工程師也在努力研究如何在整個訓練流程而不是僅在推理中利用這些原理。Dally 表示,如果成功,結合了 4-bit 計算、VSQ 和其他效率改進的晶片可以在每瓦特運算次數上達到 Hopper GPU 的 10 倍。

一種新的數字格式——Posits

早在 2017 年,美國電腦科學家 John Gustafson 和 J. Craig Venter 研究所助理研究員 Isaac Yonemoto 開發出了一種全新的數字表示方式—— posit。

圖片

論文地址:http://www.johngustafson.net/pdfs/BeatingFloatingPoint.pdf

現在,馬德里康普頓斯大學的一組研究人員開發了首個在硬體中實現 posit 標準的處理器核心,並表示與使用標準浮點數的計算相比,基本計算任務的準確率最高可以提升四個量級。

圖片

論文地址:https://ieeexplore.ieee.org/document/9817027/references#references

posits 的優勢在於它們沿著數軸分佈來精確表示數字。在數軸的中間,大約在 1 和–1 附近,存在比浮點更多的 posit 表示。在兩端,對於大的負數和正數,posit 準確率比浮點下降得更優雅。

Gustafson 表示,「posits 更適合計算中數字的自然分佈。其實,浮點運算中有大量的 bit 模式,不過沒有人使用過。這是一種浪費。」

Posits 在 1 和 -1 附近提升了準確率,這要得益於它們的表示中存在一個額外的元件。浮點數由三部分組成:一個符號位(0 表示正,1 表示負)、幾個尾數(分數)位以表示二進位制小數點之後的內容以及定義指數(2^exp)的其餘位。

Posits 保留了浮點數的所有元件,但新增了一個額外的「regime」部分,即指數的指數。這個 regime 的奇妙之處在於它的 bit 長度可以變化。對於小數,regime 可能只需要 2 個 bit,為尾數提供更高精度。這允許在 1 和 - 1 附近的最佳位置實現更高的準確率。

圖片

透過新增一個額外的可變長度機制,在零附近的數字將具有更好的準確率,神經網路中使用的大多數數字都在該位置。

藉助在 FPGA 上合成的新硬體實現,Complutense 團隊能夠比較使用 32 位浮點數和 32 位 posits 完成計算的效果。他們透過將其與使用更準確但計算成本更高的 64 位浮點格式的結果進行比較來評估準確率。posits 在矩陣乘法的準確率方面驚人地提升了四個數量級。他們還發現,提高精度並沒有以計算時間為代價,只是稍微增加了晶片的面積和功耗。

降低 RISC-V 的數學風險

一個來自瑞士和義大利的研究團隊曾開發一種減少 bit 的方案,適用於使用開源 RISC-V 指令集架構的處理器, 並推進了新處理器的開發。該團隊對 RISC-V 指令集的擴充套件包括一個有效的計算版本,它混合了較低和較高精度的數字表示。憑藉改進的混合精度數學,他們在訓練神經網路所涉及的基本計算中獲得了兩倍的加速。

降低精度在基本操作期間不僅會因 bit 減少導致精度損失,還會產生連鎖反應。將兩個低精度數字相乘可能會導致數字太小或太大而無法表示給定的 bit 長度——分別稱為下溢和上溢;另外,將一個大的低精度數和一個小的低精度數相加時,會發生 swamping 現象,導致較小的數字完全丟失。

混合精度對於改善上溢、下溢和 swamping 問題具有重要作用,其中使用低精度輸入執行計算併產生更高精度的輸出,在舍入到較低精度之前完成一批數學運算。

點積是人工智慧計算的一個基本組成部分,它通常透過一系列稱為融合乘加單元 (FMA) 的元件在硬體中實現。它們一次性執行操作 d = a*b + c,最後只進行四捨五入。為了獲得混合精度的好處,輸入 a 和 b 是低精度(例如 8 bits),而 c 和輸出 d 是高精度(例如 16 bits)。

IEEE Fellow Luca Benini 等人認為:與其一次只做一個 FMA 操作,不如同時做兩個並在最後將它們加在一起。這不僅可以防止由於兩個 FMA 之間的舍入而造成的損失,而且還可以更好地利用記憶體,因為這樣就不需要有記憶體暫存器等待前一個 FMA 完成。

Luca Benini 領導的小組設計並模擬了並行混合精度點積單元,發現向量的點積計算時間幾乎減少了一半,並且輸出精度提高了。他們目前正在構建新的硬體架構,以證明模擬的預測。

更多詳細內容請參閱原文連結:https://spectrum.ieee.org/number-representation

相關文章