讓數百萬臺手機訓練同一個模型?Google把這套框架開源了

AI科技大本營發表於2019-03-09

640?wx_fmt=jpeg


作者 | 琥珀

出品 | AI科技大本營(公眾號id:rgznai100)


【導語】據瞭解,全球有 30 億臺智慧手機和 70 億臺邊緣裝置。每天,這些電話與裝置之間的互動不斷產生新的資料。傳統的資料分析和機器學習模式,都需要在處理資料之前集中收集資料至伺服器,然後進行機器學習訓練並得到模型引數,最終獲得更好的產品。


但如果這些需要聚合的資料敏感且昂貴的話,那麼這種中心化的資料收集手段可能就不太適用了。


去掉這一步驟,直接在生成資料的邊緣裝置上進行資料分析和機器學習訓練呢?


近日,Google 開源了一款名為 TensorFlow Federated (TFF)的框架,可用於去中心化(decentralized)資料的機器學習及運算實驗。它實現了一種稱為聯邦學習(Federated Learning,FL)的方法,將為開發者提供分散式機器學習,以便在沒有資料離開裝置的情況下,便可在多種裝置上訓練共享的 ML 模型。其中,通過加密方式提供多一層的隱私保護,並且裝置上模型訓練的權重與用於連續學習的中心模型共享。


傳送門:https://www.tensorflow.org/federated/


實際上,早在 2017 年 4 月,Google AI 團隊就推出了聯邦學習的概念。這種被稱為聯邦學習的框架目前已應用在 Google 內部用於訓練神經網路模型,例如智慧手機中虛擬鍵盤的下一詞預測和音樂識別搜尋功能。


640?wx_fmt=png

640?wx_fmt=png

圖注:每臺手機都在本地訓練模型(A);將使用者更新資訊聚合(B);然後形成改進的共享模型(C)。


DeepMind 研究員 Andrew Trask 隨後發推稱讚:“Google 已經開源了 Federated Learning……可在數以百萬計的智慧手機上共享模型訓練!”


640?wx_fmt=png


讓我們一起來看看使用教程:


從一個著名的影像資料集 MNIST 開始。MNIST 的原始資料集為 NIST,其中包含 81 萬張手寫的數字,由 3600 個志願者提供,目標是建立一個識別數字的 ML 模型。


傳統手段是立即將 ML 演算法應用於整個資料集。但實際上,如果資料提供者不願意將原始資料上傳到中央伺服器,就無法將所有資料聚合在一起。


TFF 的優勢就在於,可以先選擇一個 ML 模型架構,然後輸入資料進行訓練,同時保持每個資料提供者的資料是獨立且儲存在本地。


下面顯示的是通過呼叫 TFF 的 FL API,使用已由 GitHub 上的“Leaf”專案處理的 NIST 資料集版本來分隔每個資料提供者所寫的數字:


GitHub 傳送連結:https://github.com/TalwalkarLab/leaf



# Load simulation data.
source, _ = tff.simulation.datasets.emnist.load_data()
def client_data(n):
  dataset = source.create_tf_dataset_for_client(source.client_ids[n])
  return mnist.keras_dataset_from_emnist(dataset).repeat(10).batch(20)

# Wrap a Keras model for use with TFF.
def model_fn():
  return tff.learning.from_compiled_keras_model(
      mnist.create_simple_keras_model(), sample_batch)

# Simulate a few rounds of training with the selected client devices.
trainer = tff.learning.build_federated_averaging_process(model_fn)
state = trainer.initialize()
for _ in range(5):
  state, metrics = trainer.next(state, train_data)
  print (metrics.loss)


除了可呼叫 FL API 外,TFF 還帶有一組較低階的原語(primitive),稱之為 Federated Core (FC) API。這個 API 支援在去中心化的資料集上表達各種計算。


使用 FL 進行機器學習模型訓練僅是第一步;其次,我們還需要對這些資料進行評估,這時就需要 FC API 了。


假設我們有一系列感測器可用於捕獲溫度讀數,並希望無需上傳資料便可計算除這些感測器上的平均溫度。呼叫 FC 的 API,就可以表達一種新的資料型別,例如指出 tf.float32,該資料位於分散式的客戶端上。


READINGS_TYPE = tff.FederatedType(tf.float32, tff.CLIENTS)


然後在該型別的資料上定義聯邦平均數。


@tff.federated_computation(READINGS_TYPE)
def get_average_temperature(sensor_readings):
  return tff.federated_average(sensor_readings)


之後,TFF 就可以在去中心化的資料環境中執行。從開發者的角度來講,FL 演算法可以看做是一個普通的函式,它恰好具有駐留在不同位置(分別在各個客戶端和協調服務中的)輸入和輸出。


640?wx_fmt=png


例如,使用了 TFF 之後,聯邦平均演算法的一種變體:


參考連結:https://arxiv.org/abs/1602.05629


@tff.federated_computation(
  tff.FederatedType(DATASET_TYPE, tff.CLIENTS),
  tff.FederatedType(MODEL_TYPE, tff.SERVER, all_equal=True),
  tff.FederatedType(tf.float32, tff.SERVER, all_equal=True))
def federated_train(client_data, server_model, learning_rate):
  return tff.federated_average(
      tff.federated_map(local_train, [
          client_data,
          tff.federated_broadcast(server_model),
          tff.federated_broadcast(learning_rate)]))


目前已開放教程,可以先在模型上試驗現有的 FL 演算法,也可以為 TFF 庫提供新的聯邦資料集和模型,還可以新增新的 FL 演算法實現,或者擴充套件現有 FL 演算法的新功能。


據瞭解,在 FL 推出之前,Google 還推出了 TensorFlow Privacy,一個機器學習框架庫,旨在讓開發者更容易訓練具有強大隱私保障的 AI 模型。目前二者可以整合,在差異性保護使用者隱私的基礎上,還能通過聯邦學習(FL)技術快速訓練模型。


最後附上 TF Dev Summit’19 上,TensorFlow Federated (TFF)的釋出會現場視訊:



參考連結:https://medium.com/tensorflow/introducing-tensorflow-federated-a4147aa20041


(本文為 AI科技大本營原創文章,轉載請微信聯絡 1092722531


4 月13日-4 月14日,CSDN 將在北京主辦“Python 開發者日( 2019 )”,匯聚十餘位來自阿里巴巴IBM英偉達等國內外一線科技公司的Python技術專家,還有數百位來自各行業領域的Python開發者。目前購票通道已開啟,早鳥票限量發售中,3 月15日之前可享受優惠價 299 元(售完即止)。


640?wx_fmt=jpeg

推薦閱讀:

                         640?wx_fmt=png

點選“閱讀原文”,檢視歷史精彩文章。

相關文章