Dirichlet分布生成联邦学生non-iid数据
1. 根据数据集训练数据标签和客户端数量,生成每个客户端上具有dirichlet分布的数据索引。
def dirichlet_split_noniid(train_labels, alpha, n_clients):
'''
按照参数为alpha的Dirichlet分布将样本索引集合划分为n_clients个子集
'''
n_classes = train_labels.max()+1
# (K, N) 类别标签分布矩阵X,记录每个类别划分到每个client去的比例
label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
# (K, ...) 记录K个类别对应的样本索引集合
class_idcs = [np.argwhere(train_labels == y).flatten()
for y in range(n_classes)]
# 记录N个client分别对应的样本索引集合
client_idcs = [[] for _ in range(n_clients)]
for k_idcs, fracs in zip(class_idcs, label_distribution):
# np.split按照比例fracs将类别为k的样本索引k_idcs划分为了N个子集
# i表示第i个client,idcs表示其对应的样本索引集合idcs
for i, idcs in enumerate(np.split(k_idcs,
(np.cumsum(fracs)[:-1]*len(k_idcs)).
astype(int))):
client_idcs[i] += [idcs]
client_idcs = [np.concatenate(idcs) for idcs in client_idcs]
return client_idcs
2. 调用函数根据索引产生训练数据集
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
def load_dataset_fmnist():
mnist_mean, mnist_std = 0.1307, 0.3081
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((mnist_mean,), (mnist_std,))])
# transform = transforms.Compose([transforms.ToTensor()])
fmnist_train_dataset = datasets.FashionMNIST(root="./data/FMNIST/data", train=True, transform=transform, download=True)
fmnist_test_dataset = datasets.FashionMNIST(root="./data/FMNIST/data", train=False, transform=transform, download=True)
return fmnist_train_dataset, fmnist_test_dataset
trainset, testset = load_dataset_fmnist()
labels = trainset.targets[:]
classes = trainset.classes
n_classes = len(classes)
dirichlet_alpha = 0.1
n_clients = 4
client_idcs = dirichlet_split_noniid(trainset.targets, alpha = dirichlet_alpha, n_clients= n_clients)
3. 绘制不同客户端上的数据标签分布情况
plt.figure(figsize=(12, 8))
label_distribution = [[] for _ in range(n_classes)]
for c_id, idc in enumerate(client_idcs):
for idx in idc:
label_distribution[labels[idx]].append(c_id)
plt.hist(label_distribution, stacked=True,
bins=np.arange(-0.5, n_clients + 1.5, 1),
label=classes, rwidth=0.5)
plt.xticks(np.arange(n_clients), ["Client %d" %
c_id for c_id in range(n_clients)])
plt.xlabel("Client ID")
plt.ylabel("Number of samples")
plt.legend()
plt.title("Display Label Distribution on Different Clients")
plt.show()
参考资料:
联邦学习:按Dirichlet分布划分Non-IID样本
病态非独立同分布