分散式機器學習中的模型聚合

orion發表於2021-12-02

論文[1]在聯邦(分散式)學習的情景下引入了多工學習,其採用的手段是使每個client/task節點的訓練資料分佈不同,從而使各任務節點學習到不同的模型,且每個任務節點以及全域性(global)的模型都由多個分量模型整合。該論文最關鍵與核心的地方在於將各任務節點學習到的模型進行聚合/通訊,依據模型聚合方式的不同,可以將模型採用的演算法分為client-server方法,和fully decentralized(完全去中心化)的方法(其實還有其他的聚合方法沒,如論文[3]提出的簇狀聚合方法,程式碼參見[4]我們這裡暫時略過),其中這兩種方法在具體實現上都可以替換為對代理損失函式的優化,不過我們這裡暫時略過。

因為有多種任務聚合器(Aggregator)要實現,論文程式碼(已開源在Github上,參見[2])採取的措施是先實現Aggregator抽象基類,實現好一些通用方法,並規定好抽象方法的介面,然後具體的任務聚合類繼承抽象基類,然後做具體的實現。

我們先來看任務聚合器(Aggregator)這一抽象基類

class Aggregator(ABC):
    r"""Aggregator的基類. `Aggregator`規定了client之間的通訊"""
    def __init__(
            self,
            clients,
            global_learners_ensemble,
            log_freq,
            global_train_logger,
            global_test_logger,
            sampling_rate=1.,
            sample_with_replacement=False,
            test_clients=None,
            verbose=0,
            seed=None,
            *args,
            **kwargs
    ):

        rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
        self.rng = random.Random(rng_seed) # 隨機數生成器
        self.np_rng = np.random.default_rng(rng_seed) # numpy隨機數生成器

        if test_clients is None:
            test_clients = []

        self.clients = clients #  List[Client]
        self.test_clients = test_clients #  List[Client]

        self.global_learners_ensemble = global_learners_ensemble # List[Learner]
        self.device = self.global_learners_ensemble.device


        self.log_freq = log_freq
        self.verbose = verbose
        # verbose: 調整輸出列印的冗餘度(verbosity), 
        # `0` 表示quiet(無任何列印輸出), `1` 顯示日誌, `2` 顯示所有區域性日誌; 預設是 `0`
        self.global_train_logger = global_train_logger
        self.global_test_logger = global_test_logger

        self.model_dim = self.global_learners_ensemble.model_dim # #模型特徵維度

        self.n_clients = len(clients)
        self.n_test_clients = len(test_clients)
        self.n_learners = len(self.global_learners_ensemble)

        # 儲存為每個client分配的權重(權重為0-1之間的小數)
        self.clients_weights =\
            torch.tensor(
                [client.n_train_samples for client in self.clients],
                dtype=torch.float32
            )
        self.clients_weights = self.clients_weights / self.clients_weights.sum()

        self.sampling_rate = sampling_rate  #  clients在每一輪使用的比例,預設為`1.`
        self.sample_with_replacement = sample_with_replacement #對client進行採用是可重複還是無重複的,with_replacement=True表示可重複的,否則是不可重複的

        # 每輪迭代需要使用到的client個數
        self.n_clients_per_round = max(1, int(self.sampling_rate * self.n_clients))

        # 取樣得到的client列表
        self.sampled_clients = list()

        # 記載當前的迭代通訊輪數
        self.c_round = 0 
        self.write_logs()

    @abstractmethod
    def mix(self): 
        """
        該方法用於完成各client之間的權重引數與通訊操作
        """
        pass

    @abstractmethod
    def update_clients(self): 
        """
        該方法用於將所有全域性分量模型拷貝到各個client,相當於boardcast操作
        """
        pass

    def update_test_clients(self):
        """
        將全域性(gobal)的所有分量模型都拷貝到各個client上
        """

    def write_logs(self):
        """
        對全域性(global)的train和test資料集的loss和acc做記錄
        需要對所有client的所有樣本做累加,然後除以所有client的樣本總數做平均。
        """

    def save_state(self, dir_path):
        """
        儲存aggregator的模型state,。例如, `global_learners_ensemble`中每個分量模型'learner'的state字典(以`.pt`檔案格式),以及`self.clients` 中每個client的 `learners_weights` (注意,這個權重不是模型內部的引數,而是進行繼承的時候對各個分量模型賦予的權重,包含train和test兩部分,以一個大小為n_clients(n_test_clients)× n_learners的numpy陣列的格式,即`.npy` 檔案)。
        """

    def load_state(self, dir_path):
        """
        載入aggregator的模型state,即save_state方法裡儲存的那些
        """

    def sample_clients(self):
        """
        對clients進行取樣,
        如果self.sample_with_replacement為True,則為可重複取樣,
        否則,則為不可重複採用。
        最終得到一個clients子集列表並賦予self.sampled_clients
        """

1.client-server 演算法

這種方式的通訊/聚合方法也稱中心化(centralized)方法,因為該方法在每一輪迭代最後將所有client的權重資料彙集到server節點。這種方法的優化迭代部分的虛擬碼示意如下:
CV多工學習

