聯邦學習:按Dirichlet分佈劃分Non-IID樣本

orion發表於2022-02-15

我們在《Python中的隨機取樣和概率分佈(二)》介紹瞭如何用Python現有的庫對一個概率分佈進行取樣,其中的Dirichlet分佈大家一定不會感到陌生。該分佈的概率密度函式為

\[P(\bm{x}; \bm{\alpha}) \propto \prod_{i=1}^{k} x_{i}^{\alpha_{i}-1} \\ \bm{x}=(x_1,x_2,...,x_k),\quad x_i > 0 , \quad \sum_{i=1}^k x_i = 1\\ \bm{\alpha} = (\alpha_1,\alpha_2,..., \alpha_k). \quad \alpha_i > 0 \]

其中\(\bm{\alpha}\)為引數。

我們在聯邦學習中,經常會假設不同client間的資料集不滿足獨立同分布(Non-IID)。那麼我們如何將一個現有的資料集按照Non-IID劃分呢?我們知道帶標籤樣本的生成分佈看可以表示為\(p(\bm{x}, y)\),我們進一步將其寫作\(p(\bm{x}, y)=p(\bm{x}|y)p(y)\)。其中如果要估計\(p(\bm{x}|y)\)的計算開銷非常大,但估計\(p(y)\)的計算開銷就很小。所有我們按照樣本的標籤分佈來對樣本進行Non-IID劃分是一個非常高效、簡便的做法。

總而言之,我們採取的演算法思路是儘量讓每個client上的樣本標籤分佈不同。我們設有\(K\)個類別標籤,\(N\)個client,每個類別標籤的樣本需要按照不同的比例劃分在不同的client上。我們設矩陣\(\bm{X}\in \mathbb{R}^{K*N}\)為類別標籤分佈矩陣,其行向量\(\bm{x}_k\in \mathbb{R}^N\)表示類別\(k\)在不同client上的概率分佈向量(每一維表示\(k\)類別的樣本劃分到不同client上的比例),該隨機向量就取樣自Dirichlet分佈。

據此,我們可以寫出以下的劃分演算法:

import numpy as np
np.random.seed(42)
def split_noniid(train_labels, alpha, n_clients):
    '''
    引數為alpha的Dirichlet分佈將資料索引劃分為n_clients個子集
    '''
    n_classes = train_labels.max()+1
    label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
    # (K, N)的類別標籤分佈矩陣X,記錄每個client佔有每個類別的多少

    class_idcs = [np.argwhere(train_labels==y).flatten() 
           for y in range(n_classes)]
    # 記錄每個K個類別對應的樣本下標
 
    client_idcs = [[] for _ in range(n_clients)]
    # 記錄N個client分別對應樣本集合的索引
    for c, fracs in zip(class_idcs, label_distribution):
        # np.split按照比例將類別為k的樣本劃分為了N個子集
        # for i, idcs 為遍歷第i個client對應樣本集合的索引
        for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1]*len(c)).astype(int))):
            client_idcs[i] += [idcs]

    client_idcs = [np.concatenate(idcs) for idcs in client_idcs]
  
    return client_idcs

加下來我們在EMNIST資料集上呼叫該函式進行測試,並進行視覺化呈現。我們設client數量\(N=10\),Dirichlet概率分佈的引數向量\(\bm{\alpha}\)滿足\(\alpha_i=1.0,\space i=1,2,...N\)

import torch
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)

if __name__ == "__main__":

    N_CLIENTS = 10 
    DIRICHLET_ALPHA = 1.0

    train_data = datasets.EMNIST(root=".", split="byclass", download=True, train=True)
    test_data = datasets.EMNIST(root=".", split="byclass", download=True, train=False)
    n_channels = 1


    input_sz, num_cls = train_data.data[0].shape[0],  len(train_data.classes)


    train_labels = np.array(train_data.targets)

    # 我們讓每個client不同label的樣本數量不同,以此做到Non-IID劃分
    client_idcs = split_noniid(train_labels, alpha=DIRICHLET_ALPHA, n_clients=N_CLIENTS)


    # 展示不同client的不同label的資料分佈
    plt.figure(figsize=(20,3))
    plt.hist([train_labels[idc]for idc in client_idcs], stacked=True, 
            bins=np.arange(min(train_labels)-0.5, max(train_labels) + 1.5, 1),
            label=["Client {}".format(i) for i in range(N_CLIENTS)], rwidth=0.5)
    plt.xticks(np.arange(num_cls), train_data.classes)
    plt.legend()
    plt.show()

最終的視覺化結果如下:
深度多工學習例項1
可以看到,62個類別標籤在不同client上的分佈確實不同,證明我們的樣本劃分演算法是有效的。

相關文章