YOLOv11-ultralytics-8.3.67部分代码阅读笔记-dist.py
dist.py
ultralytics\utils\dist.py
目录
dist.py
1.所需的库和模块
2.def find_free_network_port() -> int:
3.def generate_ddp_file(trainer):
4.def generate_ddp_command(world_size, trainer):
5.def ddp_cleanup(trainer, file):
1.所需的库和模块
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import os
import shutil
import socket
import sys
import tempfile
from . import USER_CONFIG_DIR
from .torch_utils import TORCH_1_9
2.def find_free_network_port() -> int:
# 这段代码定义了一个函数 find_free_network_port ,用于查找并返回一个可用的网络端口号。
# 定义了一个名为 find_free_network_port 的函数,该函数不接受任何参数,且返回值类型为 int ,表示它将返回一个整数,即可用的端口号。
def find_free_network_port() -> int:
# 在本地主机上查找一个空闲端口。
# 当我们不想连接到真正的主节点但必须设置“MASTER_PORT”环境变量时,它在单节点训练中很有用。
"""
Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real main node but have to set the
`MASTER_PORT` environment variable.
"""
# 使用 with 语句创建了一个上下文管理器,确保在代码块执行完毕后,资源(即套接字)能够被正确地释放。 socket.socket 创建了一个新的套接字对象, socket.AF_INET 表示使用 IPv4 地址族, socket.SOCK_STREAM 表示使用 TCP 协议。 as s 将创建的套接字对象赋值给变量 s 。
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
# 调用套接字的 bind 方法,将套接字绑定到本地地址 127.0.0.1 (即本机的环回地址)和端口号 0 。端口号为 0 表示让操作系统自动分配一个未被占用的端口号。
s.bind(("127.0.0.1", 0))
# 调用套接字的 getsockname 方法,获取套接字绑定的地址和端口号。 getsockname() 返回一个元组,其中第一个元素是地址( 127.0.0.1 ),第二个元素是端口号。通过 [1] 获取端口号并返回。
return s.getsockname()[1] # port
# 这段代码通过创建一个临时的 TCP 套接字并绑定到本地地址 127.0.0.1 和端口号 0 ,让操作系统自动分配一个未被占用的端口号,然后通过 getsockname 获取该端口号并返回。这种方法是一种简单且有效的动态获取可用端口号的方式,常用于需要动态分配端口的网络应用开发中。
3.def generate_ddp_file(trainer):
# 这段代码定义了一个函数 generate_ddp_file ,用于生成一个分布式数据并行(DDP)训练的临时 Python 文件,并返回该文件的名称。
# 定义了一个函数 generate_ddp_file ,它接受一个参数。
# 1.trainer :通常是一个训练器对象,包含训练相关的配置和方法。
def generate_ddp_file(trainer):
# 生成 DDP 文件并返回其文件名。
"""Generates a DDP file and returns its file name."""
# trainer.__class__.__module__ 获取 trainer 类所在的模块路径。
# trainer.__class__.__name__ 获取 trainer 类的名称。
# 将 模块路径 和 类名 拼接成一个字符串,然后使用 rsplit(".", 1) 从右侧分割一次,得到 模块路径 ( module )和 类名 ( name )。这用于后续动态导入类。
module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
# 定义了一个多行字符串 content ,表示生成的临时 DDP 文件的内容。
content = f"""
# 这是文件的注释,说明这是一个用于多 GPU 训练的临时文件,使用后应自动删除。
# Ultralytics Multi-GPU training temp file (should be automatically deleted after use) Ultralytics 多 GPU 训练临时文件(使用后应自动删除)。
# 将 trainer.args 的属性(通过 vars() 转换为字典)赋值给 overrides 。 trainer.args 包含训练器的配置参数。
overrides = {vars(trainer.args)}
# 这是 Python 中的入口点检查,确保以下代码仅在直接运行该脚本时执行,而不是在导入时执行。
if __name__ == "__main__":
# 动态导入 trainer 类所在的模块和类。 module 和 name 是前面通过 rsplit 获取的模块路径和类名。
from {module} import {name}
# 从 ultralytics.utils 模块导入默认配置字典 DEFAULT_CFG_DICT 。
from ultralytics.utils import DEFAULT_CFG_DICT
# 复制默认配置字典到变量 cfg 。
cfg = DEFAULT_CFG_DICT.copy()
# 更新配置字典 cfg ,将 save_dir 设置为空字符串。这可能是为了处理某些特定的配置需求。
cfg.update(save_dir='') # handle the extra key 'save_dir'
# 创建一个新的 trainer 对象,传入更新后的配置字典 cfg 和覆盖参数 overrides 。
trainer = {name}(cfg=cfg, overrides=overrides)
# 尝试从 trainer.hub_session 中获取 model_url 属性,如果不存在,则使用 trainer.args.model 的值。这可能是为了 动态设置模型路径 。
trainer.args.model = "{getattr(trainer.hub_session, "model_url", trainer.args.model)}"
# 调用 trainer 的 train 方法开始训练,并将结果存储在变量 results 中。
results = trainer.train()
"""
# 创建一个目录 USER_CONFIG_DIR/DDP ,用于存储生成的临时文件。 exist_ok=True 表示如果目录已存在,则不会报错。
(USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
# tempfile.NamedTemporaryFile(mode='w+b', buffering=-1, encoding=None, newline=None, suffix=None, prefix=None, dir=None, delete=True)
# tempfile.NamedTemporaryFile 是 Python 标准库 tempfile 模块中的一个类,用于创建一个临时文件,并返回一个文件对象。这个文件在关闭时可以自动删除,也可以保留下来供其他程序使用。
# 参数 :
# mode :文件模式,默认为 'w+b' (二进制读写模式)。也可以是 'w' (文本写模式)、 'r+' (读写模式)等。
# buffering :缓冲区大小。默认为 -1 ,表示使用系统默认缓冲区大小。
# encoding :文件编码,仅在文本模式下有效。
# newline :换行符处理,仅在文本模式下有效。
# suffix :临时文件的后缀名(不包括点)。例如, suffix='.txt' 。
# prefix :临时文件的前缀名。默认为系统生成的随机前缀。
# dir :临时文件的存储目录。默认为系统默认的临时目录(如 /tmp )。
# delete :文件关闭时是否自动删除。默认为 True ,表示文件在关闭后自动删除。
# 返回值 :
# 返回一个文件对象,支持文件操作(如读写)。文件路径可以通过文件对象的 .name 属性访问。
# 主要用途 :
# 创建临时文件 :用于存储临时数据,文件在关闭时可以自动删除。
# 跨进程共享 :临时文件路径可以通过 .name 属性获取,供其他进程或程序访问。
# 灵活的文件操作 :支持多种文件模式(如文本模式、二进制模式)。
# 注意事项 :
# 文件路径 :文件路径可以通过 .name 属性访问。 文件路径是临时的,通常存储在系统的临时目录中(如 /tmp )。
# 文件删除 :如果 delete=True ,文件在关闭后自动删除。 如果 delete=False ,文件在关闭后不会自动删除,需要手动删除。
# 跨平台兼容性 : tempfile.NamedTemporaryFile 在 Windows 和 Unix 系统上都能正常工作。文件路径可能因操作系统而异。
# 安全性 :临时文件路径是随机生成的,避免了文件名冲突。 如果需要更高的安全性,可以使用 tempfile.mkstemp 或 tempfile.TemporaryDirectory 。
# tempfile.NamedTemporaryFile 是一个非常实用的工具,用于创建临时文件。它支持多种文件模式和灵活的配置选项,可以用于存储临时数据或跨进程共享文件。通过 delete 参数,可以控制文件是否在关闭后自动删除。
# 使用 tempfile.NamedTemporaryFile 创建一个临时文件,设置以下参数。
with tempfile.NamedTemporaryFile(
# 文件名前缀为 _temp_ 。
prefix="_temp_",
# id(object)
# id() 是 Python 的内置函数,用于获取一个对象的唯一标识符(内存地址)。
# 参数 :
# object :任何 Python 对象,包括变量、列表、字典、类实例等。
# 返回值 :
# 返回一个整数,表示对象的唯一标识符(通常是对象在内存中的地址)。
# 作用 :
# id() 函数的主要作用是获取一个对象的唯一标识符。在 CPython 实现中,这个标识符通常是对象在内存中的地址。由于每个对象在内存中都有唯一的地址, id() 可以用来判断两个变量是否指向同一个对象。
# 注意事项 :
# 唯一性 :在程序运行期间, id() 返回的值是唯一的。但程序结束后,该内存地址可能会被释放并重新分配给其他对象。
# 不可变性 :对象的 id 在其生命周期内不会改变。即使对象的内容发生变化(如列表或字典的内容), id 仍然保持不变。
# 用途 :
# id() 通常用于调试,帮助开发者理解变量的引用关系。
# 它也可以用于判断两个变量是否指向同一个对象,但通常更推荐使用 is 关键字 : print(x is y) # 等价于 id(x) == id(y) 。
# id() 函数是一个非常实用的内置函数,用于获取对象的唯一标识符(内存地址)。它在调试和理解变量引用关系时非常有用,但通常不建议在生产代码中直接使用 id() 来比较对象,而是使用 is 关键字。
# 文件名后缀为 id(trainer) 的值,确保文件名唯一。
suffix=f"{id(trainer)}.py",
# 以读写模式打开文件。
mode="w+",
# 指定文件编码为 UTF-8。
encoding="utf-8",
# 指定文件存储目录。
dir=USER_CONFIG_DIR / "DDP",
# 文件不会在关闭后自动删除。
delete=False,
) as file:
# 将前面定义的 content 写入临时文件中。
file.write(content)
# 返回生成的临时文件的路径。
return file.name
# 这段代码的作用是动态生成一个用于分布式数据并行(DDP)训练的临时 Python 文件。文件内容包括训练器的配置和训练逻辑,通过动态导入和参数覆盖的方式,确保训练器能够在多 GPU 环境中正确运行。生成的文件存储在指定目录中,并返回文件路径,以便后续使用。这种方法可以灵活地支持不同训练器的配置和逻辑,适用于需要动态生成训练脚本的场景。
4.def generate_ddp_command(world_size, trainer):
# 这段代码定义了一个函数 generate_ddp_command ,用于生成分布式训练所需的命令,并返回该命令及其对应的临时文件路径。
# 定义了一个函数 generate_ddp_command ,它接受两个参数。
# 1.world_size :表示参与分布式训练的 GPU 数量。
# 2.trainer :一个训练器对象,包含训练相关的配置和方法。
def generate_ddp_command(world_size, trainer):
# 生成并返回分布式训练的命令。
"""Generates and returns command for distributed training."""
# __main__ 模块在 Python 中有以下用途 :
# 用于判断脚本是否被直接运行。
# 用于动态访问主模块的属性,例如文件路径、模块名称等。
# 在某些复杂的应用场景中,可以通过 __main__ 模块实现跨模块的动态访问和操作。
# 在实际开发中, __main__ 模块的使用相对较少,但在某些特定场景下非常有用。
# 导入 Python 的内置模块 __main__ 。虽然在代码中没有直接使用该模块,但可能是为了确保主模块的上下文可用。
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
# 检查 trainer 对象的 resume 属性是否为 False 。如果 resume 为 False ,表示不需要从上次保存的状态恢复训练。
if not trainer.resume:
# 如果 trainer.resume 为 False ,则调用 shutil.rmtree 删除 trainer.save_dir 指定的目录。 trainer.save_dir 是保存训练中间结果的目录,删除该目录是为了避免与新的训练任务冲突。
shutil.rmtree(trainer.save_dir) # remove the save_dir
# 调用之前定义的 generate_ddp_file 函数,生成一个临时的 DDP 文件,并将 文件路径存 储在变量 file 中。 generate_ddp_file 函数会根据 trainer 的配置动态生成一个包含训练逻辑的 Python 文件。
file = generate_ddp_file(trainer)
# 根据 PyTorch 的版本选择合适的分布式训练命令。
# 如果 TORCH_1_9 为 True (表示 PyTorch 版本 >= 1.9),则使用 torch.distributed.run 。
# 否则,使用 torch.distributed.launch 。
# TORCH_1_9 是一个变量,在代码的其他地方定义,用于判断 PyTorch 的版本。
dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
# 调用之前定义的 find_free_network_port 函数,获取一个未被占用的网络端口号。这个端口号将用于分布式训练中的通信。
port = find_free_network_port()
# 构造用于启动分布式训练的命令列表。
# sys.executable :获取当前 Python 解释器的路径。
# "-m" :表示以模块方式运行。
# dist_cmd :分布式训练的命令( torch.distributed.run 或 torch.distributed.launch )。
# "--nproc_per_node" :指定每个节点(GPU)上运行的进程数,值为 world_size 。
# "--master_port" :指定主节点的通信端口号,值为 port 。
# file :生成的临时 DDP 文件路径。
cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file]
# 返回生成的命令列表 cmd 和临时文件路径 file 。
return cmd, file
# 这段代码的作用是生成用于分布式数据并行(DDP)训练的命令。它通过以下步骤实现。检查是否需要删除保存目录(如果不需要恢复训练)。调用 generate_ddp_file 生成一个包含训练逻辑的临时 Python 文件。根据 PyTorch 版本选择合适的分布式训练命令。获取一个未被占用的网络端口号。构造完整的分布式训练命令,并返回命令及其对应的临时文件路径。这种方法可以灵活地支持不同配置和环境下的分布式训练,确保训练任务能够正确启动和运行。
5.def ddp_cleanup(trainer, file):
# 这段代码定义了一个函数 ddp_cleanup ,用于在分布式数据并行(DDP)训练完成后清理临时文件。
# 定义了一个函数 ddp_cleanup ,它接受两个参数。
# 1.trainer :一个训练器对象,通常包含训练相关的配置和方法。
# 2.file :一个字符串,表示临时文件的路径。
def ddp_cleanup(trainer, file):
# 如果创建了临时文件则删除。
"""Delete temp file if created."""
# 使用 id(trainer) 获取 trainer 对象的唯一标识符(内存地址)。
# 将 id(trainer) 转换为字符串,并拼接 .py ,形成一个特定的后缀。
# 检查 file 字符串中是否包含这个后缀。如果包含,说明 file 是与当前 trainer 对象关联的临时文件。
if f"{id(trainer)}.py" in file: # if temp_file suffix in file
# os.remove(path)
# os.remove() 是 Python 标准库 os 模块中的一个函数,用于删除指定路径的文件。
# 参数 :
# path :一个字符串,表示要删除的文件的路径。可以是相对路径或绝对路径。
# 功能 :
# 删除指定路径的文件。
# 如果文件不存在,会抛出 FileNotFoundError 异常。
# 如果路径指向一个目录,而不是文件,会抛出 IsADirectoryError 异常。
# 如果文件被其他进程占用或权限不足,可能会抛出 PermissionError 异常。
# 返回值 :
# os.remove() 不返回任何值(即返回 None )。
# 注意事项 :
# 文件存在性检查 :在调用 os.remove() 之前,可以使用 os.path.exists() 检查文件是否存在,以避免 FileNotFoundError 异常。
# 路径类型检查 :如果不确定路径是否指向文件,可以使用 os.path.isfile() 进行检查。
# 异常处理 :在实际应用中,建议使用 try-except 块捕获可能的异常,以便更好地处理错误情况。
# 总结 : os.remove() 是一个简单而强大的函数,用于删除指定路径的文件。它在文件操作中非常常用,但在使用时需要注意文件存在性、路径类型和权限问题,以避免运行时错误。
# 如果条件满足(即 file 是临时文件),调用 os.remove(file) 删除该文件。 os.remove 是 Python 标准库 os 模块中的函数,用于删除指定路径的文件。
os.remove(file)
# 这段代码的作用是在分布式训练完成后清理临时文件。它通过以下步骤实现。检查传入的 file 是否是与 trainer 对象关联的临时文件(通过检查文件名是否包含 id(trainer).py )。如果是临时文件,则调用 os.remove 删除该文件。这种方法可以确保在分布式训练完成后,动态生成的临时文件不会被遗留,从而保持系统的整洁和安全。