落實到具體程式碼實現上,這種方法的Aggregator設計如下:

class CentralizedAggregator(Aggregator):
    r""" 標準的中心化Aggreagator
    所有clients在每一輪迭代末和average client完全同步.
    """
    def mix(self):
        self.sample_clients()

        # 對self.sampled_clients中每個client的引數進行優化
        for client in self.sampled_clients:
            # 相當於虛擬碼第11行呼叫的LocalSolver函式
            client.step()

        # 遍歷global模型(self.global_learners_ensemble) 中每一個分量模型(learner)
        # 相當於虛擬碼第13行
        for learner_id, learner in enumerate(self.global_learners_ensemble):
            # 獲取所有client中對應learner_id的分量模型
            learners = [client.learners_ensemble[learner_id] for client in self.clients]
            # global模型的分量模型為所有client對應分量模型取平均,相當於虛擬碼第14行
            average_learners(learners, learner, weights=self.clients_weights)

        # 將更新後的模型賦予所有clients,相當於虛擬碼第5行的boardcast操作
        self.update_clients()

        # 通訊輪數+1
        self.c_round += 1

        if self.c_round % self.log_freq == 0:
            self.write_logs()

    def update_clients(self):
        """
        此函式負責將所有全域性分量模型拷貝到各個client,相當於虛擬碼中第5行的boardcast操作
        """
        for client in self.clients:
            for learner_id, learner in enumerate(client.learners_ensemble):
                copy_model(learner.model, self.global_learners_ensemble[learner_id].model)

                if callable(getattr(learner.optimizer, "set_initial_params", None)):
                    learner.optimizer.set_initial_params(
                        self.global_learners_ensemble[learner_id].model.parameters()
                    )

2. fully decentralized(完全去中心化)演算法

這種方法之所以被稱為去中心化的,因為該方法在每一輪迭代不需要所有client的權重資料彙集到一個特定的server節點,而只需要完成每個節點和其鄰居進行通訊(引數共享)即可。這種方法的優化迭代部分的虛擬碼示意如下:
CV多工學習
落實到具體程式碼實現上,這種方法的Aggregator設計如下:

class DecentralizedAggregator(Aggregator):
    def __init__(
            self,
            clients,
            global_learners_ensemble,
            mixing_matrix,
            log_freq,
            global_train_logger,
            global_test_logger,
            sampling_rate=1.,
            sample_with_replacement=True,
            test_clients=None,
            verbose=0,
            seed=None):

        super(DecentralizedAggregator, self).__init__(
            clients=clients,
            global_learners_ensemble=global_learners_ensemble,
            log_freq=log_freq,
            global_train_logger=global_train_logger,
            global_test_logger=global_test_logger,
            sampling_rate=sampling_rate,
            sample_with_replacement=sample_with_replacement,
            test_clients=test_clients,
            verbose=verbose,
            seed=seed
        )

        self.mixing_matrix = mixing_matrix
        assert self.sampling_rate >= 1, "partial sampling is not supported with DecentralizedAggregator"

    def update_clients(self):
        pass

    def mix(self):
        
        # 對各clients的模型引數進行優化
        for client in self.clients:
            client.step()

        # 儲存每個模型各引數混合的權重
        # 行對應不同的client,列對應單個模型中不同的引數
        # (注意:每個分量有獨立的mixing_matrix)
        mixing_matrix = torch.tensor(
            self.mixing_matrix.copy(),
            dtype=torch.float32,
            device=self.device
        )

        # 遍歷global模型(self.global_learners_ensemble) 中每一個分量模型(learner)
        # 相當於虛擬碼第14行
        for learner_id, global_learner in enumerate(self.global_learners_ensemble):
            # 用於將指定learner_id的各client的模型state讀出暫存
            state_dicts = [client.learners_ensemble[learner_id].model.state_dict() for client in self.clients]

            # 遍歷global模型中的各引數
            for key, param in global_learner.model.state_dict().items():
                shape_ = param.shape
                models_params = torch.zeros(self.n_clients, int(np.prod(shape_)), device=self.device)

                for ii, sd in enumerate(state_dicts):
                    # models_params的第ii個下標儲存的是第ii個client的引數
                    models_params[ii] = sd[key].view(1, -1) 

                # models_params的每一行是一個client的引數
                # @符號表示矩陣乘/矩陣向量乘
                # 故這裡表示每個client引數是其他所有client引數的混合
                models_params = mixing_matrix @ models_params

                for ii, sd in enumerate(state_dicts):
                    # 將第ii個client的引數存入state_dicts中對應位置
                    sd[key] = models_params[ii].view(shape_)

            # 將更新好的引數從state_dicts存入各client節點的模型中
            for client_id, client in enumerate(self.clients):
                client.learners_ensemble[learner_id].model.load_state_dict(state_dicts[client_id])

        # 通訊輪數+1
        self.c_round += 1

        if self.c_round % self.log_freq == 0:
            self.write_logs()

參考文獻

相關文章