当前位置: 首页 > article >正文

Dirichlet分布生成联邦学生non-iid数据

1. 根据数据集训练数据标签和客户端数量,生成每个客户端上具有dirichlet分布的数据索引。
def dirichlet_split_noniid(train_labels, alpha, n_clients):
    '''
    按照参数为alpha的Dirichlet分布将样本索引集合划分为n_clients个子集
    '''
    n_classes = train_labels.max()+1
    # (K, N) 类别标签分布矩阵X,记录每个类别划分到每个client去的比例
    label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
    # (K, ...) 记录K个类别对应的样本索引集合
    class_idcs = [np.argwhere(train_labels == y).flatten()
                  for y in range(n_classes)]
    # 记录N个client分别对应的样本索引集合
    client_idcs = [[] for _ in range(n_clients)]
    for k_idcs, fracs in zip(class_idcs, label_distribution):
        # np.split按照比例fracs将类别为k的样本索引k_idcs划分为了N个子集
        # i表示第i个client,idcs表示其对应的样本索引集合idcs
        for i, idcs in enumerate(np.split(k_idcs,
                                          (np.cumsum(fracs)[:-1]*len(k_idcs)).
                                          astype(int))):
            client_idcs[i] += [idcs]
    client_idcs = [np.concatenate(idcs) for idcs in client_idcs]
    return client_idcs
2. 调用函数根据索引产生训练数据集
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

def load_dataset_fmnist():
    mnist_mean, mnist_std = 0.1307, 0.3081
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((mnist_mean,), (mnist_std,))])

    # transform = transforms.Compose([transforms.ToTensor()])

    fmnist_train_dataset = datasets.FashionMNIST(root="./data/FMNIST/data", train=True, transform=transform, download=True)
    fmnist_test_dataset = datasets.FashionMNIST(root="./data/FMNIST/data", train=False, transform=transform, download=True)

    return fmnist_train_dataset, fmnist_test_dataset

trainset, testset = load_dataset_fmnist()
labels = trainset.targets[:]
classes = trainset.classes
n_classes = len(classes)
dirichlet_alpha = 0.1
n_clients = 4
client_idcs = dirichlet_split_noniid(trainset.targets, alpha = dirichlet_alpha, n_clients= n_clients)
3. 绘制不同客户端上的数据标签分布情况
plt.figure(figsize=(12, 8))
label_distribution = [[] for _ in range(n_classes)]
for c_id, idc in enumerate(client_idcs):
    for idx in idc:
        label_distribution[labels[idx]].append(c_id)

plt.hist(label_distribution, stacked=True,
            bins=np.arange(-0.5, n_clients + 1.5, 1),
            label=classes, rwidth=0.5)
plt.xticks(np.arange(n_clients), ["Client %d" %
                                    c_id for c_id in range(n_clients)])
plt.xlabel("Client ID")
plt.ylabel("Number of samples")
plt.legend()
plt.title("Display Label Distribution on Different Clients")
plt.show()

参考资料:
联邦学习:按Dirichlet分布划分Non-IID样本
病态非独立同分布


http://www.kler.cn/a/371319.html

相关文章:

  • 在Excel中如何快速筛选非特定颜色
  • iOS调试真机出现的 “__llvm_profile_initialize“ 错误
  • 堆的基本概念和插入删除方法的介绍
  • 玩转Docker | 使用Docker部署捕鱼网页小游戏
  • Flutter加载本地HTML的优雅解决方案:轻松实现富文本展示
  • 20241027_北京郊游香山公园
  • css实现背景色的斑马条效果
  • 如何用李萨如图形测正弦信号的频率?若不使用李萨如图形,如何用示波器测交流信号频率?
  • PHP内存马:不死马
  • 微信小程序如何实现地图轨迹回放?
  • 地球上的中国:世界地图概览
  • Go中的泛型
  • NFS服务器作业
  • Linux云计算 |【第五阶段】CLOUD-DAY1
  • 字母象形与hand的不同解构
  • 【机器学习】揭秘XGboost:高效梯度提升算法的实践与应用
  • 「C/C++」C++ 设计模式 之 单例模式(Singleton)
  • 怎么实现电脑控制100台手机,苹果手机群控系统不用越狱实现新突破
  • GitHub Actions的 CI/CD
  • 鸿蒙开发培训要多久
  • 【计算机网络教程】课程 章节测试1 计算机网络概述
  • 启明云端乐鑫一级代理商,家电设备Matter交互方案,乐鑫ESP32-S3无线技术
  • JVM 调优深度剖析:优化 Java 应用的全方位攻略(一)
  • CentOS下安装ElasticSearch7.9.2(无坑版)
  • uniapp开发【选择地址-省市区功能】,直接套用即可
  • 2024-10-25 问AI: [AI面试题] 强化学习是如何工作