当前位置: 首页 > article >正文

YOLOv8-ultralytics-8.2.103部分代码阅读笔记-build.py

build.py

ultralytics\data\build.py

目录

build.py

1.所需的库和模块

2.class InfiniteDataLoader(dataloader.DataLoader): 

3.class _RepeatSampler: 

4.def seed_worker(worker_id): 

5.def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False): 

6.def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32): 

7.def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): 

8.def check_source(source): 

9.def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False): 


1.所需的库和模块

# Ultralytics YOLO 🚀, AGPL-3.0 license

import os
import random
from pathlib import Path

import numpy as np
import torch
from PIL import Image
from torch.utils.data import dataloader, distributed

from ultralytics.data.dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset
from ultralytics.data.loaders import (
    LOADERS,
    LoadImagesAndVideos,
    LoadPilAndNumpy,
    LoadScreenshots,
    LoadStreams,
    LoadTensor,
    SourceTypes,
    autocast_list,
)
from ultralytics.data.utils import IMG_FORMATS, PIN_MEMORY, VID_FORMATS
from ultralytics.utils import RANK, colorstr
from ultralytics.utils.checks import check_file

2.class InfiniteDataLoader(dataloader.DataLoader): 

# 这段代码定义了一个名为 InfiniteDataLoader 的类,它继承自 PyTorch 的 dataloader.DataLoader 类。这个类的目的是创建一个无限循环的数据加载器,即使数据集的大小是有限的,它也能够无限地产生数据批次。
class InfiniteDataLoader(dataloader.DataLoader):
    # 重用工作器的数据加载器。
    # 使用与原始数据加载器相同的语法。
    """
    Dataloader that reuses workers.

    Uses same syntax as vanilla DataLoader.
    """

    # __init__ 方法。
    def __init__(self, *args, **kwargs):
        # 无限循环使用 worker 的 Dataloader,继承自 DataLoader。
        """Dataloader that infinitely recycles workers, inherits from DataLoader."""
        # 调用父类的构造函数,传递所有传入的参数。
        super().__init__(*args, **kwargs)
        # 将 batch_sampler 属性替换为 _RepeatSampler 类的一个实例,这个实例会包装原始的 batch_sampler 。 _RepeatSampler 是一个自定义的采样器,用于实现无限循环的批次生成。
        object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
        # 调用父类的 __iter__ 方法来初始化迭代器,并将其赋值给实例变量 iterator 。
        self.iterator = super().__iter__()

    # __len__ 方法。
    def __len__(self):
        # 返回批量采样器的采样器的长度。
        """Returns the length of the batch sampler's sampler."""
        # 返回 batch_sampler.sampler 的长度,这通常是数据集的大小。
        return len(self.batch_sampler.sampler)

    # 这段代码是 InfiniteDataLoader 类中的 __iter__ 方法的实现。这个方法使得 InfiniteDataLoader 可以被迭代,并且会无限次地产生数据批次。
    def __iter__(self):
        # 创建一个无限重复的采样器。
        """Creates a sampler that repeats indefinitely."""
        # 这行代码创建了一个无限循环,因为 __len__ 方法返回的是 self.batch_sampler.sampler 的长度,
        # 而在 InfiniteDataLoader 中, self.batch_sampler 被替换成了 _RepeatSampler 实例,它应该能够无限地产生索引,使得 len(self) 实际上是无限的。
        for _ in range(len(self)):

            # yield
            # yield 是 Python 中的一个关键字,它用于在函数中创建一个生成器(generator)。当一个函数中包含 yield 语句时,这个函数就变成了一个生成器函数,它允许你逐个产生函数的值,而不是一次性返回所有值。
            # 与 return 的区别 :
            # return 语句用于从函数返回一个值,并结束函数的执行。
            # yield 语句用于从生成器函数返回一个值,但不结束函数的执行。函数的状态被保存,以便下一次从生成器请求值时继续执行。
            # yield 的这些特性使得它在需要迭代处理数据时非常有用,特别是在数据量大或数据生成成本高的情况下。

            # 在循环内部, yield 关键字用于产生 self.iterator 的下一个元素。由于 self.iterator 是通过调用父类的 __iter__ 方法获得的,它指向一个可以产生数据批次的迭代器。
            # next(self.iterator) 会获取下一个数据批次。
            # yield 关键字使得每次产生一个批次后,迭代器都会暂停,直到下一次请求下一个元素。这样可以在每次迭代中提供一个批次,而不一次性加载所有批次到内存中。
            yield next(self.iterator)
    # 这个方法的结果是, InfiniteDataLoader 可以作为一个无限循环的数据源,不断地产生数据批次,直到程序停止请求新的批次。这对于训练机器学习模型非常有用,因为它允许模型在有限的数据集上进行多次迭代,而不需要手动重置数据加载器。
    # 需要注意的是,这种方法可能会导致无限循环,如果不当使用可能会导致程序无法停止。因此,在使用 InfiniteDataLoader 时,应该确保在训练循环中有适当的停止条件。

    # reset 方法。
    def reset(self):
        # 重置迭代器。
        # 当我们想要在训练时修改数据集的设置时,这很有用。
        """
        Reset iterator.

        This is useful when we want to modify settings of dataset while training.
        """
        # 重置迭代器为初始状态,这可能是为了在某些情况下重新开始迭代。
        self.iterator = self._get_iterator()
