当前位置: 首页 > article >正文

分布式训练类的定义以及创建分布式模型

一 、分布式训练类的定义

from ..modules import Module
from typing import Any, Optional
from .common_types import _devices_t, _device_t


class DistributedDataParallel(Module):
    process_group: Any = ...
    dim: int = ...
    module: Module = ...
    device_ids: _devices_t = ...
    output_device: _device_t = ...
    broadcast_buffers: bool = ...
    check_reduction: bool = ...
    broadcast_bucket_size: float = ...
    bucket_bytes_cap: float = ...

    # TODO type process_group once `distributed` module is stubbed
    def __init__(self, module: Module, device_ids: Optional[_devices_t] = ...,
                 output_device: Optional[_device_t] = ..., dim: int = ...,
                 broadcast_buffers: bool = ..., process_group: Optional[Any] = ..., bucket_cap_mb: float = ...,
                 find_unused_parameters: bool = ..., check_reduction: bool = ...) -> None: ...

Python 类的定义,该类名为 DistributedDataParallel,是 PyTorch 中用于分布式数据并行训练的模块

from ..modules import Module :导入 PyTorch 中的 Module 类,表示神经网络模块的基类

from typing import Any, Optional : 导入 Any 和 Optional 类型,用于类型注解

from .common_types import _devices_t, _device_t:导入 _devices_t 和 _device_t 类型

class DistributedDataParallel(Module): 定义了一个类 DistributedDataParallel,它继承自Module类

类属性:

    process_group: Any = ... : 代表分布式训练的进程组
    dim: int = ...: 代表分布式的维度
    module: Module = ...:代表要进行并行处理的神经网络模块
    device_ids: _devices_t = ...:代表设备的 ID 列表
    output_device: _device_t = ...:代表输出设备
    broadcast_buffers: bool = ...:是否广播缓冲区
    check_reduction: bool = ...:是否检查减少操作
    broadcast_bucket_size: float = ...: 广播桶大小
    bucket_bytes_cap: float = ...:桶的字节容量上限
# TODO type process_group once `distributed` module is stubbed
def __init__(self, module: Module, device_ids: Optional[_devices_t] = ...,
             output_device: Optional[_device_t] = ..., dim: int = ...,
             broadcast_buffers: bool = ..., process_group: Optional[Any] = ...,   bucket_cap_mb: float = ...,find_unused_parameters: bool = ..., check_reduction: bool = ...) -> None: ...:

def __init__(self, ...):类的初始化方法,用于初始化对象的属性。参数包括神经网络模块 module、设备 ID 列表 device_ids、输出设备 output_device 等。这些参数都有默认值,可以在初始化对象时提供或使用默认值

-> None: 表示初始化方法没有返回值

总体而言,这段代码定义了一个分布式数据并行训练的模块 DistributedDataParallel,该模块可以在多个设备上并行处理神经网络模块,实现分布式训练

二、创建分布式模型

这段代码创建了一个分布式数据并行(DDP)模型,并在必要时进行版本检查

根据 PyTorch 版本的不同,采取不同的配置参数来创建 DDP 模型

def smart_DDP(model):  # 定义了一个名为 smart_DPP 的函数,该函数接受一个参数model,表示神经网络模型
    # Model DDP creation with checks  版本检查,其目的是确保不使用不受支持的PyTorch版本进行DDP训练
    assert not check_version(torch.__version__, '1.12.0', pinned=True), \
        'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
        'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
    #
    '''
    DDP 模型的创建:
    if check_version(torch.__version__, '1.11.0'): ...: 如果 PyTorch 版本为 1.11.0,则使用 DDP 类创建 DDP 模型,并设置 static_graph=True
    else: ...: 如果 PyTorch 版本不为 1.11.0,则使用 DDP 类创建 DDP 模型
    '''
    if check_version(torch.__version__, '1.11.0'):
        '''
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True): 返回创建的 DDP 模型对象
        device_ids 表示设备 ID 列表,这里设置为 [LOCAL_RANK],而 output_device 表示输出设备,也设置为 LOCAL_RANK
        如果 PyTorch 版本为 1.11.0,则 static_graph 被设置为 True
        '''
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
    else:
        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)

其中,static_graph=True 是在创建分布式数据并行模型时传递给 DDP 类的一个参数

这个参数的作用是告诉 PyTorch 是否使用静态图,如果将 static_graph 设置为 True,则表示希望使用静态图,这在某些情况下可以提高分布式训练的效率,尤其是在一些特定的 PyTorch 版本中可能需要使用静态图以避免问题

在 PyTorch 中,static_graph 参数是用于控制动态图(Dynamic Computational Graph)和静态图(Static Computational Graph)的一个设置。动态图和静态图是两种不同的计算图构建方式:

  1. 动态图(Dynamic Computational Graph)

    • 在动态图中,计算图是在运行时动态构建的,每次迭代都可以改变图的结构
    • PyTorch 的默认行为是使用动态图,这使得在模型训练过程中可以更灵活地调整模型结构
  2. 静态图(Static Computational Graph)

    • 在静态图中,计算图在模型定义阶段就被固定,不再改变。这意味着一旦定义了计算图,就无法在运行时修改
    • 静态图的优点之一是可以进行一些优化,例如静态图可以被预先分析以进行优化,从而提高计算效率

http://www.kler.cn/a/161779.html

相关文章:

  • vue3 pdf base64转成文件流打开
  • 【2024最新】math-expression-evaluator 动态计算表达式的使用
  • Python用CEEMDAN-LSTM-VMD金融股价数据预测及SVR、AR、HAR对比可视化
  • 深度学习服务器租赁AutoDL
  • 2024年11月10日系统架构设计师考试题目回顾
  • 什么是 Real-Time Factor (RTF)
  • QT 重定向qdebug输出到自绘界面
  • 区分node,npm,nvm
  • uni-app实现安卓原生态调用身份证阅读器读卡库读身份证和社保卡、银行卡、IC卡等功能
  • 匹配不包含同时出现两次 “ago“ 的行
  • Redis server启动源码
  • vue 商品列表案例
  • JavaSE基础50题:11. 输出一个整数的每一位
  • CentOS 7.9 安装 k8s(详细教程)
  • Vue.js实现可编辑表格并高亮修改的单元格
  • 基于remix+metamask+ganache的智能合约部署调用
  • 注解 @Autowired 和 @Resource
  • OpenGL ES 帧缓冲对象介绍和使用示例
  • AI烟火识别智能视频分析系统解决方案
  • Dockerfile详解#如何编写自己的Dockerfile
  • Matlab 用矩阵画图
  • JAVA 多线程并发(一)
  • Jmeter接口测试
  • STL(一)(pair篇)
  • 【Docker二】docker网络模式、网络通信、数据管理
  • Oracle11g RAC无法使用VIP或SCAN IP连接数据库的解决方案