我們在《Python中的隨機取樣和概率分佈(二)》介紹瞭如何用Python現有的庫對一個概率分佈進行取樣,其中的Dirichlet分佈大家一定不會感到陌生。該分佈的概率密度函式為
其中\(\bm{\alpha}\)為引數。
我們在聯邦學習中,經常會假設不同client間的資料集不滿足獨立同分布(Non-IID)。那麼我們如何將一個現有的資料集按照Non-IID劃分呢?我們知道帶標籤樣本的生成分佈看可以表示為\(p(\bm{x}, y)\),我們進一步將其寫作\(p(\bm{x}, y)=p(\bm{x}|y)p(y)\)。其中如果要估計\(p(\bm{x}|y)\)的計算開銷非常大,但估計\(p(y)\)的計算開銷就很小。所有我們按照樣本的標籤分佈來對樣本進行Non-IID劃分是一個非常高效、簡便的做法。
總而言之,我們採取的演算法思路是儘量讓每個client上的樣本標籤分佈不同。我們設有\(K\)個類別標籤,\(N\)個client,每個類別標籤的樣本需要按照不同的比例劃分在不同的client上。我們設矩陣\(\bm{X}\in \mathbb{R}^{K*N}\)為類別標籤分佈矩陣,其行向量\(\bm{x}_k\in \mathbb{R}^N\)表示類別\(k\)在不同client上的概率分佈向量(每一維表示\(k\)類別的樣本劃分到不同client上的比例),該隨機向量就取樣自Dirichlet分佈。
據此,我們可以寫出以下的劃分演算法:
import numpy as np
np.random.seed(42)
def dirichlet_split_noniid(train_labels, alpha, n_clients):
'''
引數為alpha的Dirichlet分佈將資料索引劃分為n_clients個子集
'''
n_classes = train_labels.max()+1
label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
# (K, N)的類別標籤分佈矩陣X,記錄每個client佔有每個類別的多少
class_idcs = [np.argwhere(train_labels==y).flatten()
for y in range(n_classes)]
# 記錄每個K個類別對應的樣本下標
client_idcs = [[] for _ in range(n_clients)]
# 記錄N個client分別對應樣本集合的索引
for c, fracs in zip(class_idcs, label_distribution):
# np.split按照比例將類別為k的樣本劃分為了N個子集
# for i, idcs 為遍歷第i個client對應樣本集合的索引
for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1]*len(c)).astype(int))):
client_idcs[i] += [idcs]
client_idcs = [np.concatenate(idcs) for idcs in client_idcs]
return client_idcs
接下來我們在EMNIST資料集上呼叫該函式進行測試,並進行視覺化呈現。我們設client數量\(N=10\),Dirichlet概率分佈的引數向量\(\bm{\alpha}\)滿足\(\alpha_i=1.0,\space i=1,2,...N\):
import torch
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(42)
if __name__ == "__main__":
N_CLIENTS = 10
DIRICHLET_ALPHA = 1.0
train_data = datasets.EMNIST(root=".", split="byclass", download=True, train=True)
test_data = datasets.EMNIST(root=".", split="byclass", download=True, train=False)
n_channels = 1
input_sz, num_cls = train_data.data[0].shape[0], len(train_data.classes)
train_labels = np.array(train_data.targets)
# 我們讓每個client不同label的樣本數量不同,以此做到Non-IID劃分
client_idcs = dirichlet_split_noniid(train_labels, alpha=DIRICHLET_ALPHA, n_clients=N_CLIENTS)
# 展示不同client的不同label的資料分佈
plt.figure(figsize=(20,3))
plt.hist([train_labels[idc]for idc in client_idcs], stacked=True,
bins=np.arange(min(train_labels)-0.5, max(train_labels) + 1.5, 1),
label=["Client {}".format(i) for i in range(N_CLIENTS)], rwidth=0.5)
plt.xticks(np.arange(num_cls), train_data.classes)
plt.legend()
plt.show()
最終的視覺化結果如下:
可以看到,62個類別標籤在不同client上的分佈確實不同,證明我們的樣本劃分演算法是有效的。