当前位置: 首页 > article >正文

YOLOv11-ultralytics-8.3.67部分代码阅读笔记-build.py

build.py

ultralytics\data\build.py

目录

build.py

1.所需的库和模块

2.class InfiniteDataLoader(dataloader.DataLoader): 

3.class _RepeatSampler: 

4.def seed_worker(worker_id): 

5.def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False): 

6.def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32): 

7.def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): 

8.def check_source(source): 

9.def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False): 


1.所需的库和模块

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

import os
import random
from pathlib import Path

import numpy as np
import torch
from PIL import Image
from torch.utils.data import dataloader, distributed

from ultralytics.data.dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset
from ultralytics.data.loaders import (
    LOADERS,
    LoadImagesAndVideos,
    LoadPilAndNumpy,
    LoadScreenshots,
    LoadStreams,
    LoadTensor,
    SourceTypes,
    autocast_list,
)
from ultralytics.data.utils import IMG_FORMATS, PIN_MEMORY, VID_FORMATS
from ultralytics.utils import RANK, colorstr
from ultralytics.utils.checks import check_file

2.class InfiniteDataLoader(dataloader.DataLoader): 

# 这段代码定义了一个名为 InfiniteDataLoader 的类,它继承自 PyTorch 的 DataLoader 类,并通过修改其行为来实现无限循环的数据加载功能。
# 定义了一个名为 InfiniteDataLoader 的类,继承自 PyTorch 的 DataLoader 类。这表明它将复用 DataLoader 的功能,同时添加一些自定义行为。
class InfiniteDataLoader(dataloader.DataLoader):
    # 重用工作器的数据加载器。
    # 使用与原始数据加载器相同的语法。
    """
    Dataloader that reuses workers.

    Uses same syntax as vanilla DataLoader.
    """

    # 定义了 InfiniteDataLoader 的初始化方法,接收任意数量的位置参数 ( *args ) 和关键字参数 ( **kwargs )。这些参数将被传递给父类 DataLoader 的初始化方法。
    def __init__(self, *args, **kwargs):
        # 无限循环使用 worker 的 Dataloader,继承自 DataLoader。
        """Dataloader that infinitely recycles workers, inherits from DataLoader."""
        # 调用父类 DataLoader 的初始化方法,完成标准的 DataLoader 初始化过程。
        super().__init__(*args, **kwargs)

        # object.__setattr__(name, value)
        # 在Python中, object.__setattr__() 是一个特殊方法,用于设置对象的属性。它是 object 类的一个方法,而 object 是Python中所有类的基类。 __setattr__() 方法在设置对象属性时被自动调用,但也可以在子类中被重写以自定义属性赋值的行为。
        # 参数 :
        # name :要设置的属性的名称。
        # value :属性的值。
        # 行为 :
        # 当对一个对象的属性进行赋值操作时,例如 obj.attr = value ,Python会自动调用该对象的 __setattr__() 方法。这个方法的默认实现会设置一个名为 name 的属性,其值为 value 。
        # 为什么使用 object.__setattr__ :
        # 在某些情况下,你可能需要直接调用 __setattr__() 方法,特别是当你需要绕过属性赋值的默认行为时。例如,你可能想要在设置属性之前执行一些额外的检查或操作。
        # 注意事项 :
        # 使用 object.__setattr__() 时,应该谨慎,因为它会绕过属性的正常赋值机制,包括可能的属性监视器或装饰器。
        # 在大多数情况下,直接使用 obj.attr = value 就足够了,除非有特殊需求需要自定义属性赋值的行为。

        # 通过 object.__setattr__ 方法,将 self.batch_sampler 替换为 _RepeatSampler(self.batch_sampler) 。 _RepeatSampler 是一个包装器,用于将 batch_sampler 的行为从有限迭代转换为无限迭代。 这是实现无限数据加载的关键步骤。
        object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
        # 调用父类 DataLoader 的 __iter__ 方法,初始化一个 迭代器 self.iterator 。这个迭代器将 用于后续的无限循环 。
        self.iterator = super().__iter__()

    # 定义了 __len__ 方法,用于返回 InfiniteDataLoader 的长度。
    def __len__(self):
        # 返回批量采样器的采样器的长度。
        """Returns the length of the batch sampler's sampler."""
        # 返回 _RepeatSampler 中原始 batch_sampler 的长度。虽然 _RepeatSampler 本身是无限的,但这里返回的是原始采样器的长度,用于提供一个有限的参考值。
        return len(self.batch_sampler.sampler)

    # 定义了 __iter__ 方法,用于实现无限循环的迭代行为。
    def __iter__(self):
        # 创建一个无限重复的采样器。
        """Creates a sampler that repeats indefinitely."""
        # 循环 len(self) 次。虽然 _RepeatSampler 是无限的,但这里通过 len(self) 控制循环次数,确保每次迭代都能从 _RepeatSampler 中获取数据。
        for _ in range(len(self)):
            # 通过 next(self.iterator) 获取下一个批次的数据,并将其作为生成器的输出。
            yield next(self.iterator)

    # 定义了 __del__ 方法,用于在对象被销毁时执行清理操作。
    def __del__(self):
        # 确保工作线程被终止。
        """Ensure that workers are terminated."""
        # 开始一个 try 块,用于捕获可能的异常。
        try:
            # 检查 self.iterator 是否有 _workers 属性。如果没有,则直接返回,跳过后续清理操作。
            if not hasattr(self.iterator, "_workers"):
                return
            # 遍历 self.iterator 中的所有工作线程 _workers 。
            for w in self.iterator._workers:  # force terminate
                # 如果工作线程仍在运行,则调用 terminate() 方法强制终止线程。
                if w.is_alive():
                    w.terminate()
            # 调用 _shutdown_workers() 方法,清理工作线程。
            self.iterator._shutdown_workers()  # cleanup
        # 捕获并忽略所有异常,确保 __del__ 方法不会因异常而失败。
        except Exception:
            pass

    # 定义了一个 reset 方法,用于重置迭代器。
    def reset(self):
        # 重置迭代器。
        # 当我们想要在训练时修改数据集的设置时,这很有用。
        """
        Reset iterator.

        This is useful when we want to modify settings of dataset while training.
        """
        # 通过调用 self._get_iterator() 方法 重新获取一个新的迭代器 ,从而重置数据加载器的状态。
        self.iterator = self._get_iterator()
