找到懂機器學習系統開發的人太難了。現今的學生都一心學習機器學習演算法,很多學生對於底層的運作原理理解得很淺。而當他們在實際中應用機器學習技術時才意識到系統的重要性,那時想去學習,卻沒有了充沛的學習時間。
導言
為了支援在不同應用中高效開發機器學習演算法,人們設計和實現了機器學習框架(如TensorFlow、PyTorch、MindSpore等)。廣義來說,這些框架實現了以下共性的設計目標:
- 神經網路程式設計: 深度學習的巨大成功使得神經網路成為了許多機器學習應用的核心。根據應用的需求,人們需要定製不同的神經網路,如卷積神經網路(Convolutional Neural Networks)和自注意力神經網路(Self-Attention Neural Networks)等。這些神經網路需要一個共同的系統軟體進行開發、訓練和部署。
- 自動微分: 訓練神經網路會具有模型引數。這些引數需要透過持續計算梯度(Gradients)迭代改進。梯度的計算往往需要結合訓練資料、資料標註和損失函式(Loss Function)。考慮到大多數開發人員並不具備手工計算梯度的知識,機器學習框架需要根據開發人員給出的神經網路程式,全自動地計算梯度。這一過程被稱之為自動微分。
- 資料管理和處理: 機器學習的核心是資料。這些資料包括訓練、驗證、測試資料集和模型引數。因此,需要系統本身支援資料讀取、儲存和預處理(例如資料增強和資料清洗)。
- 模型訓練和部署: 為了讓機器學習模型達到最佳的效能,需要使用最佳化方法(例如Mini-Batch SGD)來透過多步迭代反覆計算梯度,這一過程稱之為訓練。訓練完成後,需要將訓練好的模型部署到推理裝置。
- 硬體加速器: 神經網路的相關計算往往透過矩陣計算實現。這一類計算可以被硬體加速器(例如,通用圖形處理器-GPU)加速。因此,機器學習系統需要高效利用多種硬體加速器。
- 分散式執行: 隨著訓練資料量和神經網路引數量的上升,機器學習系統的記憶體用量遠遠超過了單個機器可以提供的記憶體。因此,機器學習框架需要天然具備分散式執行的能力。
一個完整的機器學習框架一般具有如圖所示的基本架構
- 程式設計介面: 考慮到機器學習開發人員背景的多樣性,機器學習框架首先需要提供以高層次程式語言(如Python)為主的程式設計介面。同時,機器學習框架為了最佳化執行效能,需要支援以低層次程式語言(如C和C++)為主的系統實現,從而實現作業系統(如執行緒管理和網路通訊等)和各型別硬體加速器的高效使用。
- 計算圖: 利用不同程式設計介面實現的機器學習程式需要共享一個執行後端。實現這一後端的關鍵技術是計算圖技術。計算圖定義了使用者的機器學習程式,其包含大量表達計算操作的運算元節點(Operator Node),以及表達運算元之間計算依賴的邊(Edge)。
- 編譯器前端: 機器學習框架往往具有AI編譯器來構建計算圖,並將計算圖轉換為硬體可以執行的程式。這個編譯器首先會利用一系列編譯器前端技術實現對程式的分析和最佳化。編譯器前端的關鍵功能包括實現中間表示、自動微分、型別推導和靜態分析等。
- 編譯器後端和執行時: 完成計算圖的分析和最佳化後,機器學習框架進一步利用編譯器後端和執行時實現針對不同底層硬體的最佳化。常見的最佳化技術包括分析硬體的L2/L3快取大小和指令流水線長度,最佳化運算元的選擇或者排程順序。
- 異構處理器: 機器學習應用的執行由中央處理器(Central Processing Unit,CPU)和硬體加速器(如英偉達GPU、華為Ascend和谷歌TPU)共同完成。其中,非矩陣操作(如複雜的資料預處理和計算圖的排程執行)由中央處理器完成。矩陣操作和部分頻繁使用的機器學習運算元(如Transformer運算元和Convolution運算元)由硬體加速器完成。
- 資料處理: 機器學習應用需要對原始資料進行復雜預處理,同時也需要管理大量的訓練資料集、驗證資料集和測試資料集。這一系列以資料為核心的操作由資料處理模組(例如TensorFlow的tf.data和PyTorch的DataLoader)完成。
- 模型部署: 在完成模型訓練後,機器學習框架下一個需要支援的關鍵功能是模型部署。為了確保模型可以在記憶體有限的硬體上執行,會使用模型轉換、量化、蒸餾等模型壓縮技術。同時,也需要實現針對推理硬體平臺(例如英偉達Orin)的模型運算元最佳化。最後,為了保證模型的安全(如拒絕未經授權的使用者讀取),還會對模型進行混淆設計。
- 分散式訓練: 機器學習模型的訓練往往需要分散式的計算節點並行完成。其中,常見的並行訓練方法包括資料並行、模型並行、混合並行和流水線並行。這些並行訓練方法通常由遠端程式呼叫(Remote Procedure Call, RPC)、集合通訊(Collective Communication)或者引數伺服器(Parameter Server)實現。