当前位置: 首页 > article >正文

【FATE联邦学习】FATE是否支持batch分批训练?

思路梳理

想要数据上传到FATE,首先需要reader读入数据,才能后续进行训练,首先要保证reader能读入数据,不知道是否能分批次读入?

上传数据后,FATE需要trainer进行训练,不止是否存在批次训练这种模式?

检查Reader类

值得注意的是,Reader类并不在federatedml库里面,而是一个单独的pipeline库里面的组件。翻阅后发现Reader类继承了Output类。而Output类带有一个关键字data type:

class Output(object):
    def __init__(self, name, data_type='single', has_data=True, has_model=True, has_cache=False, output_unit=1):
        if has_model:
            self.model = Model(name).model
            self.model_output = Model(name).get_all_output()

        if has_data:
            if data_type == "single":
                self.data = SingleOutputData(name).data
                self.data_output = SingleOutputData(name).get_all_output()
            elif data_type == "multi":
                self.data = TraditionalMultiOutputData(name)
                self.data_output = TraditionalMultiOutputData(name).get_all_output()
            else:
                self.data = NoLimitOutputData(name, output_unit)
                self.data_output = NoLimitOutputData(name, output_unit).get_all_output()

        if has_cache:
            self.cache = Cache(name).cache
            self.cache_output = Cache(name).get_all_output()

对应的三个data type类也只不过是划分了data,并没有跟分批次相关的步骤

class SingleOutputData(object):
    def __init__(self, prefix):
        self.prefix = prefix

    @property
    def data(self):
        return ".".join([self.prefix, IODataType.SINGLE])

    @staticmethod
    def get_all_output():
        return ["data"]


class TraditionalMultiOutputData(object):
    def __init__(self, prefix):
        self.prefix = prefix

    @property
    def train_data(self):
        return ".".join([self.prefix, IODataType.TRAIN])

    @property
    def test_data(self):
        return ".".join([self.prefix, IODataType.TEST])

    @property
    def validate_data(self):
        return ".".join([self.prefix, IODataType.VALIDATE])

    @staticmethod
    def get_all_output():
        return [IODataType.TRAIN,
                IODataType.VALIDATE,
                IODataType.TEST]


class NoLimitOutputData(object):
    def __init__(self, prefix, output_unit=1):
        self.prefix = prefix
        self.output_unit = output_unit

    @property
    def data(self):
        return [self.prefix + "." + "data_" + str(i) for i in range(self.output_unit)]

    def get_all_output(self):
        return ["data_" + str(i) for i in range(self.output_unit)]

所以Reader应该是只能单次吞入整个数据集,不能够分批次读入。

检查Trainer

跟train相关的参数都在TrainerParam里面。可是TrainerParam本身只是个存储参数的包装类,里面没有东西。
最终找到了一个job submitter的东西,也是通过传参,调用服务这种形式去做的Task。这些都是包皮,没有实际的代码。

最后在federatedml.nn.homo.trainer.fedavg_trainer里找到FedAvgTrainer,他里面给了参数,里面有batch size:

class FedAVGTrainer(TrainerBase):
    """

    Parameters
    ----------
    epochs: int >0, epochs to train
    batch_size: int, -1 means full batch
    secure_aggregate: bool, default is True, whether to use secure aggregation. if enabled, will add random number
                            mask to local models. These random number masks will eventually cancel out to get 0.
    weighted_aggregation: bool, whether add weight to each local model when doing aggregation.
                         if True, According to origin paper, weight of a client is: n_local / n_global, where n_local
                         is the sample number locally and n_global is the sample number of all clients.
                         if False, simply averaging these models.

    early_stop: None, 'diff' or 'abs'. if None, disable early stop; if 'diff', use the loss difference between
                two epochs as early stop condition, if differences < tol, stop training ; if 'abs', if loss < tol,
                stop training
    tol: float, tol value for early stop

    aggregate_every_n_epoch: None or int. if None, aggregate model on the end of every epoch, if int, aggregate
                             every n epochs.
    cuda: bool, use cuda or not
    pin_memory: bool, for pytorch DataLoader
    shuffle: bool, for pytorch DataLoader
    data_loader_worker: int, for pytorch DataLoader, number of workers when loading data
    validation_freqs: None or int. if int, validate your model and send validate results to fate-board every n epoch.
                      if is binary classification task, will use metrics 'auc', 'ks', 'gain', 'lift', 'precision'
                      if is multi classification task, will use metrics 'precision', 'recall', 'accuracy'
                      if is regression task, will use metrics 'mse', 'mae', 'rmse', 'explained_variance', 'r2_score'
    checkpoint_save_freqs: save model every n epoch, if None, will not save checkpoint.
    task_type: str, 'auto', 'binary', 'multi', 'regression'
               this option decides the return format of this trainer, and the evaluation type when running validation.
               if auto, will automatically infer your task type from labels and predict results.
    """

我自己在FATE那里提的issue:https://github.com/FederatedAI/FATE/issues/4832

最后结论

在homo训练,自定义神经网络的场景下使用FedAvg训练器能够实现batch训练。但是Reader是否能加载进来,要看机器,因为Reader应该是一次性全部读取的。


http://www.kler.cn/news/16856.html

相关文章:

  • 现代CMake高级教程 - 第 1 章:添加源文件
  • PowerJob基本概念
  • PHP学习笔记第一天
  • PHP+vue大学生心理健康评价和分析系统8w3ff
  • 每天一点C++——杂记
  • QT文本编辑与排版包含字体相关设置、段落对齐与排序方式
  • 树的刷题,嗝
  • 如果用上以下几种.NET EF Core性能调优,那么查询的性能会飙升
  • bash的进程与欢迎讯息自定义
  • C++命名空间的定义以及使用
  • C++煞笔笔记
  • 功能齐全的 DIY ESP32 智能手表设计之原理图讲解一
  • python实战应用讲解-【numpy数组篇】实用小技巧(九)(附python示例代码)
  • 这一篇LiveData掉不掉价(使用+粘性事件解决)
  • 07 Kubernetes 网络与服务管理
  • 项目沟通管理和干系人管理
  • 如何学习数据结构和算法
  • 《智能手机心率和呼吸率测量算法的前瞻性验证》阅读笔记
  • 23年5月高项备考学习笔记 —— 信息系统治理
  • NLP实战:基于Pytorch的文本分类入门实战
  • PS磨皮插件portraiture最新版磨皮工具
  • 【Python习题集3】常用数据结构习题
  • vcruntime140_1.dll丢失的解决方法
  • 3个经典线程同步问题
  • 用ChatGPT通过WebSocket开发一个交互性的五子棋微信小程序(二)
  • ArduPilot之开源代码基础知识Threading概念
  • Vue3通透教程【十四】TS复杂类型详解(一)
  • MATLAB函数封装2:QT调用封装函数
  • 至少要吃掉多少糖果
  • HPDA的资料