# InfiniteDataLoader 是一个自定义的 DataLoader 类,通过以下方式实现了无限循环的数据加载功能。无限循环的采样器:在初始化时,将 DataLoader 的 batch_sampler 替换为 _RepeatSampler 包装器,从而实现无限循环的采样行为。迭代器管理:在 __iter__ 方法中,通过循环调用 next(self.iterator) ,实现无限循环的数据生成。提供了 reset 方法,用于重置迭代器的状态。资源清理:在 __del__ 方法中,强制终止所有工作线程并清理资源,避免线程泄漏。兼容性: __len__ 方法返回原始采样器的长度,确保与标准 DataLoader 的行为兼容。通过继承 DataLoader ,复用了其大部分功能,同时添加了无限循环的特性。这种设计特别适合需要无限循环数据加载的场景,例如在训练深度学习模型时,数据集需要被不断重复使用。

3.class _RepeatSampler: 

# 这段代码定义了一个名为 _RepeatSampler 的类,它是一个简单的迭代器包装器,用于无限重复一个给定的采样器( sampler )。
# 定义了一个名为 _RepeatSampler 的类。从命名来看,它可能是一个辅助类(以单下划线开头的名称通常表示“受保护”的类或方法),用于重复采样器的行为。
class _RepeatSampler:
    # 永远重复的采样器。
    """
    Sampler that repeats forever.

    Args:
        sampler (Dataset.sampler): The sampler to repeat.
    """

    # 定义了类的初始化方法 __init__ ,接收一个参数。
    # 1.sampler :是一个可迭代对象(如生成器、列表或其他迭代器),它将被 _RepeatSampler 包装并无限重复。
    def __init__(self, sampler):
        # 初始化一个无限重复给定采样器的对象。
        """Initializes an object that repeats a given sampler indefinitely."""
        # 将传入的 sampler 存储为类的实例属性 self.sampler ,以便后续使用。
        self.sampler = sampler

    # 定义了类的 __iter__ 方法,使 _RepeatSampler 成为一个可迭代对象。当调用 iter(_RepeatSampler) 或在 for 循环中使用 _RepeatSampler 时,会调用此方法。
    def __iter__(self):
        # 迭代‘采样器’并产生其内容。
        """Iterates over the 'sampler' and yields its contents."""
        # 开始一个无限循环。这意味着 _RepeatSampler 会不断地重复 sampler 的行为,直到外部显式中断。
        while True:
            # iter(self.sampler) :获取 sampler 的迭代器。
            # yield from :将 sampler 的迭代器中的每个元素逐一生成。当 sampler 的迭代结束时, while True 循环会重新开始,从而实现无限重复。
            yield from iter(self.sampler)
# _RepeatSampler 是一个简单的迭代器包装器,其核心功能是将一个有限的采样器( sampler )转换为一个无限重复的迭代器。具体来说。输入: sampler 是一个可迭代对象,可以是生成器、列表或其他任何支持迭代的对象。行为: _RepeatSampler 使用 while True 实现无限循环。每次循环中,它通过 yield from 将 sampler 的元素逐一生成。当 sampler 的迭代结束时,循环会重新开始,从而实现无限重复。用途:这种设计通常用于需要无限重复数据采样的场景,例如在数据增强、循环训练或无限数据流的生成中。4. 特点: _RepeatSampler 是一个轻量级的包装器,不修改原始 sampler 的行为,只是无限重复其输出。它依赖于 sampler 的可迭代性,因此 sampler 必须是一个有效的可迭代对象。
# 示例用法 :
# 假设有一个简单的采样器 :
# sampler = [1, 2, 3]  # 一个简单的可迭代对象
# repeat_sampler = _RepeatSampler(sampler)
# for i, value in enumerate(repeat_sampler):
#     print(value)
#     if i >= 10:  # 手动中断,否则会无限打印
#         break
# 输出将是 :
# 1
# 2
# 3
# 1
# 2
# 3
# 1
# 2
# 3
# 1
# 在这个例子中, _RepeatSampler 将 [1, 2, 3] 无限重复,直到手动中断。