# 使用 InfiniteDataLoader 的好处是,它允许模型在训练时不断地遍历数据集,而不需要手动重置数据加载器。这在某些情况下可以简化代码,特别是在需要长时间训练模型的场景中。

3.class _RepeatSampler: 

# 这个 _RepeatSampler 类是一个自定义的采样器,它的作用是将传入的 sampler 包装起来,使其能够无限次地产生样本索引。
class _RepeatSampler:
    # 永远重复的采样器。
    """
    Sampler that repeats forever.

    Args:
        sampler (Dataset.sampler): The sampler to repeat.
    """

    # __init__ 方法。
    def __init__(self, sampler):
        # 初始化一个无限重复给定采样器的对象。
        """Initializes an object that repeats a given sampler indefinitely."""
        # 接受一个 sampler 对象作为参数,并将其存储在实例变量 self.sampler 中。这个 sampler 通常是 PyTorch 的 DataLoader 中用于生成批次索引的采样器。
        self.sampler = sampler

    # __iter__ 方法。这个方法使得 _RepeatSampler 可以被迭代,并且会无限次地产生样本索引。
    def __iter__(self):
        # 迭代“采样器”并产生其内容。
        """Iterates over the 'sampler' and yields its contents."""
        # 创建了一个无限循环,这意味着采样器会不断地重复产生索引,直到程序显式地停止它。
        while True:
            # 会迭代 self.sampler 并产生其所有的元素,然后 yield 关键字会将这些元素返回给迭代 _RepeatSampler 的代码。
            # 当 self.sampler 的所有元素都被产生后, iter(self.sampler) 会自然结束,但由于外层有一个无限循环,所以会再次从头开始产生 self.sampler 的元素,从而实现无限重复。
            yield from iter(self.sampler)
# 这个采样器的实现非常简洁,但它非常强大,因为它允许 DataLoader 在有限的数据集上进行无限次迭代,这对于训练深度学习模型时进行多次 epoch 迭代非常有用。
# 使用 _RepeatSampler 的一个潜在问题是,如果不正确地管理迭代过程,可能会导致无限循环。因此,在使用 InfiniteDataLoader 和 _RepeatSampler 时,需要确保在训练过程中有适当的机制来控制迭代次数,例如通过设置一个最大迭代次数或者监听某些停止信号。

4.def seed_worker(worker_id): 