4.def seed_worker(worker_id): 

# 这段代码定义了一个名为 seed_worker 的函数,用于在多线程或多进程数据加载时为每个工作线程或进程设置随机种子。
# 定义了一个函数 seed_worker ,接收一个参数。
# 1.worker_id :表示当前工作线程或进程的唯一标识。
# # noqa 是一个注释,通常用于告诉某些代码检查工具(如 flake8 )忽略这一行的检查。
def seed_worker(worker_id):  # noqa
    # 设置数据加载器工作器种子 https://pytorch.org/docs/stable/notes/randomness.html#dataloader。
    """Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
    # torch.initial_seed() :获取当前 PyTorch 的初始随机种子。
    # % 2**32 :将种子值限制在 0 到 2^32 - 1 的范围内。这是因为 NumPy 的随机种子需要是一个 32 位的整数。
    # worker_seed :最终的随机种子值,用于后续的随机数生成。
    worker_seed = torch.initial_seed() % 2**32
    # 使用 worker_seed 设置 NumPy 的随机种子。这确保了在 NumPy 的随机操作中,每个工作线程或进程都会生成独立且可复现的随机数序列。
    np.random.seed(worker_seed)
    # 使用 worker_seed 设置 Python 标准库 random 模块的随机种子。这同样确保了在 random 模块的随机操作中,每个工作线程或进程的随机行为是独立且可复现的。
    random.seed(worker_seed)
# seed_worker 函数的作用是为多线程或多进程环境中的每个工作线程或进程设置独立的随机种子。它的主要功能包括。获取初始种子:通过 torch.initial_seed() 获取 PyTorch 的初始随机种子,并将其限制在 32 位整数范围内。设置 NumPy 和 Python 的随机种子:使用相同的种子值分别设置 NumPy 和 Python random 模块的随机种子。这确保了在多线程或多进程环境中,每个工作线程或进程的随机行为是独立的,并且可以通过相同的初始种子复现。用途:这种设计通常用于深度学习中的数据加载器(如 PyTorch 的 DataLoader ),尤其是在使用多进程加载数据时。通过为每个工作进程设置独立的随机种子,可以避免随机数生成的冲突,同时保证数据增强或其他随机操作的可复现性。

5.def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False): 

# 这段代码定义了一个名为 build_yolo_dataset 的函数,用于根据配置和输入参数构建 YOLO 数据集。
# 定义了一个函数 build_yolo_dataset ,接收以下参数 :
# 1.cfg :配置对象,包含数据集和训练的相关参数。
# 2.img_path :图像路径。
# 3.batch :批量大小。
# 4.data :数据配置,可能包含类别信息等。
# 5.mode :数据集模式,默认为 "train" ,表示训练模式。
# 6.rect :是否使用矩形批次,默认为 False 。
# 7.stride :模型的步幅,默认为 32 。
# 8.multi_modal :是否使用多模态数据集,默认为 False 。
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):
    # 构建 YOLO 数据集。
    """Build YOLO Dataset."""
    # 根据 multi_modal 参数的值,选择使用 YOLOMultiModalDataset 或 YOLODataset 类。 如果 multi_modal 为 True ,则使用 YOLOMultiModalDataset (可能支持多模态输入,如图像和文本)。 否则,使用标准的 YOLODataset 。
    dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
    # 调用选定的 dataset 类的构造函数,并传入一系列参数来初始化数据集。
    return dataset(
        # 将传入的 img_path 参数传递给数据集类,指定 图像路径 。
        img_path=img_path,
        # 从配置对象 cfg 中获取 imgsz 属性,并将其传递给数据集类,指定 图像大小 。
        imgsz=cfg.imgsz,
        # 将传入的 batch 参数传递给数据集类,指定 批量大小 。
        batch_size=batch,
        # 根据 mode 参数的值决定是否启用数据增强。 如果 mode 是 "train" ,则启用数据增强( augment=True )。 否则,禁用数据增强( augment=False )。
        augment=mode == "train",  # augmentation
        # 将 配置对象 cfg 传递给数据集类,作为超参数( hyp )。 注释中提到可能需要添加一个 get_hyps_from_cfg 函数来更清晰地提取超参数。
        hyp=cfg,  # TODO: probably add a get_hyps_from_cfg function    TODO:可能添加 get_hyps_from_cfg 函数。
        # 决定是否使用矩形批次。 如果 cfg.rect 为 True ,则启用矩形批次。 否则,检查传入的 rect 参数。
        rect=cfg.rect or rect,  # rectangular batches
        # 从配置对象 cfg 中获取 cache 属性,指定是否缓存数据。 如果 cfg.cache 未定义,则默认为 None 。
        cache=cfg.cache or None,
        # 从配置对象 cfg 中获取 single_cls 属性,指定是否为单类别训练。 如果 cfg.single_cls 未定义,则默认为 False 。
        single_cls=cfg.single_cls or False,
        # 将传入的 stride 参数转换为整数,并传递给数据集类,指定 模型的步幅 。
        stride=int(stride),
        # 根据 mode 参数的值设置填充比例。 如果是训练模式( mode == "train" ),填充比例为 0.0 。 否则,填充比例为 0.5 。
        pad=0.0 if mode == "train" else 0.5,
        # 使用 colorstr 函数为日志输出添加颜色,并指定前缀为当前模式(如 "train: " 或 "val: " )。
        prefix=colorstr(f"{mode}: "),
        # 从配置对象 cfg 中获取 task 属性,并传递给数据集类,指定 任务类型 。
        task=cfg.task,
        # 从配置对象 cfg 中获取 classes 属性,并传递给数据集类,指定 类别信息 。
        classes=cfg.classes,
        # 将传入的 data 参数传递给数据集类,可能包含 数据集的配置信息 。
        data=data,
        # 根据 mode 参数的值设置数据集的使用比例。 如果是训练模式,使用 cfg.fraction (可能用于指定训练数据的子集比例)。 否则,默认为 1.0 ,表示使用全部数据。
        fraction=cfg.fraction if mode == "train" else 1.0,
    )
# build_yolo_dataset 函数的作用是根据配置和输入参数动态构建 YOLO 数据集。其主要功能包括。动态选择数据集类:根据 multi_modal 参数的值,选择使用 YOLOMultiModalDataset 或 YOLODataset 。灵活配置数据集:支持多种模式(如训练模式和验证模式)。根据模式动态调整数据增强、填充比例和数据集使用比例。从配置对象 cfg 中提取多种参数,如图像大小、超参数、缓存设置等。可扩展性:函数通过参数化设计,支持不同的数据集配置和模式。提供了对多模态数据集的支持(通过 multi_modal 参数)。用途:该函数主要用于 YOLO 模型的训练和验证阶段,用于构建适合 YOLO 模型的数据集。通过灵活的参数配置,可以适应不同的训练需求,如单类别训练、矩形批次、数据缓存等。这种设计使得数据集的构建过程更加灵活和可配置,适合在多种场景下使用 YOLO 模型进行目标检测任务。

6.def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32): 

# 这段代码定义了一个名为 build_grounding 的函数,用于构建一个名为 GroundingDataset 的数据集。这个函数的结构与之前解释的 build_yolo_dataset 函数类似,但它是专门为“grounding”任务设计的,可能用于视觉-语言对齐或其他多模态任务。
# 定义了一个名为 build_grounding 的函数,接收以下参数 :
# 1.cfg :配置对象,包含数据集和训练的相关参数。
# 2.img_path :图像路径。
# 3.json_file :JSON 文件路径,可能包含标注信息或其他元数据。
# 4.batch :批量大小。
# 5.mode :数据集模式,默认为 "train" ,表示训练模式。
# 6.rect :是否使用矩形批次,默认为 False 。
# 7.stride :模型的步幅,默认为 32 。
def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32):
    # 构建 YOLO 数据集。
    """Build YOLO Dataset."""
    # 调用 GroundingDataset 类的构造函数,并传入一系列参数来初始化数据集。 GroundingDataset 是一个专门用于 grounding 任务的数据集类。
    return GroundingDataset(
        # 将传入的 img_path 参数传递给数据集类,指定 图像路径 。
        img_path=img_path,
        # 将传入的 json_file 参数传递给数据集类,指定 JSON 文件路径 。这个文件可能包含标注信息或其他元数据。
        json_file=json_file,
        # 从配置对象 cfg 中获取 imgsz 属性,并将其传递给数据集类,指定 图像大小 。
        imgsz=cfg.imgsz,
        # 将传入的 batch 参数传递给数据集类,指定 批量大小 。
        batch_size=batch,
        # 根据 mode 参数的值决定是否启用数据增强。 如果 mode 是 "train" ,则启用数据增强( augment=True )。 否则,禁用数据增强( augment=False )。
        augment=mode == "train",  # augmentation
        # 将 配置对象 cfg 传递给数据集类,作为超参数( hyp )。 注释中提到可能需要添加一个 get_hyps_from_cfg 函数来更清晰地提取超参数。
        hyp=cfg,  # TODO: probably add a get_hyps_from_cfg function
        # 决定是否使用矩形批次。 如果 cfg.rect 为 True ,则启用矩形批次。 否则,检查传入的 rect 参数。
        rect=cfg.rect or rect,  # rectangular batches
        # 从配置对象 cfg 中获取 cache 属性,指定是否缓存数据。 如果 cfg.cache 未定义,则默认为 None 。
        cache=cfg.cache or None,
        # 从配置对象 cfg 中获取 single_cls 属性,指定是否为单类别训练。 如果 cfg.single_cls 未定义,则默认为 False 。
        single_cls=cfg.single_cls or False,
        # 将传入的 stride 参数转换为整数,并传递给数据集类,指定 模型的步幅 。
        stride=int(stride),
        # 根据 mode 参数的值设置 填充比例 。 如果是训练模式( mode == "train" ),填充比例为 0.0 。 否则,填充比例为 0.5 。
        pad=0.0 if mode == "train" else 0.5,
        # 使用 colorstr 函数为日志输出添加颜色,并指定前缀为当前模式(如 "train: " 或 "val: " )。
        prefix=colorstr(f"{mode}: "),
        # 从配置对象 cfg 中获取 task 属性,并传递给数据集类, 指定任务类型 。
        task=cfg.task,
        # 从配置对象 cfg 中获取 classes 属性,并传递给数据集类,指定 类别信息 。
        classes=cfg.classes,
        # 根据 mode 参数的值设置数据集的使用比例。 如果是训练模式,使用 cfg.fraction (用于指定训练数据的子集比例)。 否则,默认为 1.0 ,表示使用全部数据。
        fraction=cfg.fraction if mode == "train" else 1.0,
    )
# build_grounding 函数的作用是根据配置和输入参数动态构建一个名为 GroundingDataset 的数据集,可能用于视觉-语言对齐或其他多模态任务。其主要功能包括。动态配置数据集:支持多种模式(如训练模式和验证模式)。根据模式动态调整数据增强、填充比例和数据集使用比例。从配置对象 cfg 中提取多种参数,如图像大小、超参数、缓存设置等。多模态数据支持:除了图像路径,还支持 JSON 文件路径,可能用于存储标注信息或其他元数据。可扩展性:函数通过参数化设计,支持不同的数据集配置和模式。提供了对矩形批次和数据缓存的支持。用途:该函数主要用于 grounding 任务,可能涉及视觉和语言的结合,例如视觉问答(VQA)、视觉定位或视觉-语言对齐任务。通过灵活的参数配置,可以适应不同的训练需求,如单类别训练、矩形批次、数据缓存等。这种设计使得数据集的构建过程更加灵活和可配置,适合在多种场景下使用 grounding 模型进行多模态任务。

7.def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): 

# 这段代码定义了一个名为 build_dataloader 的函数,用于根据输入参数构建一个适合训练或验证的数据加载器( DataLoader )。函数的核心功能是根据环境配置和需求,动态调整数据加载器的参数,并支持分布式训练和无限循环加载数据。
# 定义了一个名为 build_dataloader 的函数,接收以下参数 :
# 1.dataset :数据集对象,用于加载数据。
# 2.batch :批量大小。
# 3.workers :工作线程数,用于并行加载数据。
# 4.shuffle :是否打乱数据,默认为 True 。
# 5.rank :分布式训练中的进程排名,默认为 -1 (表示非分布式训练)。
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
    # 返回用于训练或验证集的 InfiniteDataLoader 或 DataLoader。
    """Return an InfiniteDataLoader or DataLoader for training or validation set."""
    # 将 批量大小 batch 限制为数据集长度 len(dataset) 的最小值,避免批量大小超过数据集的总样本数。
    batch = min(batch, len(dataset))
    # 通过 torch.cuda.device_count() 获取可用的 CUDA 设备数量(GPU 数量),并存储在变量 nd 中。
    nd = torch.cuda.device_count()  # number of CUDA devices
    # 计算 实际使用的 工作线程数 nw 。
    # os.cpu_count() 获取系统的 CPU 核心数。
    # max(nd, 1) 确保至少有一个设备(GPU 或 CPU)。
    # os.cpu_count() // max(nd, 1) 计算每个设备分配的 CPU 核心数。
    # 最终, nw 的值是上述计算结果与用户指定的 workers 的最小值。
    nw = min(os.cpu_count() // max(nd, 1), workers)  # number of workers
    # 根据 rank 的值决定是否使用分布式采样器。 如果 rank == -1 (非分布式训练), sampler 为 None 。 否则,使用 torch.utils.data.distributed.DistributedSampler ,并根据 shuffle 参数决定是否打乱数据。
    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
    # 创建一个 PyTorch 的 随机数生成器 generator ,用于控制数据加载器的随机行为。
    generator = torch.Generator()
    # 为随机数生成器 设置种子值 ,确保随机行为的可复现性。 6148914691236517205 是一个固定的种子值。 RANK 是当前进程的排名(可能来自分布式训练环境),确保每个进程的随机行为独立。
    generator.manual_seed(6148914691236517205 + RANK)
    # 返回一个 InfiniteDataLoader 对象,用于支持无限循环的数据加载。
    return InfiniteDataLoader(
        # 将传入的 dataset 对象传递给 InfiniteDataLoader 。
        dataset=dataset,
        # 将 批量大小 batch 传递给 InfiniteDataLoader 。
        batch_size=batch,
        # 决定是否在数据加载器中启用数据打乱。 如果 shuffle=True 且 sampler=None (非分布式训练),则启用数据打乱。 否则,禁用数据打乱(分布式训练中由 DistributedSampler 控制)。
        shuffle=shuffle and sampler is None,
        # 将计算得到的 工作线程数 nw 传递给 InfiniteDataLoader 。
        num_workers=nw,
        # 将 采样器 sampler 传递给 InfiniteDataLoader 。
        sampler=sampler,
        # 将 PIN_MEMORY 参数传递给 InfiniteDataLoader ,控制 是否使用内存锁定 ( pin_memory )。 PIN_MEMORY 是一个全局变量,通常在代码其他地方定义。
        # PIN_MEMORY -> 用于控制数据加载器是否启用内存锁定(pin memory)。
        pin_memory=PIN_MEMORY,
        # 从数据集对象中获取 collate_fn 方法(如果存在)。 collate_fn 用于将多个样本组合成一个批次。 如果数据集没有定义 collate_fn ,则默认为 None 。
        collate_fn=getattr(dataset, "collate_fn", None),
        # 将 seed_worker 函数设置为工作线程的初始化函数,用于 为每个工作线程设置独立的随机种子 。
        worker_init_fn=seed_worker,
        # 将随机数生成器 generator 传递给 InfiniteDataLoader ,用于 控制数据加载的随机行为 。
        generator=generator,
    )
# build_dataloader 函数的作用是根据输入参数动态构建一个适合训练或验证的数据加载器( InfiniteDataLoader )。其主要功能包括。动态调整参数:根据数据集大小和可用设备数量(CPU/GPU)动态调整批量大小和工作线程数。支持分布式训练,通过 DistributedSampler 控制数据分发。随机行为控制:使用随机数生成器 generator 和种子值确保数据加载的随机行为可复现。为每个工作线程设置独立的随机种子,避免随机数生成冲突。支持无限循环加载:返回的是 InfiniteDataLoader ,支持无限循环的数据加载,适合需要重复使用数据集的训练场景。灵活配置:支持自定义 collate_fn 方法,用于处理数据批次的组合。根据是否启用分布式训练动态调整数据打乱行为。用途:该函数适用于深度学习任务中的数据加载,尤其是在训练和验证阶段。通过灵活的参数配置,可以适应不同的硬件环境和训练需求。这种设计使得数据加载器的构建过程更加灵活和可扩展,适合在多种场景下使用,尤其是在分布式训练和无限循环加载数据的需求中。

# 在深度学习中,内存锁定( pin_memory ) 是一个与数据加载和 GPU 加速相关的重要概念。它主要用于优化数据从 CPU 内存传输到 GPU 内存的过程。
# 内存锁定的作用 :内存锁定( pin_memory=True )的作用是将数据存储在**页锁定内存(Pinned Memory)**中,而不是普通的 CPU 内存。页锁定内存是一种特殊的内存区域,它不会被操作系统交换到磁盘(即不会被“交换”或“分页”),并且可以直接被 GPU 访问。
# 为什么需要内存锁定?
# 提高数据传输速度 :当数据存储在普通 CPU 内存中时,操作系统可能会将内存中的数据交换到磁盘,这会导致数据访问延迟增加。 而页锁定内存(Pinned Memory)不会被交换到磁盘,因此可以更快地被 GPU 访问。
# 零拷贝传输 :使用页锁定内存时,数据可以直接从 CPU 内存传输到 GPU 内存,而不需要中间的拷贝操作。这种传输方式被称为零拷贝(Zero-Copy)。 零拷贝可以显著减少数据传输的时间和资源消耗。
# 减少 CPU 和 GPU 之间的同步开销 :当数据存储在页锁定内存中时,GPU 可以更高效地从 CPU 内存中读取数据,从而减少 CPU 和 GPU 之间的同步等待时间。
# 在 PyTorch 中的应用 :
# 在 PyTorch 中, pin_memory 是 DataLoader 的一个参数。当设置 pin_memory=True 时, DataLoader 会将数据加载到页锁定内存中,从而加速数据从 CPU 到 GPU 的传输。
# from torch.utils.data import DataLoader
# dataloader = DataLoader(
#     dataset=my_dataset,
#     batch_size=32,
#     shuffle=True,
#     pin_memory=True,  # 启用内存锁定
#     num_workers=4
# )
# 注意事项 :
# 仅适用于 GPU 训练 :内存锁定( pin_memory )仅在使用 GPU 进行训练时有效。如果仅使用 CPU 训练,启用 pin_memory 不会有任何效果。
# 内存占用增加 :使用页锁定内存会增加 CPU 内存的占用量,因为它会锁定一部分内存,防止操作系统将其交换到磁盘。如果系统内存不足,可能会导致性能下降。
# 最佳实践 :在 GPU 资源充足且系统内存足够的情况下,启用 pin_memory 可以显著提高数据加载效率。 如果系统内存有限,建议谨慎使用 pin_memory ,或者减少 num_workers 的数量以降低内存占用。
# 总结 :内存锁定( pin_memory )是一种优化技术,通过将数据存储在页锁定内存中,可以加速数据从 CPU 内存传输到 GPU 内存的过程。它特别适用于 GPU 训练场景,能够显著提高数据加载效率,减少 CPU 和 GPU 之间的同步开销。然而,它也会增加内存占用,因此需要根据实际硬件配置和需求进行合理使用。

8.def check_source(source): 

# 这段代码定义了一个名为 check_source 的函数,用于检查和处理输入的图像或视频源,并根据输入类型返回相应的标志和处理后的数据源。
# 定义了一个函数 check_source ,接收一个参数。
# 1.source :表示输入的图像或视频源。
def check_source(source):
    # 检查源类型并返回相应的标志值。
    """Check source type and return corresponding flag values."""
    # 初始化了五个布尔变量,分别用于 标记输入源的类型 。
    # webcam :是否为摄像头输入。
    # screenshot :是否为屏幕截图。
    # from_img :是否为图像文件或图像数据。
    # in_memory :是否为内存中的数据(如加载器返回的数据)。
    # tensor :是否为 PyTorch 张量。
    webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
    # 判断输入 source 是否为字符串、整数或 Path 对象。 字符串或 Path 对象通常表示文件路径或 URL。 整数通常表示本地 USB 摄像头的设备索引。
    if isinstance(source, (str, int, Path)):  # int for local usb camera
        # 将 source 转换为字符串,以便后续处理。
        source = str(source)
        # 判断 source 是否为图像或视频文件。 Path(source).suffix[1:] 获取文件扩展名(去掉点号)。 IMG_FORMATS | VID_FORMATS 是一个集合,包含支持的图像和视频格式。 如果文件扩展名在支持的格式中,则 is_file 为 True 。
        is_file = Path(source).suffix[1:] in (IMG_FORMATS | VID_FORMATS)
        # 判断 source 是否为 URL(支持的协议包括 HTTP、HTTPS、RTSP、RTMP 和 TCP)。
        is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
        # 判断 source 是否为摄像头输入。
        # 如果 source 是数字(表示摄像头索引),则 webcam 为 True 。
        # 如果 source 以 .streams 结尾(表示流媒体输入),则 webcam 为 True 。
        # 如果 source 是 URL 且不是文件,则 webcam 为 True 。
        webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
        # 判断 source 是否为屏幕截图模式( source 为 "screen" )。
        screenshot = source.lower() == "screen"
        # 如果 source 是 URL 且是文件(如在线视频或图像文件),调用 check_file 函数下载文件。
        if is_url and is_file:
            # def check_file(file, suffix="", download=True, download_dir=".", hard=True):
            # -> 用于检查文件是否存在,如果不存在则尝试下载文件,并返回文件的路径。直接返回 file 。返回下载后的文件路径,将其转换为字符串形式。如果找到文件,返回第一个匹配的文件路径。 如果未找到文件,返回空列表 [] 。
            # -> return file / return str(file) / return files[0] if len(files) else []  # return file
            source = check_file(source)  # download
    # 判断 source 是否为 LOADERS 类型( LOADERS 是一个元组,包含加载器类)。 如果是,则标记为 in_memory ,表示 数据已在内存中 。
    elif isinstance(source, LOADERS):
        # 设置 in_memory 为 True 。
        in_memory = True
    # 判断 source 是否为列表或元组。
    elif isinstance(source, (list, tuple)):
        # 调用 autocast_list 函数,将列表中的所有元素转换为 PIL 图像对象或 NumPy 数组。
        # def autocast_list(source): -> 将输入的图像源列表转换为统一的图像对象格式,以便后续处理。函数返回 files 列表,其中包含 所有转换后的图像对象 。 -> return files
        source = autocast_list(source)  # convert all list elements to PIL or np arrays
        # 设置 from_img 为 True ,表示 输入为图像数据 。
        from_img = True
    # 判断 source 是否为 PIL 图像对象或 NumPy 数组。
    elif isinstance(source, (Image.Image, np.ndarray)):
        # 设置 from_img 为 True 。
        from_img = True
    # 判断 source 是否为 PyTorch 张量。
    elif isinstance(source, torch.Tensor):
        # 设置 tensor 为 True 。
        tensor = True
    # 如果 source 不属于上述任何类型,则进入 else 分支。
    else:
        # 抛出 TypeError 异常,提示用户输入类型不支持,并建议查阅相关文档。
        raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict")    # 不支持的图像类型。有关支持的类型,请参阅 https://docs.ultralytics.com/modes/predict 。

    # 返回处理后的 source 和各个标志变量。
    return source, webcam, screenshot, from_img, in_memory, tensor
# check_source 函数的作用是检查输入的图像或视频源,并根据输入类型返回相应的标志和处理后的数据源。其主要功能包括。支持多种输入类型:文件路径(图像或视频)。URL(在线图像或视频)。摄像头输入(本地 USB 摄像头或流媒体)。屏幕截图模式。内存中的数据(如加载器返回的数据)。PIL 图像对象、NumPy 数组或 PyTorch 张量。自动处理输入:如果输入是 URL 且是文件,会调用 check_file 函数下载文件。如果输入是列表或元组,会调用 autocast_list 函数将元素转换为统一的图像格式。返回标志变量:根据输入类型返回布尔标志(如 webcam 、 screenshot 、 from_img 等),便于后续处理。异常处理:如果输入类型不支持,会抛出异常并提示用户。这种设计使得函数能够灵活处理多种输入源,适用于多种场景,如图像处理、视频流处理或实时摄像头输入。

9.def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False): 

# 这段代码定义了一个名为 load_inference_source 的函数,用于根据输入的图像或视频源加载数据,并返回一个适合推理(inference)的数据集对象。
# 定义了一个函数 load_inference_source ,接收以下参数 :
# 1.source :输入的图像或视频源,默认为 None 。
# 2.batch :批量大小,默认为 1 。
# 3.vid_stride :视频帧的采样间隔,默认为 1 (即每帧都采样)。
# 4.buffer :是否使用缓冲区,用于流媒体输入,默认为 False 。
def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False):
    # 加载用于对象检测的推理源并应用必要的转换。
    """
    Loads an inference source for object detection and applies necessary transformations.

    Args:
        source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
        batch (int, optional): Batch size for dataloaders. Default is 1.
        vid_stride (int, optional): The frame interval for video sources. Default is 1.
        buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.

    Returns:
        dataset (Dataset): A dataset object for the specified input source.
    """
    # 调用 check_source 函数,对输入的 source 进行检查,并获取以下返回值 :
    # source :处理后的输入源。
    # stream :布尔值,表示 是否为流媒体输入 (如摄像头或视频流)。
    # screenshot :布尔值,表示 是否为屏幕截图 。
    # from_img :布尔值,表示 是否为图像文件或图像数据 。
    # in_memory :布尔值,表示 输入是否已经加载到内存中 。
    # tensor :布尔值,表示 输入是否为 PyTorch 张量 。
    source, stream, screenshot, from_img, in_memory, tensor = check_source(source)
    # 根据输入源的类型,确定 source_type 。
    # 如果 in_memory 为 True ,则直接从 source 中获取 source_type 。
    # 否则,通过 SourceTypes 枚举类(或类似结构)根据 stream 、 screenshot 、 from_img 和 tensor 的值确定 source_type 。
    source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor)

    # Dataloader    这是一个注释,表示接下来的代码用于选择合适的 数据加载器 ( dataset )。
    # 如果输入是 PyTorch 张量( tensor=True ),则使用 LoadTensor 类加载数据。
    if tensor:
        dataset = LoadTensor(source)
    # 如果输入已经加载到内存中( in_memory=True ),则直接将 source 作为数据集对象。
    elif in_memory:
        dataset = source
    # 如果输入是流媒体( stream=True ),则使用 LoadStreams 类加载数据,并根据 vid_stride 和 buffer 参数配置视频帧的采样间隔和缓冲区。
    elif stream:
        dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer)
    # 如果输入是屏幕截图( screenshot=True ),则使用 LoadScreenshots 类加载数据。
    elif screenshot:
        dataset = LoadScreenshots(source)
    # 如果输入是图像文件或图像数据( from_img=True ),则使用 LoadPilAndNumpy 类加载数据。
    elif from_img:
        dataset = LoadPilAndNumpy(source)
    # 如果输入是其他类型的图像或视频文件,则使用 LoadImagesAndVideos 类加载数据,并根据 batch 和 vid_stride 参数配置批量大小和视频帧的采样间隔。
    else:
        dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride)

    # Attach source types to the dataset    这是一个注释,表示接下来的代码将 source_type 附加到数据集对象上。
    # 使用 setattr 函数,将 source_type 作为属性附加到 dataset 对象上,以便后续使用。
    setattr(dataset, "source_type", source_type)

    # 返回构建好的数据集对象 dataset 。
    return dataset
# load_inference_source 函数的作用是根据输入的图像或视频源加载数据,并返回一个适合推理(inference)的数据集对象。其主要功能包括。输入源检查:调用 check_source 函数,对输入源进行检查,并获取输入源的类型标志(如是否为流媒体、屏幕截图、图像文件等)。动态选择数据加载器:根据输入源的类型,选择合适的数据加载器类(如 LoadTensor 、 LoadStreams 、 LoadScreenshots 、 LoadPilAndNumpy 或 LoadImagesAndVideos )。配置数据加载器的参数(如批量大小、视频帧采样间隔、缓冲区等)。附加输入源类型:将输入源的类型( source_type )附加到数据集对象上,便于后续处理。灵活性和扩展性:支持多种输入源类型,包括图像文件、视频文件、流媒体、屏幕截图和 PyTorch 张量。支持自定义批量大小和视频帧采样间隔。用途:该函数适用于推理阶段,用于加载和处理输入数据,使其适合模型推理。通过动态选择数据加载器,可以适应多种输入场景,如实时摄像头输入、屏幕截图或本地文件。这种设计使得函数能够灵活处理多种输入源,适用于多种推理场景,如目标检测、图像分类或视频分析。


http://www.kler.cn/a/555474.html

相关文章:

  • Linux 内核是如何检测可用物理内存地址范围的?
  • Three.js 快速入门教程【三】渲染器
  • kubernetes1.28部署mysql5.7主从同步,使用Nfs制作持久卷存储,适用于centos7/9操作系统,
  • Deepseek 与 ChatGPT:AI 浪潮中的双子星较量
  • JavaScript 开发秘籍:日常总结与实战技巧-1
  • postgresql实时同步数据表mysql
  • HttpSession类的对象session:保存的数据谁有权限读取?
  • 面试基础-如何设计一个短链接系统
  • 使用 Docker-compose 部署 MySQL
  • Openai Dashboard可视化微调大语言模型
  • C++游戏开发流程图
  • idea从远程gitee拉取项目
  • SVN服务器搭建【Linux】
  • Node os模块
  • Android开发-深入解析Android中的AIDL及其应用场景
  • SpringCloud系列教程:微服务的未来(二十四)Direct交换机、Topic交换机、声明队列交换机
  • 蓝桥杯备赛 Day15 动态规划
  • STM32 HAL库UART串口数据接收实验
  • Golang访问Google Sheet
  • Java 中的内存泄漏问题及解决方案