# seed_worker 函数是一个用于设置 PyTorch DataLoader 工作进程随机种子的函数。这个函数的目的是为了确保在多进程数据加载时,每个工作进程能够生成可复现的随机数序列。这对于确保实验的可重复性非常重要,尤其是在使用随机数据增强或其他需要随机性的操作时。
def seed_worker(worker_id):  # noqa
    # 设置数据加载器工作者种子https://pytorch.org/docs/stable/notes/randomness.html#dataloader。
    """Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
    # 这行代码获取 PyTorch 的全局随机种子,并将其对 2**32 取模,以确保得到的种子是一个32位的整数。 torch.initial_seed() 返回的是当前进程的随机种子,这个种子是由 PyTorch 的随机数生成器生成的。
    worker_seed = torch.initial_seed() % 2**32
    # 这行代码设置 NumPy 的随机种子。NumPy 也用于生成随机数,特别是在数据预处理和增强时,因此确保它的随机性与其他库一致是很重要的。
    np.random.seed(worker_seed)
    # 这行代码设置 Python 标准库 random 模块的随机种子。这同样是为了确保在使用 random 模块时生成的随机数序列是可复现的。
    random.seed(worker_seed)
# 这个函数通常作为 DataLoader 的 worker_init_fn 参数传递,这样每个工作进程在开始工作之前都会调用这个函数来设置随机种子。例如:
# data_loader = DataLoader(dataset, num_workers=4, worker_init_fn=seed_worker)
# 这样设置后,每个工作进程都会在开始工作前调用 seed_worker 函数,确保每个进程的随机数生成是可复现的。

5.def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False): 

# 这段代码定义了一个名为 build_yolo_dataset 的函数,它用于构建 YOLO(You Only Look Once)目标检测模型的数据集。这个函数根据提供的配置和参数,初始化并返回一个 YOLODataset 或 YOLOMultiModalDataset 实例。
# 1.cfg : 配置对象,包含了构建数据集所需的配置参数。
# 2.img_path : 图像文件的路径。
# 3.batch : 批处理大小,即每次迭代处理的图像数量。
# 4.data : 额外的数据,可能用于数据增强或其他目的。
# 5.mode : 模式,可以是"train"(训练)或其它值(验证/测试)。
# 6.rect : 是否使用矩形批次,这通常与批处理中的图像尺寸有关。
# 7.stride : 步长,用于确定网格单元的大小。
# 8.multi_modal : 是否是多模态数据集,如果是,将使用 YOLOMultiModalDataset 类,否则使用 YOLODataset 类。
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):
    # 构建 YOLO 数据集。
    """Build YOLO Dataset."""
    # 根据 multi_modal 参数的值,确定使用 YOLOMultiModalDataset 还是 YOLODataset 类。
    dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
    # 创建数据集实例,传入必要的参数,包括图像路径、图像尺寸、批处理大小、是否进行数据增强等。
    # 返回创建的数据集实例。
    return dataset(
        # img_path : 图像文件的路径,这是数据集将要加载图像的地方。
        img_path=img_path,
        #  imgsz : 图像尺寸,这是YOLO模型输入图像的尺寸,通常是一个整数或一个元组,表示图像的宽度和高度。
        imgsz=cfg.imgsz,
        # batch_size : 批处理大小,这个参数定义了每次迭代中处理的图像数量。
        batch_size=batch,
        # augment : 数据增强标志,如果模式是"train"(训练模式),则为True,表示在训练时应用数据增强技术。
        augment=mode == "train",  # augmentation
        # hyp : 超参数,这里直接传递了配置对象 cfg ,包含学习率、迭代次数等超参数。
        hyp=cfg,  # TODO: probably add a get_hyps_from_cfg function
        # rect : 是否使用矩形批次,这通常与批处理中的图像尺寸有关,如果 cfg.rect 为True或者函数参数 rect 为True,则使用矩形批次。
        rect=cfg.rect or rect,  # rectangular batches
        # cache : 缓存设置,如果 cfg.cache 有值,则使用该值,否则为None,这可能用于控制是否缓存预处理后的图像。
        cache=cfg.cache or None,
        # single_cls : 是否是单类别检测,如果为True,则数据集中只包含一个类别。
        single_cls=cfg.single_cls or False,
        # stride : 步长,转换为整数,这个参数影响YOLO模型中的网格单元大小。
        stride=int(stride),
        # pad : 填充比例,如果是训练模式,则为0.0,否则为0.5,这个参数可能用于控制训练时的图像填充。
        pad=0.0 if mode == "train" else 0.5,
        # prefix : 日志前缀,使用 colorstr 函数格式化模式名称,这可能用于在日志输出中区分不同的模式。
        prefix=colorstr(f"{mode}: "),
        # task : 任务类型,从 cfg 中获取,可能用于区分不同的检测任务。
        task=cfg.task,
        #  classes  : 类别数量,从  cfg  中获取,表示数据集中包含的类别数量。
        classes=cfg.classes,
        # data : 额外的数据,可能用于数据增强或其他目的。
        data=data,
        # fraction : 数据集分割比例,如果是训练模式,则使用  cfg.fraction  的值,否则为1.0,这可能用于控制训练集的大小。
        fraction=cfg.fraction if mode == "train" else 1.0,
    )
# 这个函数是一个高级抽象,它允许用户通过配置文件来灵活地创建和定制YOLO数据集,以适应不同的训练和验证需求。

6.def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32): 

# 这个函数 build_grounding 用于构建一个特定于“grounding”任务的YOLO数据集。Grounding通常指的是将语言描述与视觉实体联系起来的过程,这在多模态学习、视觉问答等领域中是一个常见的任务。
# 这个函数的参数和 build_yolo_dataset 函数类似,但是它创建的是 GroundingDataset 实例,这是一个专门为grounding任务定制的数据集类。
# 1.cfg : 配置对象,包含了构建数据集所需的配置参数。
# 2.img_path : 图像文件的路径,这是数据集将要加载图像的地方。
# 3.json_file : JSON文件的路径,通常包含与图像相关的标注信息,如bounding boxes、类别标签等。
# 4.batch : 批处理大小,即每次迭代处理的图像数量。
# 5.mode : 模式,可以是"train"(训练)或其它值(验证/测试)。
# 6.rect : 是否使用矩形批次,这通常与批处理中的图像尺寸有关。
# 7.stride : 步长,用于确定网格单元的大小。
def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32):
    # 构建 YOLO 数据集。
    """Build YOLO Dataset."""
    # 函数返回 GroundingDataset 实例。
    return GroundingDataset(
        # 图像文件的路径。
        img_path=img_path,
        # JSON文件的路径,包含标注信息。
        json_file=json_file,
        # 图像尺寸,这是YOLO模型输入图像的尺寸。
        imgsz=cfg.imgsz,
        # 批处理大小。
        batch_size=batch,
        # 数据增强标志,如果是训练模式,则为True。
        augment=mode == "train",  # augmentation
        # 超参数,这里直接传递了配置对象 cfg 。
        hyp=cfg,  # TODO: probably add a get_hyps_from_cfg function
        # 是否使用矩形批次。
        rect=cfg.rect or rect,  # rectangular batches
        # 缓存设置,如果 cfg.cache 有值,则使用该值,否则为None。
        cache=cfg.cache or None,
        # 是否是单类别检测。
        single_cls=cfg.single_cls or False,
        # 步长,转换为整数。
        stride=int(stride),
        # 填充比例,如果是训练模式,则为0.0,否则为0.5。
        pad=0.0 if mode == "train" else 0.5,
        # 日志前缀,使用 colorstr 函数格式化模式名称。
        # def colorstr(*input):
        # -> 返回值。函数通过遍历 args 中的每个元素(颜色或样式),从 colors 字典中获取对应的ANSI转义序列,并将其与传入的 string 字符串连接起来。最后,它还会添加一个 colors["end"] 序列,用于重置终端的颜色和样式到默认状态。
        # -> return "".join(colors[x] for x in args) + f"{string}" + colors["end"] 
        prefix=colorstr(f"{mode}: "),
        # 任务类型,从 cfg 中获取。
        task=cfg.task,
        # 类别数量,从 cfg 中获取。
        classes=cfg.classes,
        # 数据集分割比例,如果是训练模式,则使用 cfg.fraction 的值,否则为1.0。
        fraction=cfg.fraction if mode == "train" else 1.0,
    )
# 这个函数的目的是创建一个配置好的 GroundingDataset 对象,该对象可以在grounding任务的训练或验证过程中使用。通过传递不同的参数,可以灵活地调整数据集的行为,以适应不同的训练需求。

7.def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): 

# 这段代码是一个 Python 函数,它定义了一个名为 build_dataloader 的函数,用于构建 PyTorch 的 DataLoader 或 InfiniteDataLoader 对象。这个函数通常用于机器学习或深度学习中,以便在训练或验证模型时加载数据集。
# 定义了一个函数 build_dataloader ,它接受五个参数。
# 1.dataset :数据集对象。
# 2.batch :批次大小。
# 3.workers :加载数据的工作线程数。
# 4.shuffle :是否在每个epoch开始时打乱数据,默认为True。
# 5.rank :分布式训练中的进程排名,默认为-1,表示非分布式训练。
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
    # 返回用于训练或验证集的 InfiniteDataLoader 或 DataLoader。
    """Return an InfiniteDataLoader or DataLoader for training or validation set."""
    # 确保批次大小不会超过数据集的大小。
    batch = min(batch, len(dataset))
    # 获取当前系统中可用的 CUDA 设备数量。
    nd = torch.cuda.device_count()  # number of CUDA devices
    # 计算可用于数据加载的工作线程数,这个数字是 CPU 核心数 除以 CUDA 设备数 (或1,如果CUDA设备数为0)和传入的 workers 参数的最小值。
    nw = min(os.cpu_count() // max(nd, 1), workers)  # number of workers
    # 如果不是分布式训练( rank 为 -1),则不使用 sampler ;如果是分布式训练,则创建一个 DistributedSampler 对象,用于确保每个进程只处理数据集的一部分。
    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)

    # torch.Generator(device='cpu')
    # 在PyTorch中, torch.Generator 是一个用于生成随机数的伪随机数生成器(PRNG)的类。它主要用于生成与特定种子或设备相关的随机数,以确保实验的可重复性。
    # 参数 :
    # device : 指定生成器所在的设备,可以是'cpu'或'cuda'设备。
    # 主要方法 :
    # manual_seed(seed) : 设置生成器的种子。 seed : 一个整数,用于初始化生成器。
    # seed() : 自动设置生成器的种子。 这将生成一个随机种子,确保每次运行代码时生成的随机数不同。
    # get_state() : 返回生成器的当前状态。
    # set_state(state) : 设置生成器的状态。 state : 一个张量,表示生成器的状态。
    # initial_seed() : 返回生成器的初始种子。
    # 注意事项 :
    # 当使用多个设备(如CPU和GPU)时,需要为每个设备创建一个独立的 Generator 实例,并设置不同的种子,以确保随机数的生成是独立的。
    # 在分布式训练或多线程环境中,正确管理生成器的状态非常重要,以避免随机数生成的冲突。
    # torch.Generator 提供了一种灵活的方式来控制随机数的生成,使得实验和模型训练更加可重复和可控。

    # 创建一个 PyTorch 随机数生成器对象。
    generator = torch.Generator()
    # 为随机数生成器设置种子值,这里 RANK 应该是一个变量,但在代码中没有定义,可能是一个外部定义的变量,用于确保在分布式训练中每个进程的随机性是不同的。
    generator.manual_seed(6148914691236517205 + RANK)
    # 返回一个 InfiniteDataLoader 对象,这个对象是 DataLoader 的一个变体,用于无限循环地加载数据集。
    return InfiniteDataLoader(
        # 数据集对象。
        dataset=dataset,
        # 批次大小。
        batch_size=batch,
        # 是否打乱数据。
        shuffle=shuffle and sampler is None,
        # 工作线程数。
        num_workers=nw,
        # 采样器对象。
        sampler=sampler,
        # 一个布尔值,表示是否将数据加载到 CUDA 固定内存中,以加快数据传输到 GPU 的速度。
        pin_memory=PIN_MEMORY,
        # 一个函数,用于决定如何将多个样本数据合并成一个批次。
        collate_fn=getattr(dataset, "collate_fn", None),
        # 每个工作线程启动时调用的函数,这里使用 seed_worker 函数来设置每个工作线程的随机种子。
        worker_init_fn=seed_worker,
        # 上面创建的随机数生成器对象。
        generator=generator,
    )
# 这个函数的设计目的是为了提供一个灵活且高效的数据加载器,支持批量加载、多线程和分布式训练。通过设置随机数生成器和采样器,它可以确保数据加载的随机性和分布式训练的正确性。

8.def check_source(source): 

# 这段代码定义了一个名为 check_source 的函数,其目的是检查输入的 source 参数,并返回与 source 类型相对应的标志值。这些标志值用于确定 source 是来自网络摄像头、屏幕截图、图像文件、内存中的图像、还是 PyTorch 张量。
# 参数。
# 1.source :输入源,可以是多种类型,包括字符串、整数、Path对象、图像文件、URL、列表、元组、PIL图像、NumPy数组或PyTorch张量。
def check_source(source):
    # 检查源类型并返回相应的标志值。
    """Check source type and return corresponding flag values."""
    # 定义了五个布尔变量,初始值都为 False ,用于标记 source 的类型。 分别表示是否为 摄像头 、 屏幕截图 、 图像 、 内存中的数据 和 张量。
    webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
    # 使用 isinstance 检查 source 是否为 字符串 、 整数 或 Path 对象 。如果是,将其转换为字符串,并进行进一步的检查。
    if isinstance(source, (str, int, Path)):  # int for local usb camera
        source = str(source)
        # is_file 检查路径后缀是否在支持的图像或视频格式中。
        is_file = Path(source).suffix[1:] in (IMG_FORMATS | VID_FORMATS)
        # is_url 检查字符串是否以特定的URL协议开头。
        is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))

        # str.isnumeric()
        # 在Python中, isnumeric() 是字符串( str )类的一个方法,用于检查字符串中的所有字符是否都是数字。如果所有字符都是数字,则返回 True ;否则返回 False 。这个方法不考虑字符的编码,只关注字符是否表示数字。
        # 返回值 :
        # bool 类型,表示字符串是否只包含数字。
        # 注意事项 :
        # isnumeric() 方法只检查字符是否为数字,不考虑字符的编码或数值大小。
        # 空字符串( "" )被认为是数字,这可能与直觉不符,但符合Unicode标准。
        # 该方法不考虑数字的格式,如负号或小数点,这些都会被视作非数字字符。
        # isnumeric() 方法常用于验证输入是否为纯数字字符串,例如在处理用户输入或解析数据时。

        # webcam 标记为 True 如果 source 是数字(代表本地USB摄像头)、以特定后缀结束或是一个URL但不是文件。
        webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
        # screenshot 标记为 True 如果 source 字符串等于 "screen" 。
        screenshot = source.lower() == "screen"
        # 如果 source 是一个URL且是一个文件,调用 check_file 函数 。
        if is_url and is_file:
            # def check_file(file, suffix="", download=True, download_dir=".", hard=True):
            # -> 检查文件的存在性,如果需要的话,下载该文件,并返回文件的路径。
            # ->  return file / return str(file) / return files[0] if len(files) else []  # return file
            source = check_file(source)  # download
    # 如果 source 是 LOADERS 类型,则 in_memory 标记为 True 。
    elif isinstance(source, LOADERS):
        in_memory = True
    # 如果 source 是一个列表或元组,调用 autocast_list 函数,并标记 from_img 为 True 。
    elif isinstance(source, (list, tuple)):
        # def autocast_list(source): -> 用于将不同类型源的数据合并成一个包含 NumPy 数组或 PIL 图像的列表。返回包含处理后的图像数据的 files 列表。 -> return files
        source = autocast_list(source)  # convert all list elements to PIL or np arrays
        from_img = True
    # 如果 source 是一个 Image.Image 对象或NumPy数组,标记 from_img 为 True 。
    elif isinstance(source, (Image.Image, np.ndarray)):
        from_img = True
    # 如果 source 是一个PyTorch张量,标记 tensor 为 True 。
    elif isinstance(source, torch.Tensor):
        tensor = True
    # 如果 source 不是上述任何类型,抛出一个 TypeError 异常。
    else:
        raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict")    # 不支持的图像类型。有关支持的类型,请参阅 https://docs.ultralytics.com/modes/predict.

    # 函数返回处理后的  source  和 webcam (摄像头) , screenshot (屏幕截图) , from_img (图像) , in_memory (内存中的数据) , tensor (张量) 这五个布尔值。
    return source, webcam, screenshot, from_img, in_memory, tensor
# 这个函数的目的是为进一步的处理确定输入源的类型,并根据源的类型设置相应的标志值。这样,后续的代码就可以根据这些标志值来决定如何处理输入源。

9.def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False): 

# 这段代码定义了一个名为 load_inference_source 的函数,它的作用是根据提供的 source 参数加载不同类型的数据源,并返回一个相应的 dataset 对象。这个函数支持多种数据源类型,包括张量、内存中的数据、视频流、屏幕截图、图像和视频文件。
# 参数说明。
# 1.source : 数据源,可以是文件路径、URL、图像或视频流、内存中的张量等。
# 2.batch : 用于加载图像和视频时的批量大小,默认为1。
# 3.vid_stride : 视频加载时的帧间隔,默认为1。
# 4.buffer : 是否对视频流进行缓冲,默认为False。
def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False):
    # 加载用于对象检测的推理源并应用必要的转换。
    """
    Loads an inference source for object detection and applies necessary transformations.

    Args:
        source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
        batch (int, optional): Batch size for dataloaders. Default is 1.
        vid_stride (int, optional): The frame interval for video sources. Default is 1.
        buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.

    Returns:
        dataset (Dataset): A dataset object for the specified input source.
    """
    # 检查数据源类型。
    # 调用 check_source 函数来确定 source 的类型,并返回 source 本身以及五个布尔值,分别表示是否为 摄像头 stream 、 屏幕截图 screenshot 、 图像 from_img 、 内存中的数据 in_memory 和 张量 tensor 。
    source, stream, screenshot, from_img, in_memory, tensor = check_source(source)
    # 确定数据源类型。
    # 根据 in_memory 的值,确定 source_type 。如果 in_memory 为True,则 source_type 直接从 source 对象中获取;否则,使用 SourceTypes 类来创建 source_type 。
    # class SourceTypes:
    # -> 这个类可以用于配置预测模型的输入源,或者在需要区分不同输入类型的场景中使用。
    # -> SourceTypes(stream=False, screenshot=False, from_img=False, tensor=False)
    source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor)

    # Dataloader
    # 加载数据集。
    # 如果 tensor 为True,则使用 LoadTensor 类来加载张量数据。
    if tensor:
        # class LoadTensor: -> 用于处理图像数据,特别是以张量(Tensor)形式存在的图像数据。 -> def __init__(self, im0) -> None:
        dataset = LoadTensor(source)
    # 如果 in_memory 为True,则直接使用 source 作为数据集。
    elif in_memory:
        dataset = source
    # 如果 stream 为True,则使用 LoadStreams 类来加载视频流数据。
    elif stream:
        # class LoadStreams: -> 用于加载和处理视频流。 -> def __init__(self, sources="file.streams", vid_stride=1, buffer=False):
        dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer)
    # 如果 screenshot 为True,则使用 LoadScreenshots 类来加载屏幕截图数据。
    elif screenshot:
        # class LoadScreenshots: -> 用于捕获屏幕截图并将屏幕内容作为 NumPy 数组返回。 -> def __init__(self, source):
        dataset = LoadScreenshots(source)
    # 如果 from_img 为True,则使用 LoadPilAndNumpy 类来加载PIL图像和NumPy数组数据。
    elif from_img:
        # class LoadPilAndNumpy: -> 用于处理 PIL 图像和 NumPy 数组格式的图像数据。 -> def __init__(self, im0):
        dataset = LoadPilAndNumpy(source)
    # 否则,使用 LoadImagesAndVideos 类来加载图像和视频文件。
    else:
        # class LoadImagesAndVideos: -> 用于加载图像和视频文件,并初始化一个数据加载器。 -> def __init__(self, path, batch=1, vid_stride=1):
        dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride)

    # Attach source types to the dataset

    # setattr(object, name, value)
    # setattr 是 Python 内置的一个函数,用于将属性赋值给对象。这个函数可以用来动态地设置对象的属性值,包括那些在代码运行时才知道名称的属性。
    # object :要设置属性的对象。
    # name :要设置的属性的名称,它应该是一个字符串。
    # value :要赋给属性的值。
    # 功能 :
    # setattr 函数将 value 赋给 object 的 name 指定的属性。如果 name 指定的属性在 object 中不存在,则会创建一个新的属性。返回值 setattr 函数没有返回值。
    # 注意事项 :
    # 使用 setattr 时需要注意属性名称的字符串格式,因为属性名称会被直接用作对象的属性键。
    # setattr 可以用于任何对象,包括自定义类的实例、内置类型的对象等。
    # 如果需要删除对象的属性,可以使用 delattr 函数,其用法与 setattr 类似,但是用于删除属性而不是设置属性。

    # 附加数据源类型。将 source_type 附加到 dataset 对象上。
    setattr(dataset, "source_type", source_type)

    # 返回数据集。返回构建好的 dataset 对象。
    return dataset
# 这个函数的设计体现了模块化和灵活性,可以根据不同的数据源类型选择合适的加载方式,并确保数据源类型信息的传递和存储。


http://www.kler.cn/a/428198.html

相关文章:

  • 【如何制定虚拟货币的补仓策略并计算回本和盈利】
  • Linux网络编程之---组播和广播
  • 快速排序的基本思想和java实现
  • Next.js系统性教学:全面掌握客服务端组件(Server Components)
  • ARMv8-A MacOS调试环境搭建
  • Python毕业设计选题:基于大数据的淘宝电子产品数据分析的设计与实现-django+spark+spider
  • PyCharm文件、临时文件、目录、文件夹(Directory)、软件包(Package)的区别
  • Spring Boot配置文件敏感信息加密
  • 智创 AI 新视界 -- AI 与量子计算的未来融合前景(16 - 5)
  • python拆分Excel文件
  • docker安装ddns-go(外网连接局域网)
  • JVM 参数配置详细介绍
  • C++ Learning 函数重载•引用
  • PyTorch基本使用——张量的索引操作
  • opencv光流法推测物体的运动
  • Spring Boot日志:从Logger到@Slf4j的探秘
  • ChatGPT 最新推出的 Pro 订阅计划,具备哪些能力 ?
  • uniapp 微信小程序webview 和 h5数据通信
  • 【AWS re:Invent 2024】一文了解EKS新功能:Amazon EKS Auto Mode
  • Python实现BBS论坛自动签到【steamtools论坛】