YOLOv9-0.1部分代码阅读笔记-general.py
general.py
utils\general.py
目录
general.py
1.所需的库和模块
2.def is_ascii(s=''):
3.def is_chinese(s='人工智能'):
4.def is_colab():
5.def is_notebook():
6.def is_kaggle():
7.def is_docker() -> bool:
8.def is_writeable(dir, test=False):
9.def set_logging(name=LOGGING_NAME, verbose=True):
10.def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
11.class Profile(contextlib.ContextDecorator):
12.class Timeout(contextlib.ContextDecorator):
13.class WorkingDirectory(contextlib.ContextDecorator):
14.def methods(instance):
15.def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
16.def init_seeds(seed=0, deterministic=False):
17.def intersect_dicts(da, db, exclude=()):
18.def get_default_args(func):
19.def get_latest_run(search_dir='.'):
20.def file_age(path=__file__):
21.def file_date(path=__file__):
22.def file_size(path):
23.def check_online():
24.def git_describe(path=ROOT):
25.def check_git_status(repo='WongKinYiu/yolov9', branch='main'):
26.def check_git_info(path='.'):
27.def check_python(minimum='3.7.0'):
28.def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
29.def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=''):
30.def check_img_size(imgsz, s=32, floor=0):
31.def check_imshow(warn=False):
32.def check_suffix(file='yolo.pt', suffix=('.pt',), msg=''):
33.def check_yaml(file, suffix=('.yaml', '.yml')):
34.def check_file(file, suffix=''):
35.def check_font(font=FONT, progress=False):
36.def check_dataset(data, autodownload=True):
37.def check_amp(model):
38.def yaml_load(file='data.yaml'):
39.def yaml_save(file='data.yaml', data={}):
40.def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
41.def url2file(url):
42.def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
43.def make_divisible(x, divisor):
44.def clean_str(s):
45.def one_cycle(y1=0.0, y2=1.0, steps=100):
46.def one_flat_cycle(y1=0.0, y2=1.0, steps=100):
47.def colorstr(*input):
48.def labels_to_class_weights(labels, nc=80):
49.def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
50.def coco80_to_coco91_class():
51.def xyxy2xywh(x):
52.def xywh2xyxy(x):
53.def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
54.def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
55.def xyn2xy(x, w=640, h=640, padw=0, padh=0):
56.def segment2box(segment, width=640, height=640):
57.def segments2boxes(segments):
58.def resample_segments(segments, n=1000):
59.def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
60.def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
61.def clip_boxes(boxes, shape):
62.def clip_segments(segments, shape):
63.def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, labels=(), max_det=300, nm=0,):
64.def strip_optimizer(f='best.pt', s=''):
65.def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
66.def apply_classifier(x, model, img, im0):
67.def increment_path(path, exist_ok=False, sep='', mkdir=False):
68.def imread(path, flags=cv2.IMREAD_COLOR):
69.def imwrite(path, im):
70.def imshow(path, im):
1.所需的库和模块
import contextlib
import glob
import inspect
import logging
import logging.config
import math
import os
import platform
import random
import re
import signal
import sys
import time
import urllib
from copy import deepcopy
from datetime import datetime
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
from subprocess import check_output
from tarfile import is_tarfile
from typing import Optional
from zipfile import ZipFile, is_zipfile
import cv2
import IPython
import numpy as np
import pandas as pd
import pkg_resources as pkg
import torch
import torchvision
import yaml
from utils import TryExcept, emojis
from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness
# Path.resolve(strict=False)
# resolve() 函数是 Python pathlib 模块中 Path 类的一个方法,用于将路径对象转换为绝对路径,并解析路径中的所有符号链接(symlinks),返回一个新的路径对象。
# 参数 :
# strict :布尔值,默认为 False 。如果设置为 True ,则在路径不存在时引发 FileNotFoundError 。如果为 False ,则尽可能解析路径并附加任何剩余部分,而不检查它是否存在。
# 返回值 :
# 返回一个新的 Path 对象,该对象表示原始路径的绝对路径,并且已经解析了所有符号链接。
# 功能 :
# resolve() 方法将路径对象转换为绝对路径,并解析路径中的所有符号链接。
# 它还会处理路径中的所有特殊符号,例如 . (当前目录)和 .. (上级目录)。
# 如果路径不存在且 strict 参数为 True ,则引发 FileNotFoundError 。
# 如果 strict 参数为 False ,则尽可能解析路径并附加任何剩余部分,而不检查它是否存在。
# 注意事项 :
# resolve() 方法总是返回一个绝对路径,无论传入的路径是相对路径还是绝对路径。
# 如果路径是相对路径,则会根据当前工作目录将其解析为绝对路径。
# resolve() 方法会对路径进行标准化处理,消除冗余的分隔符、处理上级目录符号(..)和当前目录符号(.),以保证返回的路径是规范化的。
# 这行代码使用了 Python 的 pathlib 模块来获取当前文件的路径。
# FILE : 这是一个变量,用于存储当前文件的路径。
# Path : 这是 pathlib 模块中的一个类,用于表示文件系统路径。
# __file__ : 这是一个特殊的 Python 变量,它包含了当前脚本的路径。如果脚本位于包中, __file__ 将是相对路径;如果脚本是直接运行的脚本, __file__ 将是绝对路径。
# resolve() : 这是 Path 对象的一个方法,用于将相对路径解析为绝对路径,并解决其中的符号链接。
# 获取当前 Python 脚本的绝对路径,并将其存储在变量 FILE 中。
FILE = Path(__file__).resolve()
# 这行代码是在 FILE 变量的基础上进一步获取当前文件的父目录的父目录,并将其赋值给 ROOT 变量。
# ROOT : 这是一个变量,用于存储当前文件的父目录的父目录(即上两级的目录)。
# FILE.parents : 这是 pathlib.Path 对象的一个属性,它返回一个包含当前路径所有父目录的列表。
# FILE.parents[1] : 这是访问 parents 列表中第二个元素(索引为 1,因为索引从 0 开始)的操作,即当前文件的父目录的父目录。
# 这行代码的作用是获取当前文件的上两级目录,并将其存储在变量 ROOT 中。
ROOT = FILE.parents[1] # YOLO root directory
# 这行代码使用了 Python 的 os 模块来获取环境变量 RANK 的值,并将其转换为整数类型。
# os.getenv() 是 os 模块中的一个函数,用于获取指定的环境变量的值。它接受两个参数 :第一个参数是环境变量的名称(在这个例子中是 'RANK' ),第二个参数是可选的,默认值(在这个例子中是 -1 )。
# 如果环境变量 RANK 存在, os.getenv('RANK', -1) 将返回该环境变量的值。如果环境变量 RANK 不存在,函数将返回 -1 作为默认值。
# int() 函数将 os.getenv() 返回的值转换为整数类型。如果返回的值不能被转换为整数(例如,如果环境变量的值不是数字), int() 函数将抛出 ValueError 异常。
# 这行代码通常用于分布式计算或并行处理环境中,其中 RANK 环境变量用于标识进程或节点的序号。例如,在 MPI(消息传递接口)程序中,每个进程通常会设置一个 RANK 环境变量,以区分不同的进程。
RANK = int(os.getenv('RANK', -1))
# 这段代码是一系列设置,用于配置一个名为 YOLOv5 的对象检测模型的环境参数。
# Settings
# 设置 NUM_THREADS 变量,它代表 YOLOv5 多进程处理时使用的线程数。 os.cpu_count() 返回当前机器的 CPU 核心数。 max(1, os.cpu_count() - 1) 确保至少有一个线程,并且不超过 CPU 核心数减一。 min(8, ...) 确保线程数不超过 8,即使 CPU 核心数减一大于 8。这样,代码可以自动适应不同机器的 CPU 核心数,同时限制最大线程数以避免过度线程竞争。
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
# os.getenv(key, default=None)
# os.getenv() 是 Python 标准库 os 模块中的一个函数,用于获取环境变量的值。环境变量是操作系统级别的变量,通常用于配置程序的运行环境。
# 参数 :
# key : 要获取的环境变量的名称(字符串)。
# default : 可选参数,如果指定的环境变量不存在,则返回该默认值。默认为 None 。
# 返回值 :
# 返回指定环境变量的值(字符串),如果环境变量不存在且未提供默认值,则返回 None ;如果提供了默认值,则返回该默认值。
# 使用场景 :
# os.getenv() 常用于读取配置参数,例如数据库连接字符串、API 密钥等,这些参数可以在操作系统环境中设置,以便在代码中使用。
# 它可以用于根据环境变量的值来调整程序的行为,例如在开发和生产环境中使用不同的配置。
# 注意事项 :
# 使用 os.getenv() 读取环境变量时,返回值始终是字符串类型。如果环境变量的值是数字或其他类型,仍然会被转换为字符串。
# 如果需要获取所有环境变量,可以使用 os.environ ,它返回一个包含所有环境变量的字典。
# os.getenv() 是处理环境变量的便捷方法,常用于配置和管理程序的运行环境。
# 设置 DATASETS_DIR 变量,它代表全局数据集目录。 os.getenv('YOLOv5_DATASETS_DIR', ...) 尝试获取名为 YOLOv5_DATASETS_DIR 的环境变量,如果未设置,则使用 ROOT.parent / 'datasets' 作为默认路径,其中 ROOT 是一个之前定义的变量,通常指向项目的根目录。 Path 是 pathlib 模块中的一个类,用于表示文件系统路径。
DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory
# 设置 AUTOINSTALL 变量,它代表全局自动安装模式。 os.getenv('YOLOv5_AUTOINSTALL', True) 尝试获取名为 YOLOv5_AUTOINSTALL 的环境变量,如果未设置,则默认为 True 。 str(...).lower() == 'true' 将环境变量的值转换为小写字符串,并检查是否等于 'true' 。
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
# 设置 VERBOSE 变量,它代表全局详细模式。与 AUTOINSTALL 类似, os.getenv('YOLOv5_VERBOSE', True) 尝试获取名为 YOLOv5_VERBOSE 的环境变量,如果未设置,则默认为 True 。 str(...).lower() == 'true' 将环境变量的值转换为小写字符串,并检查是否等于 'true' 。
VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
# 设置 TQDM_BAR_FORMAT 变量,它定义了 tqdm 进度条的格式。 tqdm 是一个快速、可扩展的 Python 进度条库,用于在长循环中添加进度条。这个格式字符串定义了进度条的外观,包括 左侧的标签 、 进度条本身 、 当前进度 和 总进度 、以及 已过时间 。
TQDM_BAR_FORMAT = '{l_bar}{bar:10}| {n_fmt}/{total_fmt} {elapsed}' # tqdm bar format
# 设置 FONT 变量,指定了默认的字体文件名。这个字体文件可能用于绘制文本,例如在图像上绘制检测到的对象的标签。
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
# 这些设置为 YOLOv5 模型提供了灵活的配置选项,允许用户根据他们的系统环境和偏好来调整行为。通过环境变量,用户可以在不修改代码的情况下改变程序的行为,这在不同的部署环境中非常有用。
# 这段代码包含了一系列的设置,用于配置不同的Python库的输出选项和线程行为。
# torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None)
# torch.set_printoptions() 是 PyTorch 中的一个函数,用于设置打印张量时的选项,比如限制打印的元素数量、设置精度等。这个函数在处理大型张量或者需要更精确控制输出格式时非常有用。
# 参数 :
# precision :浮点输出的精度位数(默认值 = 4)。
# threshold :触发汇总而不是完整 repr 的数组元素总数(默认值 = 1000)。当张量的元素总数超过这个阈值时,会以省略号的形式显示,而不是显示所有元素。
# edgeitems :每个维度开头和结尾处汇总的数组项数(默认值 = 3)。当张量以省略号形式显示时,这个参数控制显示每个维度开始和结束的元素数量。
# linewidth :用于插入换行符的每行字符数(默认值 = 80)。对于超过 threshold 而折叠的张量,这个参数会被忽略。
# profile :打印选项的预设配置。可以覆盖上述任何选项。有三个选项 : default 、 short 、 full 。
# default :默认的打印选项。
# short :更简洁的打印选项,减少显示的元素数量和精度。
# full :显示更多的元素和更高的精度。
# sci_mode :启用(True)或禁用(False)科学记数法。如果指定为 None(默认),则该值由 torch._tensor_str._Formatter 定义,由框架自动选择。
# 通过使用 torch.set_printoptions() ,你可以灵活地控制 PyTorch 张量的打印输出,使其更适合你的调试和展示需求。
# 设置了PyTorch张量打印选项。 linewidth=320 指定了输出行的最大宽度, precision=5 设置了浮点数的打印精度(小数点后5位), profile='long' 指定了打印格式,这通常意味着更详细的输出。
torch.set_printoptions(linewidth=320, precision=5, profile='long')
# numpy.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, suppress=None, nanstr=None, infstr=None, formatter=None, sign=None, floatmode=None)
# 参数 :
# precision :浮点数的打印精度,默认为8位小数。如果设置为 None ,则使用 NumPy 的默认精度。
# threshold :控制当数组元素总数超过这个值时,数组将以缩略形式显示。默认值为1000。
# edgeitems :在缩略显示时,每个维度的开头和结尾显示的元素数量。默认值为3。
# linewidth :输出行的最大宽度。默认值为75。
# suppress :如果为 True ,则在显示浮点数时不显示尾随的零。默认为 False 。
# nanstr :表示 NaN 值的字符串,默认为 "nan" 。
# infstr :表示 Inf 和 -Inf 值的字符串,默认分别为 "inf" 和 "-inf" 。
# formatter :一个字典,用于定义特定数据类型的格式化函数或字符串。例如, {'float_kind': '{:12.5g}'.format} 。
# sign :控制浮点数的符号显示方式。可以是 '+' 、 '-' 或 None 。默认为 None ,表示正数不显示符号。
# floatmode :控制浮点数的显示模式。可以是 'fixed' 、 'scientific' 、 'maxprec' 或 None 。默认为 None ,表示根据数值大小自动选择显示模式。
# 通过使用 np.set_printoptions() ,可以灵活地控制 NumPy 数组的打印输出,使其更适合调试和展示需求。
# 设置了NumPy数组的打印选项。 linewidth=320 同样设置了输出行的最大宽度。 formatter={'float_kind': '{:11.5g}'.format} 定义了浮点数的格式化方式, {:11.5g} 是一个格式化字符串,意味着以一般形式(g格式)打印浮点数,保留5位有效数字,并且总宽度为11个字符。
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
# 设置了Pandas库的显示选项。 max_columns=10 限制了在输出中显示的最大列数。如果DataFrame有超过10列,那么Pandas将不会显示所有列。
pd.options.display.max_columns = 10
# cv2.setNumThreads(numThreads)
# cv2.setNumThreads() 是 OpenCV 库中的一个函数,用于设置 OpenCV 并行处理时使用的线程数。
# 参数 :
# numThreads :要设置的线程数目。这个参数控制 OpenCV 在执行并行操作时使用的线程数量。
# 作用 :
# cv2.setNumThreads() 函数的作用是控制 OpenCV 在进行并行处理时使用的线程数。这可以影响程序的性能和效率,尤其是在进行大规模图像处理或视频分析任务时。通过调整线程数,可以优化程序的并行处理能力,提高计算速度和实时性能。
# 注意事项 :
# cv2.setNumThreads() 函数是一个全局设置,意味着它将影响到之后 OpenCV 库中的所有并行处理操作。因此,在设置线程数目之后,需要谨慎选择合适的值,避免对其他部分的程序性能造成负面影响。
# 在某些情况下,比如与 PyTorch 的 DataLoader 一起使用时,为了避免潜在的死锁问题,可能需要将线程数设置为 0,禁用 OpenCV 的多线程功能。
# 设置线程数为 0 可以防止 OpenCV 使用多线程,这在某些特定的应用场景中可能是必要的。
# 通过合理配置 cv2.setNumThreads() ,可以有效地控制 OpenCV 的并行处理行为,从而优化程序的整体性能。
# 设置了OpenCV的线程数。 cv2.setNumThreads(0) 禁用了OpenCV的多线程功能。这是因为OpenCV的多线程可能与PyTorch的DataLoader不兼容,可能会导致问题。
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
# 设置了环境变量 NUMEXPR_MAX_THREADS ,用于控制NumExpr库使用的线程数。 NUM_THREADS 是之前定义的变量,表示YOLOv5多进程处理时使用的线程数。
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
# platform.system()
# platform.system() 是 Python 标准库 platform 模块中的一个函数,用于获取当前操作系统的名称。
# 返回值 :
# 返回一个字符串,表示当前操作系统的名称。 在 Windows 上,返回 'Windows' 。 在 macOS 上,返回 'Darwin' 。 在 Linux 上,返回 'Linux' 。
# 注意事项 :
# platform.system() 只返回操作系统的简单名称,不提供版本信息。
# 如果需要更详细的操作系统信息,可以使用 platform.release() , platform.version() , platform.machine() 等函数。
# platform 模块还提供了其他函数来获取系统架构、节点名称、处理器等信息。
# 其他相关函数 :
# platform.release() :返回操作系统的发布版本号。
# platform.version() :返回操作系统的版本字符串。
# platform.machine() :返回计算机的硬件类型。
# platform.node() :返回网络名称(计算机名)。
# platform.architecture() :返回机器的硬件架构。
# 使用 platform 模块可以帮助你的 Python 程序更灵活地适应不同的操作系统环境。
# 设置了环境变量 OMP_NUM_THREADS ,用于控制OpenMP并行区域的线程数。 platform.system() == 'darwin' 检查操作系统是否为macOS(darwin),如果是,则设置 OMP_NUM_THREADS 为 '1' ,因为在macOS上,多线程可能会降低性能。否则,设置为 NUM_THREADS 的值。
os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
# 这些设置优化了不同库的输出显示和线程行为,以确保在多线程环境中的兼容性和性能。通过调整这些选项,可以提高代码的可读性和执行效率,特别是在处理大型数据集和进行大规模并行计算时。
2.def is_ascii(s=''):
# 这段代码定义了一个名为 is_ascii 的函数,它用于检查一个字符串是否完全由 ASCII 字符组成,即不包含任何非ASCII(如 Unicode 或 UTF-8)字符。
# 定义 is_ascii 函数,它接受一个参数。
# 1.s :这是一个默认为空字符串的参数。这个参数可以是任何类型,因为函数内部会将其转换为字符串。
def is_ascii(s=''):
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7) 字符串是否全部由 ASCII(无 UTF)字符组成?(请注意,str().isascii() 在 python 3.7 中引入)。
# 将输入参数 s 转换为字符串。如果 s 是列表、元组、 None 或其他非字符串类型,这行代码确保它们被转换为字符串类型,以便进行后续的 ASCII 检查。
s = str(s) # convert list, tuple, None, etc. to str
# 这行代码执行以下操作 :
# s.encode() :将字符串 s 编码为字节串。如果 s 包含非ASCII字符,这一步可能会失败或产生错误。
# .decode('ascii', 'ignore') :将字节串解码回字符串,但只限于 ASCII 字符。非ASCII字符将被忽略(即删除)。
# len(s.encode().decode('ascii', 'ignore')) == len(s) :比较解码后的字符串长度与原始字符串长度。如果相等,说明原始字符串中不包含任何非ASCII字符,因此它完全由 ASCII 字符组成。
return len(s.encode().decode('ascii', 'ignore')) == len(s)
# 这个 is_ascii 函数提供了一个简单的方式来检查字符串是否为纯 ASCII,这在处理文本数据时非常有用,尤其是在需要确保文本兼容性或进行特定于 ASCII 的文本处理时。
3.def is_chinese(s='人工智能'):
# 这段代码定义了一个名为 is_chinese 的函数,它用于检查传入的字符串 s 是否包含至少一个中文字符。
# 定义了一个名为 is_chinese 的函数,它接受一个参数。
# 1.s :默认值为 '人工智能' 。这个参数是一个字符串,你想要检查它是否包含中文字符。
def is_chinese(s='人工智能'):
# Is string composed of any Chinese characters? 字符串是否由任意汉字组成?
# 这行代码执行了实际的检查操作。
# re.search() 是 Python 标准库 re (正则表达式)模块中的一个函数,用于在字符串中搜索匹配正则表达式的模式。
# [\u4e00-\u9fff] :这是一个 Unicode 范围,表示所有的中文字符。 \u4e00 是汉字“一”的 Unicode 编码, \u9fff 是汉字“龥”的 Unicode 编码,这个范围覆盖了绝大部分常用汉字。
# str(s) :确保输入的 s 被转换成字符串,以防传入的不是字符串类型。
# bool(...) : re.search() 函数返回一个匹配对象,如果找到了匹配项,则 bool 函数将其转换为 True ;如果没有找到匹配项,则转换为 False 。
return bool(re.search('[\u4e00-\u9fff]', str(s)))
# 这个函数只能检查字符串是否包含中文字符,但不会检查字符串是否完全由中文字符组成。
# 如果需要检查字符串是否完全由中文字符组成,可以将 bool(re.search()) 改为 bool(re.match()) ,并移除默认参数,因为 re.match() 要求匹配必须从字符串的开始位置起。
# 这个函数默认参数包含中文字符,因此在没有传入参数时,直接调用 is_chinese() 也会返回 True 。
4.def is_colab():
# 这段代码定义了一个名为 is_colab 的函数,用于检查当前的运行环境是否是 Google Colab 实例。
# 定义了一个名为 is_colab 的函数,它不接受任何参数。
def is_colab():
# Is environment a Google Colab instance? 环境是 Google Colab 实例吗?
# 这行代码执行了实际的检查操作。 sys.modules 是一个字典,包含了所有已经加载的 Python 模块。如果当前环境是 Google Colab,那么 google.colab 模块会被加载,因此 sys.modules 字典中会包含 'google.colab' 这个键。
# in 关键字用于检查 'google.colab' 是否是 sys.modules 字典的键。如果是,返回 True ;否则返回 False 。
return 'google.colab' in sys.modules
# Google Colab 环境特有的 google.colab 模块提供了一些额外的功能,比如与 Google Drive 的集成、免费的 GPU 和 TPU 资源等。通过检查这个模块是否存在,可以确定代码是否运行在 Colab 环境中,并据此调整代码的行为。
5.def is_notebook():
# 这段代码定义了一个名为 is_notebook 的函数,用于判断当前的运行环境是否是一个 Jupyter 笔记本环境。这个函数在不同的 Jupyter 笔记本环境中进行了验证,包括 Google Colab、JupyterLab、Kaggle 和 Paperspace。
# 定义了一个没有参数的函数 is_notebook 。
def is_notebook():
# Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
# 获取当前的 IPython 实例,并将其类型转换为字符串。 IPython.get_ipython() 是 IPython 库中的一个函数,用于获取当前的 IPython 实例。 type() 函数用于获取这个实例的类型,然后将其转换为字符串。
ipython_type = str(type(IPython.get_ipython()))
# 码检查字符串 ipython_type 中是否包含 'colab' 或 'zmqshell' 。这两个字符串是 Jupyter 笔记本环境中常见的类型标识符 :
# 'colab' :Google Colab 环境。
# 'zmqshell' :Jupyter 的 ZeroMQ 内核,用于 Jupyter 笔记本和其他类似环境中的通信。
# 如果 ipython_type 包含这两个字符串中的任何一个,函数就返回 True ,表示当前环境是一个 Jupyter 笔记本环境;否则返回 False 。
return 'colab' in ipython_type or 'zmqshell' in ipython_type
6.def is_kaggle():
# 这段代码定义了一个名为 is_kaggle 的函数,用于检查当前的运行环境是否是 Kaggle Notebook。
# 定义了一个名为 is_kaggle 的函数,它不接受任何参数。
def is_kaggle():
# Is environment a Kaggle Notebook? 环境是 Kaggle Notebook 吗?
# 这行代码执行了实际的检查操作。它使用了 os.environ.get 方法来获取环境变量的值,并检查这些值是否符合 Kaggle Notebook 的特征。
# os.environ.get('PWD') :尝试获取当前工作目录的环境变量 PWD 。在 Kaggle Notebook 中,这个值通常是 /kaggle/working 。
# os.environ.get('KAGGLE_URL_BASE') :尝试获取 Kaggle 的基础 URL 环境变量 KAGGLE_URL_BASE 。在 Kaggle Notebook 中,这个值通常是 'https://www.kaggle.com' 。
# and :逻辑与操作符,确保两个条件同时满足时才返回 True 。
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
# 函数 is_kaggle 通过检查两个环境变量来判断当前环境是否为 Kaggle Notebook。如果当前工作目录是 /kaggle/working 且 KAGGLE_URL_BASE 环境变量的值为 'https://www.kaggle.com' ,则函数返回 True ,表示当前环境是 Kaggle Notebook;否则返回 False 。
# 这个函数可以帮助开发者识别代码是否在特定的 Kaggle 环境中运行,从而进行相应的环境适配或行为调整。例如,开发者可能会根据环境不同加载不同的数据集或使用不同的资源。
7.def is_docker() -> bool:
# 这段代码定义了一个名为 is_docker 的函数,其目的是检查当前进程是否运行在 Docker 容器内。
# 定义了一个名为 is_docker 的函数,它不接受任何参数,并明确返回类型为 bool 。
def is_docker() -> bool:
# 检查该进程是否在 Docker 容器内运行。
"""Check if the process runs inside a docker container."""
# 检查 /.dockerenv 文件是否存在。Docker 容器中通常会包含这个文件,所以如果文件存在,函数直接返回 True 。
if Path("/.dockerenv").exists():
return True
# 尝试打开 /proc/self/cgroup 文件,这个文件包含了当前进程的控制组(cgroup)信息。如果任何一行中包含字符串 "docker" ,则表示进程在 Docker 的控制组下运行,函数返回 True 。
try: # check if docker is in control groups
with open("/proc/self/cgroup") as file:
return any("docker" in line for line in file)
# 如果尝试打开 /proc/self/cgroup 文件时发生 OSError 异常(例如,文件不存在或没有权限读取),函数捕获异常并返回 False 。
except OSError:
return False
# 函数 is_docker 通过两种方式检查当前进程是否运行在 Docker 容器内。检查 /.dockerenv 文件是否存在。检查 /proc/self/cgroup 文件中是否包含 "docker" 字符串。如果任一检查结果为真,则函数返回 True ,表示进程运行在 Docker 容器内;否则返回 False 。这种方法可以有效地帮助开发者识别代码是否在 Docker 环境中执行,从而进行相应的环境适配或行为调整。
8.def is_writeable(dir, test=False):
# 这段代码定义了一个名为 is_writeable 的函数,用于检查指定目录是否具有写入权限。如果需要,它还可以通过尝试创建一个测试文件来验证写入权限。
# 定义 is_writeable 函数,它接受两个参数。
# 1.dir :要检查的目录路径。
# 2.test :一个布尔值,指示是否通过创建一个测试文件来验证写入权限,默认为 False 。
def is_writeable(dir, test=False):
# Return True if directory has write permissions, test opening a file with write permissions if test=True 如果目录具有写权限,则返回 True,如果 test=True,则测试打开具有写权限的文件。
# 如果 test 参数为 False ,则不进行测试文件操作。
if not test:
# 使用 os.access() 函数检查目录是否有写入权限。 os.W_OK 是一个常量,表示检查写入权限。如果目录具有写入权限,则返回 True ;否则,返回 False 。
return os.access(dir, os.W_OK) # possible issues on Windows
# 如果 test 为 True ,则创建一个路径对象,表示在指定目录下的一个测试文件 tmp.txt 。
file = Path(dir) / 'tmp.txt'
# 开始一个 try 块,用于捕获在尝试写入文件时可能发生的任何 OSError 。
try:
# 尝试以写入模式打开测试文件。如果文件成功打开,说明目录具有写入权限。
with open(file, 'w'): # open file with write permissions
# 在 with 块中,不需要执行任何操作,因为目的只是检查文件是否可以被打开。
pass
# 在测试完成后,删除测试文件。
file.unlink() # remove file
# 如果测试文件成功打开并关闭,返回 True ,表示目录具有写入权限。
return True
# 如果发生 OSError 异常,表示目录没有写入权限或无法创建测试文件。
except OSError:
# 如果发生异常,返回 False ,表示目录没有写入权限。
return False
# 这个 is_writeable 函数提供了一种灵活的方式来检查目录的写入权限,可以选择是否进行实际的文件写入测试。这对于确保程序在尝试写入文件之前具有必要的权限非常有用。
LOGGING_NAME = "yolov5"
9.def set_logging(name=LOGGING_NAME, verbose=True):
# 这段代码定义了一个名为 set_logging 的函数,它用于配置 Python 的 logging 模块,以便为特定的名称设置日志记录。
# 函数定义。
# 1.name : 日志记录器的名称,默认为 LOGGING_NAME (LOGGING_NAME = "yolov5")。
# 2.verbose : 是否为详细模式,即是否记录详细信息,默认为 True 。
def set_logging(name=LOGGING_NAME, verbose=True):
# sets up logging for the given name 设置给定名称的日志记录。
# os.getenv(key, default=None)
# os.getenv() 是 Python 标准库 os 模块中的一个函数,用于获取环境变量的值。环境变量是操作系统级别的变量,通常用于配置程序的运行环境。
# 参数 :
# key : 要获取的环境变量的名称(字符串)。
# default : 可选参数,如果指定的环境变量不存在,则返回该默认值。默认为 None 。
# 返回值 :
# 返回指定环境变量的值(字符串),如果环境变量不存在且未提供默认值,则返回 None ;如果提供了默认值,则返回该默认值。
# 使用场景 :
# os.getenv() 常用于读取配置参数,例如数据库连接字符串、API 密钥等,这些参数可以在操作系统环境中设置,以便在代码中使用。
# 它可以用于根据环境变量的值来调整程序的行为,例如在开发和生产环境中使用不同的配置。
# 注意事项 :
# 使用 os.getenv() 读取环境变量时,返回值始终是字符串类型。如果环境变量的值是数字或其他类型,仍然会被转换为字符串。
# 如果需要获取所有环境变量,可以使用 os.environ ,它返回一个包含所有环境变量的字典。
# os.getenv() 是处理环境变量的便捷方法,常用于配置和管理程序的运行环境。
# 获取环境变量 RANK 的值,并将其转换为整数。 RANK 通常用于多 GPU 训练中标识当前进程的排名。如果环境变量不存在,则默认为 -1 。
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings 多 GPU 训练的world排名。
# logging.INFO
# logging.INFO 是一个用于指定日志级别常量的预定义属性。 logging 模块用于跟踪程序的运行状态,提供信息输出等功能,而 logging.INFO 表示信息级别的日志,通常用于输出程序的正常运行信息。
# 属性定义 :
# logging.INFO 是一个整数,其值为 20。它被用来指示日志消息的重要性级别。在 logging 模块中,日志级别从低到高依次为 DEBUG (10)、 INFO (20)、 WARNING (30)、 ERROR (40)和 CRITICAL (50)。
# 使用场景 :
# logging.INFO 通常用于输出程序运行过程中的一般信息,这些信息对于了解程序的运行状态和调试程序是有帮助的,但通常不包含敏感数据或错误信息。
# 注意事项 :
# 日志级别用于控制日志消息的输出,只有那些级别等于或高于当前设置级别的日志消息才会被输出。
# 在实际开发中,可以根据需要调整日志级别,以便更好地控制日志输出的详细程度。例如,在开发和测试阶段可能希望看到更多的信息,而在生产环境中可能只需要警告和错误信息。
# 根据 verbose 参数和 rank 值确定日志级别。如果 verbose 为 True 且 rank 为 -1 或 0 (即主进程),则设置日志级别为 INFO ,否则设置为 ERROR 。
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
# logging.config.dictConfig(configDict)
# logging.config.dictConfig() 是 Python logging 模块中的一个函数,它根据提供的字典配置来配置日志系统。这个函数允许你以结构化的字典格式定义日志记录器(loggers)、处理器(handlers)、格式化器(formatters)等组件的配置,提供了一种灵活的方式来设置日志系统。
# 参数 :
# configDict : 一个字典,包含了日志系统的配置信息。
# 配置字典的结构 :
# 配置字典 configDict 通常包含以下几个主要部分 :
# version : 配置字典的版本号,目前通常设置为 1。
# formatters : 定义 格式化器 的配置,其中的键是格式化器的名称,值是格式化器的配置。
# handlers : 定义 处理器 的配置,其中的键是处理器的名称,值是处理器的配置。
# loggers : 定义日志 记录器 的配置,其中的键是日志记录器的名称,值是日志记录器的配置。
# 函数的作用 :
# logging.config.dictConfig() 函数使用提供的配置字典来配置日志系统,这使得日志系统的配置可以与代码分离,便于管理和修改。这种方式特别适合于复杂应用程序,其中日志配置可能需要根据不同的环境或条件进行调整。
# 注意事项 :
# 配置字典中的 class 属性需要使用字符串来指定类的全路径,例如 'class': 'logging.StreamHandler' 。
# 配置字典中的 stream 属性可以指定日志输出的目标流,例如 'stream': 'ext://sys.stdout' 表示标准输出流。
# 使用 dictConfig() 配置日志系统后,可以通过 logging.getLogger() 获取配置好的日志记录器实例,并使用它来输出日志消息。
# 使用 dictConfig 方法配置日志记录器。这个方法允许使用字典格式配置日志系统。
logging.config.dictConfig({
# 配置字典的版本号。
"version": 1,
# 不禁用已经存在的日志记录器。
"disable_existing_loggers": False,
# 定义一个格式化器,用于指定日志消息的格式。
"formatters": {
name: {
"format": "%(message)s"}},
# 定义一个处理器,使用流处理器 StreamHandler ,指定格式化器和日志级别。
"handlers": {
name: {
"class": "logging.StreamHandler",
"formatter": name,
"level": level,}},
# 定义一个日志记录器,设置日志级别、处理器和禁止消息传播。
"loggers": {
name: {
"level": level,
"handlers": [name],
"propagate": False,}}})
# set_logging 函数提供了一个灵活的方式来配置日志记录,允许用户根据是否为详细模式和进程排名来设置日志级别。这个函数特别适用于需要根据不同环境(如单 GPU 或多 GPU 训练)调整日志详细程度的场景。通过这种方式,用户可以更好地控制日志输出,以便调试和监控程序的运行情况。
# def set_logging(name=LOGGING_NAME, verbose=True): -> 用于配置 Python 的 logging 模块,以便为特定的名称设置日志记录。
set_logging(LOGGING_NAME) # run before defining LOGGER
# logging.getLogger(name)
# logging.getLogger 是 Python 标准库 logging 模块中的一个函数,用于返回一个日志记录器(logger)对象。每个日志记录器都有一个名字,通常使用 __name__ 变量来指定当前模块的名称作为日志记录器的名字。
# 这个对象是用于处理和跟踪程序中事件(例如错误、警告、信息等)的工具。
# name :日志记录器的名称,默认为 __name__ ,即当前模块的名称。
# 功能 :
# getLogger 函数的主要功能是获取一个指定名称的日志记录器对象。如果该名称的日志记录器已经存在,则返回该日志记录器;如果不存在,则创建一个新的日志记录器,并将其添加到 logging 系统的日志记录器树中。
# 这行代码是 Python 中使用 logging 模块创建或获取一个日志记录器(logger)的常见方式。 LOGGER 是一个日志记录器对象,用于记录日志信息。
# logging.getLogger(name) 函数接受一个 name 参数,返回一个对应名称的日志记录器对象。如果同名的日志记录器已经存在,则返回现有的日志记录器;如果不存在,则创建一个新的日志记录器。
# LOGGER -> LOGGER 是一个日志记录器对象,用于记录日志信息。
LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
if platform.system() == 'Windows':
for fn in LOGGER.info, LOGGER.warning:
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
10.def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
# 这段代码定义了一个名为 user_config_dir 的函数,它用于获取用户的配置目录路径。如果环境变量中指定了路径,则使用该路径;否则,它会根据不同的操作系统返回一个默认的配置目录路径,并确保该目录存在。
# dir :这个参数指定了配置目录的名称。函数会在这个目录下创建或查找配置文件。如果环境变量没有指定配置目录,这个参数的值将被用来在用户的主配置目录下创建或查找一个名为 'Ultralytics' 的目录。
# env_var :这个参数指定了一个环境变量的名称。函数会检查这个环境变量是否存在,如果存在,函数会使用这个环境变量的值作为配置目录的路径。这允许用户在系统环境变量中设置配置目录的位置,从而覆盖默认的行为。
def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
# Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
# 尝试从环境变量中获取 env_var 指定的值。 os.getenv 函数会返回与环境变量名对应的值,如果该环境变量不存在,则返回 None 。
env = os.getenv(env_var)
# 如果环境变量存在(即 env 不为 None ),则使用该环境变量的值作为路径,并将其转换为 Path 对象。
if env:
path = Path(env) # use environment variable
# 如果环境变量不存在,这段代码定义了一个字典 cfg ,它包含了不同操作系统的配置目录路径。
else:
cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
# 使用 Path.home() 获取用户的主目录路径,然后根据当前操作系统(使用 platform.system() 获取)从 cfg 字典中获取相应的配置目录路径,并将其与用户的主目录路径组合。
path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
# 检查上一步得到的路径是否可写。如果不是可写的,则使用 /tmp 目录作为备选路径。这是因为在某些环境(如 Google Cloud Platform (GCP) 和 Amazon Web Services (AWS) Lambda)中,只有 /tmp 目录是可写的。然后,将 dir 参数的值追加到路径中。
# def is_writeable(dir, test=False): -> 用于检查指定目录是否具有写入权限。如果需要,它还可以通过尝试创建一个测试文件来验证写入权限。 -> return os.access(dir, os.W_OK) / return True / return False
path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
# 使用 mkdir 方法创建路径所指向的目录, exist_ok=True 参数表示如果目录已经存在,则不会抛出异常。
path.mkdir(exist_ok=True) # make if required
# 函数返回最终确定的路径对象。
return path
# 这个函数的目的是提供一个用户的配置目录路径,它会优先考虑环境变量中的设置,如果环境变量未设置,则根据不同的操作系统返回一个默认的配置目录路径,并确保该目录存在。
CONFIG_DIR = user_config_dir() # Ultralytics settings dir Ultralytics 设置目录。
11.class Profile(contextlib.ContextDecorator):
# 这段代码定义了一个名为 Profile 的类,它继承自 contextlib.ContextDecorator 。这个类可以用来测量代码块的执行时间,既可以作为一个装饰器使用,也可以作为一个上下文管理器使用。
# 定义了一个名为 Profile 的类,它继承自 contextlib.ContextDecorator 。这意味着 Profile 类可以用来作为一个装饰器或者上下文管理器。
class Profile(contextlib.ContextDecorator):
# YOLO Profile class. Usage: @Profile() decorator or 'with Profile():' context manager YOLO Profile 类。用法:@Profile() 装饰器或 'with Profile():' 上下文管理器。
# 这是 Profile 类的构造函数,它接受一个参数。
# 1.t :默认值为 0.0 。这个参数用来累积测量的时间。
def __init__(self, t=0.0):
# 将传入的参数 t 赋值给实例变量 self.t 。
self.t = t
# 检查 CUDA 是否可用(即是否有 NVIDIA GPU 并且 PyTorch 可以使用它),并将结果赋值给实例变量 self.cuda 。
self.cuda = torch.cuda.is_available()
# 这是一个特殊方法,当使用 with 语句进入上下文时被调用。
def __enter__(self):
# 调用 self.time() 方法来获取当前时间,并将其赋值给 self.start ,这表示代码块开始执行的时间。
self.start = self.time()
# 返回 self 对象,允许在 with 语句中使用 Profile 实例。
return self
# 这是一个特殊方法,当使用 with 语句退出上下文时被调用。它接受三个参数。
# 1.type 、 2.value 和 3.traceback ,这些参数提供了异常的信息(如果有的话)。
def __exit__(self, type, value, traceback):
# 再次调用 self.time() 方法来获取当前时间,并计算与 self.start 的差值,这个差值表示代码块执行的时间,赋值给 self.dt 。
self.dt = self.time() - self.start # delta-time
# 将 self.dt 的值加到 self.t 上,这样可以累积测量的时间。
self.t += self.dt # accumulate dt
# 定义了一个名为 time 的方法,用于获取当前时间。
def time(self):
# 检查 self.cuda 是否为 True 。
if self.cuda:
# torch.cuda.synchronize(device=None)
# torch.cuda.synchronize() 是 PyTorch 库中的一个函数,用于同步 CPU 和 GPU 之间的计算。
# 参数 :
# device :可选参数,指定要同步的设备。可以是 torch.device 对象、整数(表示 GPU 编号)或字符串(如 'cuda:0' )。如果省略此参数或设置为 None ,则默认同步当前设备。
# 功能 :
# torch.cuda.synchronize() 函数用于确保当前设备上所有先前提交的 CUDA 核心执行完成。这个函数会阻塞调用它的 CPU 线程,直到所有 CUDA 核心完成执行为止。
# 作用 :
# 由于 GPU 计算是异步的,提交到 GPU 的计算任务会被放入队列中,程序不会等待 GPU 完成计算就继续执行后续代码。这可能会导致一个问题 :在 GPU 计算尚未完成时,CPU 就开始访问 GPU 的计算结果,此时可能会得到错误的结果。
# 为了避免这种情况,可以使用 torch.cuda.synchronize() 函数来同步 CPU 和 GPU 之间的计算,确保在继续执行后续代码之前 GPU 的计算已经完成。
# 典型使用场景 :
# 性能测试 :在测量 GPU 上运行的操作(如模型推理)所需时间时,需要在操作前后分别调用 torch.cuda.synchronize() 来确保测量时间包括了 GPU 上所有操作的执行时间。
# 确保数据一致性 :在需要确保 GPU 计算结果已经准备好被 CPU 使用的场景下,比如在训练深度神经网络时,在每个 epoch 结束后计算验证集的误差,此时需要使用 torch.cuda.synchronize() 来同步 CPU 和 GPU 之间的计算,以确保得到正确的结果。
# torch.cuda.synchronize() 是一个重要的函数,用于确保在继续执行后续代码之前 GPU 的计算已经完成,从而避免因异步计算导致的数据不一致问题。
# 如果 self.cuda 为 True ,则调用 torch.cuda.synchronize() 来确保所有 CUDA 核心的命令都完成。这是必要的,因为 GPU 操作可能是异步的,我们需要确保时间测量的准确性。
torch.cuda.synchronize()
# 返回当前的系统时间,这是通过 Python 标准库中的 time 模块的 time 函数获取的。
return time.time()
# Profile 类提供了一个简单的方式来测量代码块的执行时间,同时考虑到了 CUDA 的异步行为,以确保时间测量的准确性。
12.class Timeout(contextlib.ContextDecorator):
# 这段代码定义了一个名为 Timeout 的类,它是一个上下文装饰器,用于设置一个操作的超时时间。如果操作超过了指定的时间限制,它将引发一个 TimeoutError 。
# 定义了一个名为 Timeout 的类,它继承自 contextlib.ContextDecorator 。这意味着 Timeout 可以被用作装饰器和上下文管理器。
class Timeout(contextlib.ContextDecorator):
# YOLO Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager YOLO 超时类。
# 初始化方法。
# 这是 Timeout 类的初始化方法,它接受以下参数 :
# 1.seconds :超时时间,以秒为单位。
# 2.timeout_msg :超时时的错误消息,默认为空字符串。
# 3.suppress_timeout_errors :一个布尔值,指示是否抑制超时错误,默认为 True 。
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
# seconds 被转换为整数, timeout_message 和 suppress 被设置为相应的值。
self.seconds = int(seconds)
self.timeout_message = timeout_msg
self.suppress = bool(suppress_timeout_errors)
# 超时处理方法。
# 这是一个私有方法,用作信号处理函数。当超时发生时(即 SIGALRM 信号被触发时),这个方法会被调用,并引发一个 TimeoutError 。
def _timeout_handler(self, signum, frame):
raise TimeoutError(self.timeout_message)
# 这段代码是 Timeout 类的 __enter__ 方法的实现,它是上下文管理器协议的一部分。当使用 with 语句和 Timeout 类时, __enter__ 方法会在代码块执行前被调用。这个方法的作用是设置超时处理机制。
# 定义了 __enter__ 方法,它是上下文管理器的一部分,用于初始化资源或执行一些设置。
def __enter__(self):
# 检查当前操作系统是否不是 Windows。这是因为 Windows 系统不支持 UNIX 信号(如 SIGALRM ),所以这部分代码在 Windows 上不会执行。
if platform.system() != 'Windows': # not supported on Windows
# signal.signal(signalnum, handler)
# signal.signal() 是 Python 标准库 signal 模块中的一个函数,用于设置信号处理器。信号是操作系统用来与进程通信的一种机制,它们可以响应各种事件,如终端输入、程序错误、外部中断等。
# 参数 :
# signalnum :信号的编号,比如 signal.SIGINT 、 signal.SIGALRM 等。
# handler :当信号被接收时将被调用的处理器。这个参数可以是以下几种 :一个函数,它将被用作信号的处理器。 signal.SIG_IGN ,表示忽略该信号。 signal.SIG_DFL ,表示使用默认的处理方式(通常是终止进程)。
# 返回值 :
# 返回先前的信号处理器,这样你可以在之后恢复它。
# 注意事项 :
# signal.signal() 只在支持信号的系统中有效,比如 Unix 和 Linux 系统。在 Windows 上,支持的信号有限。
# 信号处理器应该尽可能简单,因为它们运行在不同的上下文中,复杂的操作可能会导致不可预测的行为。
# 信号处理器不能被传递给 signal.signal() 的任何参数。
# 信号处理器不应该调用任何可能引起信号自身的函数,这可能会导致递归调用和程序崩溃。
# signal.SIGALRM = 14
# signal.SIGALRM 是 Python signal 模块中的一个常量,它代表一个特定的信号,用于在 Unix 和类 Unix 系统中处理定时器超时。当一个定时器设置通过 signal.alarm() 函数设置并且到期时, SIGALRM 信号会被发送给进程。
# 这个值可能在不同的系统上有所不同,但在大多数 Unix 系统中, SIGALRM 的值是 14。
# 用途 :
# SIGALRM 信号用于实现闹钟定时器功能。可以使用 signal.alarm() 函数设置一个定时器,当定时器到期时, SIGALRM 信号会被触发。可以为这个信号设置一个信号处理器,以便在信号被触发时执行特定的操作。
# 注意事项 :
# SIGALRM 信号只在支持信号的系统中有效,比如 Unix 和 Linux 系统。在 Windows 上,这个信号不可用。
# 信号处理器应该尽可能简单,因为它们运行在不同的上下文中,复杂的操作可能会导致不可预测的行为。
# 使用 signal.alarm() 设置的定时器是一次性的,一旦触发 SIGALRM 信号后,定时器就会停止,除非你再次调用 signal.alarm() 设置新的定时器。
# 使用 signal.signal 函数设置 SIGALRM 信号的处理函数。 signal.SIGALRM 是一个信号,用于通知进程一个定时器已经到期。 self._timeout_handler 是 Timeout 类中定义的一个方法,它将在 SIGALRM 信号被触发时执行。
signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
# signal.alarm(seconds)
# signal.alarm() 是 Python 标准库 signal 模块中的一个函数,它用于设置一个定时器,当指定的时间过去后, SIGALRM 信号会被发送给进程。这个函数主要用于 Unix 和类 Unix 系统,因为 Windows 不支持 SIGALRM 信号。
# 参数 :
# seconds :一个整数,表示定时器的秒数。如果参数为 0,则取消任何已经设置的定时器。
# 返回值 :
# signal.alarm() 在 Python 3 中不返回任何值(在 Python 2 中返回先前的定时器值)。
# 注意事项 :
# signal.alarm() 设置的定时器是一次性的,一旦触发 SIGALRM 信号后,定时器就会停止,除非你再次调用 signal.alarm() 设置新的定时器。
# 如果程序在执行一个阻塞操作时 SIGALRM 信号被触发,信号处理器会在阻塞操作完成后立即被调用。
# 在多线程程序中, signal.alarm() 只会影响主线程,因为 SIGALRM 信号不能被线程捕获。
# 在某些系统中, signal.alarm() 设置的定时器精度可能受到系统定时器分辨率的限制,因此实际的定时精度可能不如期望的精确。
# 使用 signal.alarm 函数设置一个定时器,当指定的秒数( self.seconds )过去后, SIGALRM 信号将被发送给进程。这个函数启动了一个倒计时,当倒计时结束时,如果代码块还没有完成执行, SIGALRM 信号将被触发,从而调用 self._timeout_handler 方法。
signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
# __enter__ 方法在 Timeout 上下文管理器被进入时设置了一个定时器,用于在指定的超时时间后触发 SIGALRM 信号。如果操作系统不是 Windows,它还会设置一个信号处理函数来处理这个信号。
# 当 SIGALRM 信号被触发时, _timeout_handler 方法将被调用,从而引发一个 TimeoutError ,除非 suppress_timeout_errors 被设置为 True ,这种情况下异常将被抑制。
# 这个方法是实现超时功能的关键部分,它确保了即使被装饰的函数或 with 代码块执行时间过长,程序也能以受控的方式响应。
# 这段代码是 Timeout 类的 __exit__ 方法的实现,它是上下文管理器协议的一部分。当使用 with 语句和 Timeout 类时, __exit__ 方法会在代码块执行后被调用,无论是正常结束还是由于异常。这个方法的作用是清理设置的超时处理机制。
# 定义了 __exit__ 方法,它是上下文管理器的一部分,用于清理资源或执行一些退出时的操作。这个方法接受三个参数。
# 1.exc_type :异常类型。
# 2.exc_val :异常值。
# 3.exc_tb :异常的 traceback。
def __exit__(self, exc_type, exc_val, exc_tb):
# 检查当前操作系统是否不是 Windows。这是因为 Windows 系统不支持 UNIX 信号(如 SIGALRM ),所以这部分代码在 Windows 上不会执行。
if platform.system() != 'Windows':
# 使用 signal.alarm(0) 取消任何计划的 SIGALRM 信号。如果倒计时还未结束,这将阻止 SIGALRM 信号被发送。
signal.alarm(0) # Cancel SIGALRM if it's scheduled
# 这行代码检查两个条件 : self.suppress 是否为 True ,表示是否抑制超时错误。 exc_type 是否为 TimeoutError ,表示是否捕获到了超时异常。如果这两个条件都满足, __exit__ 方法返回 True ,这将抑制 TimeoutError 异常,防止它被外部捕获和处理。
if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
return True
# __exit__ 方法在 Timeout 上下文管理器退出时执行以下操作 :如果操作系统不是 Windows,取消任何计划的 SIGALRM 信号。 如果设置了抑制超时错误,并且捕获到了 TimeoutError ,则抑制这个异常。
# 这个方法确保了即使被装饰的函数或 with 代码块执行超时,程序也能以受控的方式响应,并且在适当的情况下抑制异常。这提供了一种灵活的方式来处理可能由于超时而引发的异常。
13.class WorkingDirectory(contextlib.ContextDecorator):
# 这段代码定义了一个名为 WorkingDirectory 的类,它是一个上下文装饰器,用于在代码块执行期间临时改变当前工作目录,并在代码块执行完毕后恢复原来的工作目录。
# 定义了一个名为 WorkingDirectory 的类,它继承自 contextlib.ContextDecorator 。这意味着 WorkingDirectory 可以被用作装饰器和上下文管理器。
class WorkingDirectory(contextlib.ContextDecorator):
# Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager 用法 :@WorkingDirectory(dir) 装饰器或 'with WorkingDirectory(dir):' 上下文管理器。
# 初始化方法。
# 这是 WorkingDirectory 类的初始化方法,它接受一个参数。
# 1.new_dir :即你想要切换到的新目录路径。
def __init__(self, new_dir):
# 将传入的 new_dir 保存为实例变量,表示新的工作目录。
self.dir = new_dir # new dir
# current_working_directory = Path.cwd()
# 在Python中, cwd() 函数是 pathlib 模块中的一个方法,用于获取当前工作目录。这个方法是 Path 类的一个实例方法, Path 类是 pathlib 模块中用于处理文件系统路径的类。
# 功能描述 :
# Path.cwd() 方法返回一个 Path 对象,该对象代表当前工作目录的路径。
# 返回值 :
# Path.cwd() 方法返回的是 Path 对象,这个对象提供了许多方法来操作路径,例如 .resolve() 可以获取路径的绝对路径, .as_posix() 可以将路径转换为跨平台的字符串形式等。
# 异常处理 :
# Path.cwd() 方法通常不会抛出异常,因为它只是返回当前工作目录的路径。但是,如果系统出现问题,导致无法确定当前工作目录,可能会抛出异常,这种情况下应该进行异常处理。
# 获取当前工作目录的绝对路径,并保存为实例变量 cwd 。
self.cwd = Path.cwd().resolve() # current dir
# 进入上下文管理器方法。
# 这是上下文管理器的进入方法。当使用 with 语句和 WorkingDirectory 类时, __enter__ 方法会在代码块执行前被调用。
def __enter__(self):
# 调用 os.chdir() 函数改变当前工作目录到 self.dir 。
os.chdir(self.dir)
# 退出上下文管理器方法。
# 这是上下文管理器的退出方法。当代码块执行完毕后,无论是正常结束还是由于异常, __exit__ 方法都会被调用。
def __exit__(self, exc_type, exc_val, exc_tb):
# 调用 os.chdir() 函数恢复原来的工作目录到 self.cwd 。
os.chdir(self.cwd)
# WorkingDirectory 类提供了一个方便的方式来临时改变当前工作目录,确保代码块执行完毕后目录能够被恢复,这在处理需要在特定目录下运行的文件操作时非常有用。
14.def methods(instance):
# 这段代码定义了一个名为 methods 的函数,其目的是获取一个类实例的所有可调用方法。
# 定义了一个名为 methods 的函数,它接受一个参数。
# 1.instance :这个参数应该是一个类的实例。
def methods(instance):
# Get class/instance methods 获取类/实例方法。
# dir(object=None)
# dir() 是 Python 中的一个内置函数,用于获取对象的属性列表,包括方法、类属性以及实例属性。这个函数可以不带参数调用,也可以带一个参数。
# 参数 :
# object :(可选)要检查的对象。如果未提供参数或参数为 None ,默认返回当前局部符号表中的名称。
# 返回值 :
# 返回一个包含对象属性和方法名称的列表(字符串类型)。列表中的名称以字母顺序排列。
# 用法示例 :
# 无参数调用 : print(dir()) # 无参数调用 dir(),返回当前局部作用域中的变量和函数名称。
# 带参数调用 : print(dir(example_instance)) # 使用 dir() 获取实例的所有属性和方法。
# 注意事项 :
# dir() 返回的列表中包含了对象的所有属性和方法,包括继承自父类的属性和方法。
# dir() 返回的列表中也包括了 Python 的特殊方法(如 __init__ 和 __str__ )。
# dir() 返回的列表中不包括类的私有属性和方法(以单下划线 _ 开头的属性和方法)和双重私有属性和方法(以双下划线 __ 开头的属性和方法),除非它们在 __dict__ 中明确定义。
# dir() 可以用于动态地探索和调用对象的属性和方法,特别是在交互式编程和调试中非常有用。
# callable(object)
# callable() 是 Python 中的一个内置函数,用于判断一个对象是否是可调用的。如果对象可以被调用,比如函数、方法、类或者实现了 __call__ 方法的实例,那么 callable() 函数返回 True ;否则返回 False 。
# 参数 :
# object :要检查的对象。
# 返回值 :
# 如果对象是可调用的,则返回 True 。 如果对象不是可调用的,则返回 False 。
# 注意事项 :
# callable() 函数只能用于判断对象是否可被调用,它不能用于判断对象是否可被实例化,即使类是可调用的(即它可以被实例化), callable() 也会返回 False 。
# 有些对象虽然不是函数或实现了 __call__ 方法的实例,但它们可能仍然有 __call__ 属性, callable() 会检查这个属性是否存在并且是否可调用。
# callable() 函数是一个轻量级的函数,它不尝试实际调用对象,只是检查对象是否具有可调用的性质。
# 使用列表推导式来构建一个方法名的列表。
# dir(instance) 返回实例的所有属性和方法的名称列表。对于每个名称 f , getattr(instance, f) 获取对应的属性或方法。 callable(getattr(instance, f)) 检查这个属性或方法是否是可调用的,即是否是一个函数或方法。 not f.startswith("__") 确保排除掉 Python 特殊方法(即以双下划线开始和结束的方法,如 __init__ 和 __str__ )。
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
# methods 函数提供了一个简单的方式来获取一个类实例的所有可调用方法的名称,不包括特殊方法。这对于需要动态调用或检查对象方法的场合非常有用。
15.def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
# 这段代码定义了一个名为 print_args 的函数,其目的是打印函数的参数。这个函数可以显示当前函数的参数,或者如果提供了参数字典,也可以显示任意函数的参数。
# 定义了 print_args 函数,接受三个参数。
# 1.args :一个可选的字典,包含要打印的参数名和值。默认为 None 。
# 2.show_file :一个布尔值,指示是否打印文件名。默认为 True 。
# 3.show_func :一个布尔值,指示是否打印函数名。默认为 False 。
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
# Print function arguments (optional args dict) 打印函数参数(可选参数字典)。
# frame = inspect.currentframe()
# inspect.currentframe() 是 Python 标准库 inspect 模块中的一个函数,它用于获取当前栈帧(stack frame)的对象。栈帧对象表示当前正在执行的函数调用的上下文,包括代码对象、局部变量、参数等信息。
# 这行代码会返回当前执行的栈帧对象。栈帧对象提供了多个方法和属性,可以用来检查函数调用的详细信息,例如 :
# f_back :返回调用当前栈帧的上一个栈帧对象,即调用者的栈帧。
# f_code :返回一个代码对象,包含了当前栈帧的代码信息,如代码行号、变量名等。
# f_locals :返回一个字典,包含了当前栈帧的局部变量。
# f_globals :返回一个字典,包含了当前栈帧的全局变量。
# inspect.currentframe() 通常用于调试和日志记录,可以帮助开发者了解代码的执行流程和上下文环境。不过,需要注意的是,频繁地使用 inspect.currentframe() 可能会对性能产生影响,因为它涉及到栈帧的检查和信息的提取。
# 获取当前函数调用的上一个帧(即调用 print_args 的函数)。
x = inspect.currentframe().f_back # previous frame
# frame_info = inspect.getframeinfo(frame, context=1)
# inspect.getframeinfo() 是 Python 标准库 inspect 模块中的一个函数,它用于获取指定栈帧的详细信息。这个函数需要一个栈帧对象作为参数,通常与 inspect.currentframe() 结合使用来获取当前的栈帧信息。
# 参数 :
# frame :一个栈帧对象,通常通过 inspect.currentframe() 获取。
# context :一个整数,表示返回的帧信息中应包含多少层上级帧的信息。默认值为 0,意味着只返回传入的帧信息。
# 返回值 :
# 该函数返回一个命名元组( FrameInfo ),其中包含了以下属性 :
# filename :栈帧中的代码文件名。
# lineno :栈帧中的当前行号。
# function :栈帧中的当前函数名。
# code_context :一个包含栈帧代码上下文的列表,每个元素是一个元组,包含行号和代码行。
# 注意事项 :
# inspect.getframeinfo() 可以用于调试和分析程序的运行状态,尤其是在需要获取函数调用堆栈时。
# 由于获取栈帧信息可能会有一定的性能开销,因此建议在调试时使用,而不是在生产环境中频繁使用。
# code_context 属性可以用于获取代码的上下文,这对于生成更详细的调试信息非常有用。
# inspect.getframeinfo() 是 Python 反射(reflection)功能的一部分,它允许程序在运行时检查和修改自身的结构和行为。
# 获取上一个帧的信息,包括文件名、函数名等。
file, _, func, _, _ = inspect.getframeinfo(x)
# 如果没有提供 args 字典,则自动获取当前函数的参数。
if args is None: # get args automatically
# arg_values = inspect.getargvalues(frame)
# inspect.getargvalues() 是 Python inspect 模块中的一个函数,用于获取指定栈帧中函数的参数信息。这个函数返回一个命名元组,其中包含了函数的参数名称、 * 和 ** 参数的名字、关键字参数的名字,以及局部变量的字典。
# 参数 :
# frame :一个栈帧对象,通常通过 inspect.currentframe() 或 inspect.stack() 获取。
# 返回值 :
# 该函数返回一个命名元组 ArgInfo ,包含以下属性 :
# args :一个列表,包含所有位置参数的名称。
# varargs :一个字符串,表示 * 形式的参数名称,或者 None 如果没有。
# keywords :一个字符串,表示 ** 形式的参数名称,或者 None 如果没有。
# locals :一个字典,包含函数调用时的局部变量。
# 注意事项 :
# inspect.getargvalues() 可以用于调试和分析程序的运行状态,尤其是在需要获取函数调用参数时。
# 由于获取栈帧信息可能会有一定的性能开销,因此建议在调试时使用,而不是在生产环境中频繁使用。
# 在 Python 3.3 及以上版本中, inspect.getargvalues() 函数因疏忽在 Python 3.5 中被错误地标记为弃用,但实际上它仍然是一个有效的函数。
# inspect.getargvalues() 是 Python 反射(reflection)功能的一部分,它允许程序在运行时检查和修改自身的结构和行为。
# 使用 inspect.getargvalues 获取参数信息。
args, _, _, frm = inspect.getargvalues(x)
# 从 frm 中筛选出实际定义的参数。
args = {k: v for k, v in frm.items() if k in args}
# 尝试将文件路径转换为相对于 ROOT 的相对路径。
try:
# 将文件路径转换为相对路径,并移除文件后缀。
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
# 如果 ROOT 不是文件路径的前缀,则只使用文件名。
except ValueError:
# 获取文件名(不含扩展名)。
file = Path(file).stem
# 构建要打印的前缀字符串,包括文件名和函数名(如果相应的标志为 True )。
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
# 使用 LOGGER 打印参数信息,参数名和值以 key=value 格式显示,并且使用 colorstr 函
# 数可能对输出进行着色。
# def colorstr(*input): -> 用于给字符串添加 ANSI 转义代码,从而在支持 ANSI 颜色代码的终端中输出彩色文本。构建并返回最终的字符串。 -> return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
# 这个函数非常有用,特别是在调试时,可以快速查看函数调用时传递的参数。通过 show_file 和 show_func 参数,可以控制输出信息的详细程度。
16.def init_seeds(seed=0, deterministic=False):
# 这段代码定义了一个名为 init_seeds 的函数,其目的是在不同的库中初始化随机种子,以确保实验的可重复性。
# 定义了一个名为 init_seeds 的函数,它接受两个参数。
# 1.seed :用于初始化随机种子的整数值,默认为 0。
# 2.deterministic :一个布尔值,指示是否启用确定性算法,默认为 False。
def init_seeds(seed=0, deterministic=False):
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html 初始化随机数生成器(RNG)种子https://pytorch.org/docs/stable/notes/randomness.html 。
# 设置随机种子。
# 为 Python 内置的 random 模块设置随机种子。
random.seed(seed)
# 为 NumPy 库设置随机种子。
np.random.seed(seed)
# 为 PyTorch 库设置随机种子。
torch.manual_seed(seed)
# 为 PyTorch 的 CUDA(GPU)设置随机种子。
torch.cuda.manual_seed(seed)
# 对于多 GPU 环境,为所有 GPU 设置随机种子。这个函数是异常安全的,意味着即使有 GPU 失败,它也会继续执行。
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287 自动批处理问题 https://github.com/ultralytics/yolov5/issues/9287 。
# 启用确定性算法。
# 如果 deterministic 参数为 True 并且 PyTorch 版本至少为 1.12.0,则启用确定性算法。
# def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False): -> 用于检查当前安装的软件包版本是否满足指定的最低版本要求。 函数返回 result ,即版本检查的结果。 -> return result
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
# 启用 PyTorch 的确定性算法。
torch.use_deterministic_algorithms(True)
# 设置 PyTorch 的 cuDNN 后端为确定性模式,这有助于确保每次运行代码时,得到的结果是一致的。
torch.backends.cudnn.deterministic = True
# 设置环境变量 CUBLAS_WORKSPACE_CONFIG ,这会影响 NVIDIA 的 cuBLAS 库的行为,减少内存使用并提高确定性。
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
# 设置环境变量 PYTHONHASHSEED ,这会影响 Python 中字典的哈希值,确保字典的键值对顺序是一致的。
os.environ['PYTHONHASHSEED'] = str(seed)
# init_seeds 函数是一个实用工具,用于在不同的库中初始化随机种子,并在需要时启用确定性算法。这有助于确保实验和模型训练的可重复性,特别是在需要确保不同运行之间的结果一致性时。通过设置随机种子和确定性算法,可以减少随机性对实验结果的影响。
17.def intersect_dicts(da, db, exclude=()):
# 这段代码定义了一个名为 intersect_dicts 的函数,其目的是计算两个字典 da 和 db 中具有匹配键和形状的元素的交集,同时排除指定的键。
# 定义了一个名为 intersect_dicts 的函数,它接受三个参数。
# 1.da :第一个字典。
# 2.db :第二个字典。
# 3.exclude :一个元组,包含需要从交集中排除的键,默认为空元组。
def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values 匹配键和形状的字典交集,省略“exclude”键,使用 da 值。
# 计算交集。
# 这行代码使用字典推导式来构建一个新的字典,包含满足以下条件的键值对 :
# k in db :键 k 同时存在于 da 和 db 中。
# all(x not in k for x in exclude) :键 k 不包含在 exclude 元组中的任何元素。
# v.shape == db[k].shape : da 中的值 v 的形状与 db 中对应值的形状相同。
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
# intersect_dicts 函数提供了一种方便的方式来计算两个字典的交集,同时考虑键的存在性、排除特定键和值的形状匹配。这在处理需要比较两个数据结构中相同形状数组的场景时非常有用,例如在机器学习或数据分析中比较模型参数。
18.def get_default_args(func):
# 这段代码定义了一个名为 get_default_args 的函数,其目的是获取一个函数的默认参数值。
# 定义了一个名为 get_default_args 的函数,它接受一个参数。
# 1.func :这个参数是一个函数。
def get_default_args(func):
# Get func() default arguments 获取 func() 默认参数。
# inspect.signature(func, *, follow=True, bound=True)
# inspect.signature() 是 Python 标准库 inspect 模块中的一个函数,用于获取一个函数、方法或可调用对象的签名。函数签名包括函数的参数名称、参数类型、参数默认值以及返回类型等信息。
# 参数 :
# func :要检查的函数、方法或可调用对象。
# follow :一个布尔值,默认为 True 。如果为 True ,并且 func 是一个绑定方法,则返回该方法的 __func__ 的签名。
# bound :一个布尔值,默认为 True 。如果为 True ,则在返回的签名中包含 bound 参数。
# 返回值 :
# 返回一个 Signature 对象,该对象包含了函数的签名信息。
# Signature 对象的属性和方法 :
# Signature 对象提供了以下方法和属性 :
# parameters :一个 OrderedDict ,其键是参数名,值是 Parameter 对象,包含了参数的详细信息。
# return_annotation :返回值的注解。
# bind() :创建一个新的 Signature 对象,该对象绑定到特定的实例。
# replace() :创建一个新的 Signature 对象,该对象是当前对象的副本,但可以替换某些属性。
# 注意事项 :
# inspect.signature() 可以用于任何可调用对象,包括函数、方法、类以及其他实现了 __call__ 方法的对象。
# 使用 inspect.signature() 可以帮助动态地分析和调用函数,特别是在需要处理函数参数的复杂场景中,例如在构建装饰器、框架或测试工具时。
# 使用 inspect 模块的 signature 函数来获取 func 函数的签名。函数签名包含了函数的所有参数信息,包括参数名、参数类型、默认值等。
signature = inspect.signature(func)
# inspect.Parameter
# inspect.Parameter 是一个类,它是 Python 标准库 inspect 模块的一部分。这个类表示函数参数的详细信息,包括参数的名称、类型注解、默认值、是否接受位置参数或关键字参数等信息。
# inspect.Parameter 的构造函数不是直接公开给用户使用的,而是在 inspect 模块内部使用,当调用 inspect.signature() 函数时,返回的 Signature 对象中的 parameters 属性会包含 Parameter 实例。
# 以下是 inspect.Parameter 类的主要属性和方法 :
# 属性 :
# empty :一个特殊的类级标记,用来指定没有默认值和注解。
# name :参数的名称,是一个字符串,并且必须是一个有效的 Python 标识符。
# default :参数的默认值。如果参数没有默认值,这个属性被设置为 Parameter.empty 。
# annotation :参数的注解。如果参数没有注解,这个属性被设置为 Parameter.empty 。
# kind :描述参数如何绑定到函数参数。可能的值包括 :
# POSITIONAL_ONLY :只能作为位置参数提供。
# POSITIONAL_OR_KEYWORD :可以作为位置参数或关键字参数提供(这是 Python 函数中标准绑定行为)。
# VAR_POSITIONAL :未绑定到其他参数的一组位置参数。对应于 Python 函数定义中的 *args 参数。
# KEYWORD_ONLY :必须作为关键字参数提供。关键字参数是那些出现在 * 或 *args 条目之后的参数。
# VAR_KEYWORD :未绑定到其他参数的一组关键字参数。对应于 Python 函数定义中的 **kwargs 参数。
# 方法 :
# replace() :创建一个修改后的 Parameter 对象的副本。这个方法接受参数,允许修改 Parameter 对象的属性,如 name 、 kind 、 default 、 annotation 等,并返回一个新的 Parameter 对象。
# 注意事项 :
# inspect.Parameter 对象是不可变的。如果你需要修改一个 Parameter 对象,你应该使用 Parameter.replace() 方法来创建一个新的修改过的副本。
# Parameter 类提供了丰富的信息,可以帮助你深入了解函数参数的特性,这在动态分析和调用函数时非常有用。
# 使用字典推导式来构建一个新的字典,包含 func 函数中所有具有默认值的参数。
# signature.parameters 是一个字典,其键是参数名,值是 inspect.Parameter 对象,包含了参数的详细信息。
# k: v.default :字典推导式中的表达式, k 是参数名, v.default 是参数的默认值。
# if v.default is not inspect.Parameter.empty :这个条件过滤出那些具有默认值的参数,排除那些没有默认值的参数。 inspect.Parameter.empty 是一个特殊的单例对象,用于表示参数没有默认值。
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
# get_default_args 函数提供了一种方便的方式来获取一个函数的默认参数值。这在需要动态调用函数或检查函数签名时非常有用,尤其是在调试、测试或构建函数调用的高级工具时。
19.def get_latest_run(search_dir='.'):
# 这段代码定义了一个名为 get_latest_run 的函数,其目的是在指定目录及其子目录中查找最新的 last.pt 文件。这个文件通常用于存储训练模型的最新状态,以便后续可以从中断的地方恢复训练。
# 定义了一个名为 get_latest_run 的函数,它接受一个参数。
# 1.search_dir :表示要搜索的目录,默认为当前目录( . )。
def get_latest_run(search_dir='.'):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from) 返回 /runs 中最新的“last.pt”的路径(即 --resume from)。
# 查找最新的 last.pt 文件。
# 使用 glob 模块的 glob 函数来查找所有匹配 last*.pt 模式的文件。 recursive=True 参数表示搜索是递归的,即会搜索 search_dir 目录及其所有子目录。
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
# max(iterable, *args, key=None, default=None)
# max() 是 Python 中的一个内置函数,用于返回多个参数中的最大值。这个函数可以比较任何可比较的数据类型的值,包括数字、字符串、元组等,只要它们是可比较的。
# 参数 :
# iterable :第一个参数,可以是任何可迭代对象,如列表、元组、集合等,从中找出最大值。
# \*args :零个或多个参数,这些参数与 iterable 中的元素进行比较。
# key :(可选)一个函数,它会被用来在比较前对每个元素进行转换。这个函数接收一个参数,返回一个用于比较的值。
# default :(可选)如果 iterable 是空的,并且提供了 default ,则返回 default 值。在 Python 3.8 及以上版本中可用。
# 返回值 :
# 返回 iterable 和 args 中的最大值。
# 注意事项 :
# 如果 iterable 中的元素不可比较, max() 函数会抛出 TypeError 。
# 如果 iterable 是空的,并且没有提供 default 参数,在 Python 3.8 之前的版本中, max() 函数会抛出 ValueError 。从 Python 3.8 开始,可以指定 default 参数来处理空的可迭代对象。
# os.path.getctime(path)
# os.path.getctime() 是 Python 标准库 os.path 模块中的一个函数,它用于获取文件的创建时间。在 Unix 系统中,这通常是文件的 inode 创建时间,而在 Windows 系统中,这通常是文件的实际创建时间。
# 参数 :
# path :文件的路径,可以是相对路径或绝对路径。
# 返回值 :
# 返回指定文件的创建时间,以自纪元(Epoch,即 1970 年 1 月 1 日 00:00:00 UTC)以来的秒数表示。
# 注意事项 :
# os.path.getctime() 返回的时间是一个浮点数,在 Unix 系统中,这个时间的精度是纳秒级别的。
# 在不同的操作系统上, os.path.getctime() 的表现可能有所不同。在 Unix 系统上,它返回的是文件 inode 的创建时间,这可能不是文件的实际创建时间。在 Windows 系统上,它返回的是文件的实际创建时间。
# 如果文件不存在,调用 os.path.getctime() 将抛出 FileNotFoundError 异常。
# 如果需要更详细的时间信息,可以使用 datetime 模块将返回的时间戳转换为可读的日期和时间格式。
# 返回最新的文件路径。
# 检查 last_list 是否非空。如果找到了匹配的文件,使用 max 函数和 os.path.getctime 作为键来确定最新的文件, os.path.getctime 返回文件的创建时间。如果没有找到文件,函数返回一个空字符串。
return max(last_list, key=os.path.getctime) if last_list else ''
# get_latest_run 函数提供了一种方便的方式来查找最新的 last.pt 文件,这在训练中断后恢复模型时非常有用。通过递归搜索指定目录,这个函数可以快速定位到最新的训练状态文件。
20.def file_age(path=__file__):
# 这段代码定义了一个名为 file_age 的函数,其目的是计算自指定文件上次更新以来经过的天数。
# 定义了一个名为 file_age 的函数,它接受一个参数。
# 1.path :默认值为 __file__ ,这是一个特殊的 Python 变量,它包含了当前文件的路径。
def file_age(path=__file__):
# Return days since last file update 返回自上次文件更新以来的天数。
# 计算了文件的“年龄”,即自文件上次修改以来经过的时间。
# datetime.now() 返回当前的日期和时间。
# datetime.fromtimestamp() 将自纪元以来的秒数转换为 datetime 对象。这里使用的是 Path(path).stat().st_mtime ,它获取文件的最后修改时间(mtime),这是一个自纪元以来的秒数。
# 两个 datetime 对象相减得到一个 timedelta 对象,表示两个时间点之间的差异。
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
# 返回 timedelta 对象的 days 属性,即自文件上次更新以来经过的整天数。注释掉的部分 + dt.seconds / 86400 可以用来返回包括小数部分的天数,即包括小时、分钟和秒在内的完整天数。
return dt.days # + dt.seconds / 86400 # fractional days
# file_age 函数提供了一种方便的方式来计算文件自上次更新以来经过的天数。这对于确定文件的新旧程度或计划清理旧文件时非常有用。通过简单地调用这个函数并传递文件路径,你可以快速获取文件的年龄信息。
21.def file_date(path=__file__):
# 这段代码定义了一个名为 file_date 的函数,其目的是返回指定文件的修改日期,格式为 年-月-日 。
# 定义了 file_date 函数,接受一个参数。
# 1.path :该参数指定了文件的路径,默认为 __file__ ,即当前脚本文件的路径。
def file_date(path=__file__):
# Return human-readable file modification date, i.e. '2021-3-26' 返回人类可读的文件修改日期,例如“2021-3-26”。
# Path(path) 创建一个 pathlib.Path 对象,用于处理文件路径。
# .stat() 获取文件的状态信息,类似于 Unix 的 stat 系统调用。
# .st_mtime 获取文件的最后修改时间,以 Unix 时间戳(自 1970 年 1 月 1 日以来的秒数)表示。
# datetime.fromtimestamp() 将 Unix 时间戳转换为 datetime 对象。
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
# 返回格式化的日期字符串,格式为 年-月-日 。
# datetime 对象属性 : year 年份。 month 月份(1-12)。 day 月份中的日(1-31)。
return f'{t.year}-{t.month}-{t.day}'
# 这个函数非常有用,特别是在需要记录文件最后修改时间的场景,例如日志记录、备份操作或者版本控制。通过这个函数,你可以轻松地获取并显示文件的修改日期,以便于追踪文件的更新历史。
22.def file_size(path):
# 这段代码定义了一个名为 file_size 的函数,用于计算指定文件或目录的大小,以兆字节(MB)为单位。
# 定义了一个名为 file_size 的函数,它接受一个参数。
# 1.path :表示要计算大小的文件或目录的路径。
def file_size(path):
# Return file/dir size (MB) 返回文件/目录大小(MB)。
# 定义了 mb 变量,用于将字节转换为兆字节(Mebibyte,1 MiB = 1024 * 1024 字节)。这里使用的是二进制前缀(MiB),而不是十进制前缀(MB)。
mb = 1 << 20 # bytes to MiB (1024 ** 2)
# 将 path 参数转换为 Path 对象,以便使用 pathlib 模块提供的方法。
path = Path(path)
# 如果 path 是一个文件,使用 path.stat() 获取文件的状态信息,其中 st_size 属性表示文件大小(字节)。然后将其转换为兆字节并返回。
if path.is_file():
return path.stat().st_size / mb
# 如果 path 是一个目录,使用 path.glob('**/*') 递归地获取目录中所有文件的路径。然后,对每个文件使用 f.stat().st_size 获取其大小,并累加这些大小。最后,将总大小转换为兆字节并返回。
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
# 如果 path 既不是文件也不是目录,函数返回 0.0 。
else:
return 0.0
# file_size 函数提供了一种方便的方式来计算文件或目录的大小。它递归地计算目录中所有文件的大小,对于单个文件直接返回其大小。这个函数可以用于各种场景,比如监控磁盘使用情况、清理大文件等。
23.def check_online():
# 这段代码定义了一个名为 check_online 的函数,其目的是检查设备是否连接到互联网。这个函数尝试连接到一个外部服务器来验证网络连接是否可用。
# 定义了一个名为 check_online 的函数,它不接受任何参数。
def check_online():
# Check internet connectivity 检查互联网连接。
# 导入 Python 的 socket 模块,这个模块提供了访问套接字的方法,可以用于网络通信。
import socket
# 在 check_online 函数内部定义了一个名为 run_once 的嵌套函数,用于执行单次的网络连接检查。
# 这段代码定义了一个名为 run_once 的函数,其目的是尝试检查互联网连接是否可用。这个函数尝试连接到一个外部服务器来验证网络连接是否正常。
# 定义了一个名为 run_once 的函数,不接受任何参数。
def run_once():
# Check once
# 检查互联网连接。
# 使用 try - except 语句来执行一次网络连接检查。
try:
# try 块中, socket.create_connection(("1.1.1.1", 443), 5) 尝试创建一个到指定 IP 地址 "1.1.1.1" 和端口 443 的网络连接。这个 IP 地址是 Cloudflare 提供的一个公共 DNS 服务器地址,端口 443 是 HTTPS 服务的标准端口。连接尝试的超时时间设置为 5 秒。
socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility 检查主机可访问性。
# 如果连接成功,函数返回 True ,表示当前设备可以访问互联网。
return True
# except 块中,如果 socket.create_connection 抛出 OSError 异常(可能是因为网络不可达、连接超时或其他网络错误),函数捕获这个异常并返回 False ,表示当前设备无法访问互联网。
except OSError:
return False
# run_once 函数提供了一种简单的方式来尝试检查设备的互联网连接状态。通过尝试连接到一个外部服务器,这个函数可以帮助确定设备是否能够访问互联网。这对于需要网络连接的应用程序来说非常有用,比如在执行网络操作之前验证网络状态。
# check_online 函数调用 run_once 两次,并使用逻辑或操作符 or 来组合结果。如果至少有一次连接成功, check_online 函数就返回 True ,表示设备连接到互联网。这种方法增加了对间歇性网络连接问题的鲁棒性。
return run_once() or run_once() # check twice to increase robustness to intermittent connectivity issues 检查两次以增强对间歇性连接问题的稳健性。
# check_online 函数提供了一种简单的方式来检查设备的互联网连接状态。通过尝试连接到一个外部服务器,并执行两次检查以增加鲁棒性,这个函数可以帮助确定设备是否能够访问互联网。这对于需要网络连接的应用程序来说非常有用,比如在执行网络操作之前验证网络状态。
24.def git_describe(path=ROOT):
# 这段代码定义了一个名为 git_describe 的函数,其目的是获取 Git 仓库的描述信息。这个描述信息通常包含了最近的标签(tag)和自那个标签以来的提交次数,以及当前的提交哈希。
# 定义了 git_describe 函数,接受一个参数。
# 1.path :该参数指定了 Git 仓库的路径,默认为 ROOT (通常是一个全局变量,指向项目的根目录)。
def git_describe(path=ROOT): # path must be a directory
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe 返回人类可读的 git 描述,即 v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe。
# 开始一个 try-except 块,用于捕获并处理可能发生的异常。
try:
# 断言指定路径下的 .git 目录存在,确保给定的路径是一个 Git 仓库。
assert (Path(path) / '.git').is_dir()
# subprocess.check_output(cmd, *args, **kwargs)
# check_output() 函数是 Python 标准库 subprocess 模块中的一个函数,它用于执行指定的命令并获取命令的输出。如果命令执行成功, check_output() 会返回命令的输出;如果命令执行失败(即返回非零退出状态),则会抛出一个 CalledProcessError 异常。
# 参数说明 :
# cmd :要执行的命令,可以是字符串或者字符串列表。如果是字符串,会被 shell 解释,这与 shell=True 相同;如果是字符串列表,则直接传递给底层的 execvp() 函数。
# *args :传递给 subprocess.Popen() 的其他参数。
# **kwargs :传递给 subprocess.Popen() 的其他关键字参数。常用的关键字参数包括 :
# shell :如果为 True ,则 cmd 会被 shell 解释。默认为 False 。
# stdout :子进程的 stdout 管道。默认为 subprocess.PIPE ,即捕获输出。
# stderr :子进程的 stderr 管道。默认为 subprocess.PIPE ,即捕获错误输出。
# universal_newlines 或 text :如果设置为 True ,则 check_output() 返回一个字符串而不是字节对象。在 Python 3.7 及更高版本中, text 参数被引入, universal_newlines 被废弃。
# 返回值 :
# 返回执行命令后的标准输出(stdout)。
# 异常 :
# 如果命令返回非零退出状态, check_output() 会抛出 subprocess.CalledProcessError 异常。
# 使用 check_output 函数执行 Git 命令,获取仓库的描述信息。
# shell=True 允许命令在 shell 中执行。 .decode() 将命令输出的字节串解码为字符串。 [:-1] 去除字符串末尾的换行符。
return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
# 如果在执行上述命令时发生任何异常(例如,路径不是 Git 仓库,或者 Git 命令失败),则捕获异常。
except Exception:
# 如果发生异常,则返回一个空字符串。
return ''
# 这个函数非常有用,特别是在需要自动化构建和版本控制的项目中,可以快速获取当前代码的状态和版本信息。通过这个函数,你可以在日志、错误报告或者构建的文档中包含 Git 描述信息,以便于追踪和记录。
25.def check_git_status(repo='WongKinYiu/yolov9', branch='main'):
# 这段代码定义了一个名为 check_git_status 的函数,它被两个装饰器 @TryExcept() 和 @WorkingDirectory(ROOT) 修饰。这个函数的目的是检查当前 Git 仓库的状态,如果代码过时,则推荐执行 git pull 来更新代码。
# 函数使用 @TryExcept() 装饰器来捕获并处理可能发生的异常。
# class TryExcept(contextlib.ContextDecorator):
# -> 这个上下文管理器的作用是在代码块执行过程中捕获异常,并在异常发生时打印一条包含自定义消息和异常值的消息。
# -> def __init__(self, msg=''):
@TryExcept()
# 使用 @WorkingDirectory(ROOT) 装饰器来确保函数执行时在指定的根目录 ROOT 下。
# class WorkingDirectory(contextlib.ContextDecorator):
# -> 它是一个上下文装饰器,用于在代码块执行期间临时改变当前工作目录,并在代码块执行完毕后恢复原来的工作目录。
# -> def __init__(self, new_dir):
@WorkingDirectory(ROOT)
# 函数 check_git_status 接受两个参数。
# 1.repo 和 2.branch :分别表示 GitHub 上的 仓库 和 分支 ,默认值分别为 'WongKinYiu/yolov9' 和 'main' 。
def check_git_status(repo='WongKinYiu/yolov9', branch='main'):
# YOLO status check, recommend 'git pull' if code is out of date
# 这段代码是 check_git_status 函数的一部分,它执行了几个关键步骤来准备和验证 Git 仓库的状态检查。
# 构建 GitHub 仓库 URL。使用 f-string(格式化字符串字面量)构建了当前要检查的 GitHub 仓库的 URL。变量 repo 是一个包含用户名和仓库名的字符串,例如 'WongKinYiu/yolov9' 。
url = f'https://github.com/{repo}'
# 构建更新信息消息。创建了一个消息 msg ,它包含了一个链接到 GitHub 仓库的 URL,以便用户可以查看更新。
msg = f', for updates see {url}' # ,有关更新,请参阅 {url}。
# 准备状态消息。调用 colorstr 函数,它用于给字符串添加颜色。这里它被用来创建一个带有前缀 'github: ' 的字符串 s ,这个字符串将用于构建后续的状态消息。
# def colorstr(*input): -> 用于给字符串添加 ANSI 转义代码,从而在支持 ANSI 颜色代码的终端中输出彩色文本。构建并返回最终的字符串。 -> return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
s = colorstr('github: ') # string
# 验证当前目录是 Git 仓库。使用 assert 语句来验证当前目录是否包含 .git 子目录,这是 Git 仓库的标志。如果 .git 目录不存在,断言失败,并显示消息 s 加上 'skipping check (not a git repository)' 和 msg 。
assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg # 跳过检查(不是 git 存储库)。
# 验证设备是否在线。使用 assert 语句来验证设备是否在线,通过调用 check_online 函数。如果设备不在线,断言失败,并显示消息 s 加上 'skipping check (offline)' 和 msg 。
# def check_online(): -> 检查设备是否连接到互联网。调用 run_once 两次,并使用逻辑或操作符 or 来组合结果。如果至少有一次连接成功, check_online 函数就返回 True ,表示设备连接到互联网。 -> return run_once() or run_once()
assert check_online(), s + 'skipping check (offline)' + msg # 跳过检查(离线)。
# 这些代码行共同确保了在执行 Git 状态检查之前,当前目录是一个 Git 仓库,并且设备可以访问互联网。如果这些条件不满足,将显示适当的消息,并根据断言失败提前退出函数。这是一种常见的错误处理和预检查方法,用于在执行更复杂的操作之前确保所有先决条件都已满足。
# 这段代码是 check_git_status 函数的一部分,用于检查当前 Git 仓库的远程分支状态,并确定本地分支是否是最新的。
# subprocess.check_output(cmd, *args, **kwargs)
# check_output() 函数是 Python 标准库 subprocess 模块中的一个函数,它用于执行指定的命令并获取命令的输出。如果命令执行成功, check_output() 会返回命令的输出;如果命令执行失败(即返回非零退出状态),则会抛出一个 CalledProcessError 异常。
# 参数说明 :
# cmd :要执行的命令,可以是字符串或者字符串列表。如果是字符串,会被 shell 解释,这与 shell=True 相同;如果是字符串列表,则直接传递给底层的 execvp() 函数。
# *args :传递给 subprocess.Popen() 的其他参数。
# **kwargs :传递给 subprocess.Popen() 的其他关键字参数。常用的关键字参数包括 :
# shell :如果为 True ,则 cmd 会被 shell 解释。默认为 False 。
# stdout :子进程的 stdout 管道。默认为 subprocess.PIPE ,即捕获输出。
# stderr :子进程的 stderr 管道。默认为 subprocess.PIPE ,即捕获错误输出。
# universal_newlines 或 text :如果设置为 True ,则 check_output() 返回一个字符串而不是字节对象。在 Python 3.7 及更高版本中, text 参数被引入, universal_newlines 被废弃。
# 返回值 :
# 返回执行命令后的标准输出(stdout)。
# 异常 :
# 如果命令返回非零退出状态, check_output() 会抛出 subprocess.CalledProcessError 异常。
# 分割远程仓库信息。使用 subprocess.check_output 执行 git remote -v 命令,获取当前 Git 仓库的所有远程仓库信息。 shell=True 参数允许直接在 shell 中执行命令。然后,使用 re.split 按空格分割输出字符串,得到一个包含远程仓库信息的列表 splits 。
splits = re.split(pattern=r'\s', string=check_output('git remote -v', shell=True).decode())
# 检查特定的远程仓库是否存在。
# 检查列表 splits 中是否包含指定的 repo 。
matches = [repo in s for s in splits]
if any(matches):
# 如果找到了匹配的远程仓库, remote 变量将被设置为该仓库的名称。
remote = splits[matches.index(True) - 1]
else:
# 如果没有找到, remote 被设置为 'ultralytics' ,并且使用 check_output 执行 git remote add 命令来添加一个新的远程仓库。
remote = 'ultralytics'
check_output(f'git remote add {remote} {url}', shell=True)
# 更新远程仓库信息。执行 git fetch 命令来更新远程仓库 remote 的信息, timeout=5 参数设置超时时间为 5 秒。
check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
# 获取当前检出的分支名称。执行 git rev-parse --abbrev-ref HEAD 命令来获取当前检出的分支名称,并去除两端的空白字符。
local_branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
# 计算落后的提交数。执行 git rev-list 命令来计算本地分支 local_branch 相对于远程分支 remote}/{branch} 落后的提交数,并将结果转换为整数。
n = int(check_output(f'git rev-list {local_branch}..{remote}/{branch} --count', shell=True)) # commits behind
# 这些代码行共同完成了以下任务 :获取远程仓库信息,并检查是否存在指定的远程仓库。 如果不存在指定的远程仓库,则添加它。 更新远程仓库的信息。 获取当前检出的分支名称。 计算本地分支相对于远程分支落后的提交数。这些步骤为后续的代码提供了必要的信息,以便确定是否需要执行 git pull 来更新代码。
# 这段代码是 check_git_status 函数的最后部分,用于根据本地分支与远程分支的同步状态来构建消息,并记录日志。
# 检查落后的提交数。检查变量 n (本地分支落后远程分支的提交数)是否大于0。如果大于0,表示本地代码库不是最新的。
if n > 0:
# 构建更新提示消息。
# 根据远程仓库的名称来决定使用哪个 git pull 命令。如果远程仓库名称是 origin (Git默认的远程仓库名称),则使用简单的 git pull 命令。否则,使用指定远程仓库和分支的 git pull 命令。
pull = 'git pull' if remote == 'origin' else f'git pull {remote} {branch}'
# 将一个警告消息添加到字符串 s 。消息中包含了落后的提交数 n ,以及如何更新代码的指令。如果 n 大于1,则在“commit”后面加上“s”来正确地显示复数形式。
s += f"⚠️ YOLO is out of date by {n} commit{'s' * (n > 1)}. Use `{pull}` or `git clone {url}` to update." # ⚠️ YOLO 已过期 {n} 次提交{'s' * (n > 1)}。使用 `{pull}` 或 `git clone {url}` 进行更新。
# 构建已同步消息。
# 如果 n 不大于0,表示本地代码库是最新的,代码将添加一个表示已同步的消息到字符串 s 。
else:
s += f'up to date with {url} ✅' # 了解 {url} 的最新动态✅。
# 记录日志。使用 LOGGER 对象的 info 方法记录信息级别的日志,日志内容为字符串 s 。 LOGGER 是一个配置好的日志记录器实例。
LOGGER.info(s)
# 这段代码通过比较本地分支和远程分支的提交数来决定是否需要更新代码,并构建相应的消息。如果代码过时,它提供了如何更新的指令;如果代码是最新的,它确认了代码的同步状态。这些信息被记录在日志中,以便于开发者了解代码库的状态。这种方法有助于自动化地管理和维护代码库,确保使用的是最新的代码。
# check_git_status 函数提供了一种自动化的方式来检查当前 Git 仓库的状态,并在代码过时时提供更新建议。这个函数可以集成到更大的系统中,以确保使用的代码始终是最新的。通过装饰器 @TryExcept() 和 @WorkingDirectory(ROOT) ,函数还增加了错误处理和目录管理的功能。
26.def check_git_info(path='.'):
# 这段代码定义了一个名为 check_git_info 的函数,用于检查当前 Git 仓库的信息,包括远程仓库 URL、当前分支名称和最新的提交哈希值。这个函数被 @WorkingDirectory(ROOT) 装饰器修饰,确保在指定的根目录 ROOT 下执行。
# 函数使用 @WorkingDirectory(ROOT) 装饰器,确保在执行时切换到 ROOT 目录。
# class WorkingDirectory(contextlib.ContextDecorator):
# -> 它是一个上下文装饰器,用于在代码块执行期间临时改变当前工作目录,并在代码块执行完毕后恢复原来的工作目录。
# -> def __init__(self, new_dir):
@WorkingDirectory(ROOT)
# 函数 check_git_info 接受一个参数。
# 1.path :表示要检查的目录路径,默认为当前目录( . )。
def check_git_info(path='.'):
# YOLO git info check, return {remote, branch, commit} YOLO git info 检查,返回 {remote, branch, commit} 。
# 检查 Git 仓库信息。
# 调用 check_requirements 函数,用于确保 gitpython 包已经安装。 gitpython 是一个 Python 包,提供了对 Git 仓库的访问和操作接口。
check_requirements('gitpython')
# 导入 git 模块,这是 gitpython 包的一部分,用于与 Git 仓库交互。
import git
# 尝试创建一个 git.Repo 对象,它表示指定路径下的 Git 仓库。
try:
repo = git.Repo(path)
# 如果成功创建 git.Repo 对象,获取远程仓库 origin 的 URL,并移除末尾的 .git 。
remote = repo.remotes.origin.url.replace('.git', '') # i.e. 'https://github.com/WongKinYiu/yolov9'
# 获取当前 HEAD 提交的哈希值。
commit = repo.head.commit.hexsha # i.e. '3134699c73af83aac2a481435550b968d5792c0d'
# 尝试获取当前活动的分支名称。如果当前 HEAD 处于分离状态(即不是在任何分支上),则会捕获 TypeError 异常,并设置 branch 为 None 。
try:
branch = repo.active_branch.name # i.e. 'main'
except TypeError: # not on any branch
branch = None # i.e. 'detached HEAD' state
# 返回一个包含 远程仓库 URL 、 当前分支名称 和 提交哈希值 的字典。
return {'remote': remote, 'branch': branch, 'commit': commit}
# 如果路径不是一个 Git 仓库目录,会捕获 InvalidGitRepositoryError 异常。
except git.exc.InvalidGitRepositoryError: # path is not a git dir
# 返回一个所有值为 None 的字典。
return {'remote': None, 'branch': None, 'commit': None}
# check_git_info 函数提供了一种方便的方式来获取当前 Git 仓库的关键信息。这个函数可以用于确保代码运行在一个有效的 Git 仓库中,并且可以获取仓库的状态信息,这对于版本控制和同步非常有用。通过装饰器 @WorkingDirectory(ROOT) ,函数还确保了在指定的根目录下执行,增加了代码的灵活性和可移植性。
27.def check_python(minimum='3.7.0'):
# 这段代码定义了一个名为 check_python 的函数,其目的是检查当前运行的 Python 版本是否满足指定的最低版本要求。
# 定义了 check_python 函数,它接受一个参数。
# 1.minimum :该参数指定了所需的最低 Python 版本,默认值为 '3.7.0' 。
def check_python(minimum='3.7.0'):
# Check current python version vs. required python version
# 调用了 check_version 函数,传入了四个参数 :
# platform.python_version() :这是一个函数调用,返回当前运行的 Python 版本字符串。
# minimum :这是函数 check_python 的参数,表示所需的最低 Python 版本。
# name='Python ' :这是一个字符串,表示在日志或错误消息中使用的名称,这里指 Python 。
# hard=True :这是一个布尔值参数,当设置为 True 时,如果当前版本不满足最低版本要求,函数将抛出异常。
check_version(platform.python_version(), minimum, name='Python ', hard=True)
28.def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
# 这段代码定义了一个名为 check_version 的函数,它用于检查当前安装的软件包版本是否满足指定的最低版本要求。
# 函数定义。
# 1.current : 当前安装的版本,默认为 '0.0.0' 。
# 2.minimum : 所需的最低版本,默认为 '0.0.0' 。
# 3.name : 版本名称的前缀,默认为 'version ' 。
# 4.pinned : 是否严格匹配版本,即不允许任何偏离,默认为 False 。
# 5.hard : 是否在版本不满足要求时抛出异常,默认为 False 。
# 6.verbose : 是否在版本不满足要求时打印警告信息,默认为 False 。
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
# Check version vs. required version
# pkg_resources.parse_version(version_string)
# pkg.parse_version() 是 Python pkg_resources 模块中的一个函数,用于解析版本号字符串,并将其转换为可以比较的 Version 对象。这个函数非常适用于版本控制和管理,因为它能够识别并正确比较复杂的版本号,包括预发布版本和后发布版本。
# 参数 :
# version_string :一个表示版本号的字符串。
# 返回值 :
# 返回一个 Version 对象,该对象包含了版本号的各个部分,如主版本号、次版本号、修订版本号等,并且可以用于版本比较。
# 使用列表推导式和 pkg.parse_version 函数来解析 current 和 minimum 参数,将它们转换为可以比较的版本对象。
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
# 根据 pinned 参数的值来决定是比较当前版本和最小版本是否相等,还是检查当前版本是否大于或等于最小版本。
result = (current == minimum) if pinned else (current >= minimum) # bool
# 构造一个警告信息字符串,指出所需的最低版本和当前安装的版本。
s = f'WARNING ⚠️ {name}{minimum} is required by YOLO, but {name}{current} is currently installed' # string 警告 ⚠️ YOLO 需要 {name}{minimum},但当前已安装 {name}{current}。
# 如果 hard 参数为 True 。
if hard:
# 则使用 assert 语句来确保 result 为 True ,如果不是,则抛出异常,异常信息是 emojis(s) , emojis 是一个函数,能够将字符串转换为带有表情符号的警告信息。
# def emojis(str=''): -> 返回一个在特定平台(特别是Windows)上安全的、不包含emoji的字符串版本。 -> return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
assert result, emojis(s) # assert min requirements met
# 如果 verbose 参数为 True 且 result 为 False ,则使用日志记录器 LOGGER 来记录警告信息。
if verbose and not result:
LOGGER.warning(s)
# 函数返回 result ,即版本检查的结果。
return result
# check_version 函数用于检查当前版本是否满足最低要求,并且可以根据参数的不同,以不同的方式处理版本不符合要求的情况。
29.def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=''):
# 这段代码是一个名为 check_requirements 的Python函数,它用于检查是否安装了满足YOLO要求的依赖项。如果某些依赖项未安装或版本不兼容,函数会尝试自动安装它们。
# @TryExcept() :这是一个装饰器,用于捕获函数执行过程中的异常。
# class TryExcept(contextlib.ContextDecorator):
# -> 这个上下文管理器的作用是在代码块执行过程中捕获异常,并在异常发生时打印一条包含自定义消息和异常值的消息。
# -> def __init__(self, msg=''):
@TryExcept()
# 定义函数 check_requirements ,接受四个参数。
# 1.requirements :依赖文件或依赖列表,默认为 ROOT 目录下的 requirements.txt 文件。
# 2.exclude :需要排除的依赖项列表,默认为空。
# 3.install :是否自动安装未满足的依赖项,默认为 True 。
# 4.cmds :安装依赖时附加的命令行参数,默认为空字符串。
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=''):
# Check installed dependencies meet YOLO requirements (pass *.txt file or list of packages or single package str) 检查已安装的依赖项是否满足 YOLO 要求(传递 *.txt 文件或包列表或单个包 str)。
# 定义一个前缀字符串,用于日志输出,颜色为红色,字体加粗。
prefix = colorstr('red', 'bold', 'requirements:')
# 检查Python版本是否符合要求。
# def check_python(minimum='3.7.0'): -> 检查当前运行的 Python 版本是否满足指定的最低版本要求。
check_python() # check python version
# 如果 requirements 是 Path 对象(即文件路径)。
if isinstance(requirements, Path): # requirements.txt file
# Path.resolve(strict=False)
# resolve() 函数是 Python pathlib 模块中 Path 类的一个方法,用于将路径对象转换为绝对路径,并解析路径中的所有符号链接(symlinks),返回一个新的路径对象。
# 参数 :
# strict :布尔值,默认为 False 。如果设置为 True ,则在路径不存在时引发 FileNotFoundError 。如果为 False ,则尽可能解析路径并附加任何剩余部分,而不检查它是否存在。
# 返回值 :
# 返回一个新的 Path 对象,该对象表示原始路径的绝对路径,并且已经解析了所有符号链接。
# 功能 :
# resolve() 方法将路径对象转换为绝对路径,并解析路径中的所有符号链接。
# 它还会处理路径中的所有特殊符号,例如 . (当前目录)和 .. (上级目录)。
# 如果路径不存在且 strict 参数为 True ,则引发 FileNotFoundError 。
# 如果 strict 参数为 False ,则尽可能解析路径并附加任何剩余部分,而不检查它是否存在。
# 注意事项 :
# resolve() 方法总是返回一个绝对路径,无论传入的路径是相对路径还是绝对路径。
# 如果路径是相对路径,则会根据当前工作目录将其解析为绝对路径。
# resolve() 方法会对路径进行标准化处理,消除冗余的分隔符、处理上级目录符号(..)和当前目录符号(.),以保证返回的路径是规范化的。
# 解析文件路径。
file = requirements.resolve()
# 确保文件存在。
assert file.exists(), f"{prefix} {file} not found, check failed." # 未找到 {prefix} {file},检查失败。
# 打开文件。
with file.open() as f:
# pkg.parse_requirements(str, session=None, name=None, requirements=None)
# pkg_resources.parse_requirements() 函数是 Python setuptools 包中的一个函数,它用于解析 requirements.txt 文件或其他字符串形式的依赖要求。这个函数将字符串形式的依赖要求解析为 pkg_resources.Requirement 对象列表,这些对象包含了依赖包的名称和版本信息。
# 参数说明 :
# str :可以是一个包含依赖要求的字符串,或者是一个文件对象,该文件对象包含 requirements.txt 文件的内容。
# session :(可选)一个 pkg_resources.WorkingSet 对象,或者任何具有 fetch 方法的对象。这个参数用于处理需要下载的依赖。
# name :(可选)正在构建的包的名称,如果适用的话。
# requirements :(可选)一个 Requirement 对象列表,新解析的依赖要求将被添加到这个列表中。
# 返回值 :
# 返回一个 Requirement 对象列表,每个对象代表一个解析出的依赖要求。
# 解析文件中的依赖项,排除指定的依赖项。
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
# 如果 requirements 是字符串。
elif isinstance(requirements, str):
# 将字符串转换为列表。
requirements = [requirements]
# 初始化一个空字符串,用于存储未满足的依赖项。
s = ''
# 初始化一个计数器,用于记录未满足的依赖项数量。
n = 0
# 遍历依赖项列表。
for r in requirements:
# 尝试检查依赖项是否已安装。
try:
# 检查依赖项是否满足。
pkg.require(r)
# 捕获版本冲突或未找到分发包的异常。
except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
# 将未满足的依赖项添加到字符串 s 中。
s += f'"{r}" '
# 增加未满足的依赖项计数。
n += 1
# 如果存在未满足的依赖项,且 install 为 True ,且环境变量 AUTOINSTALL 为 True 。
if s and install and AUTOINSTALL: # check environment variable
# 记录日志,尝试自动更新。
LOGGER.info(f"{prefix} YOLO requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...") # 未找到 {prefix} YOLO 要求{'s' * (n > 1)} {s},正在尝试自动更新...
# 尝试自动安装依赖项。
try:
# assert check_online(), "AutoUpdate skipped (offline)"
# subprocess.check_output(cmd, *args, **kwargs)
# check_output() 函数是 Python 标准库 subprocess 模块中的一个函数,它用于执行指定的命令并获取命令的输出。如果命令执行成功, check_output() 会返回命令的输出;如果命令执行失败(即返回非零退出状态),则会抛出一个 CalledProcessError 异常。
# 参数说明 :
# cmd :要执行的命令,可以是字符串或者字符串列表。如果是字符串,会被 shell 解释,这与 shell=True 相同;如果是字符串列表,则直接传递给底层的 execvp() 函数。
# *args :传递给 subprocess.Popen() 的其他参数。
# **kwargs :传递给 subprocess.Popen() 的其他关键字参数。常用的关键字参数包括 :
# shell :如果为 True ,则 cmd 会被 shell 解释。默认为 False 。
# stdout :子进程的 stdout 管道。默认为 subprocess.PIPE ,即捕获输出。
# stderr :子进程的 stderr 管道。默认为 subprocess.PIPE ,即捕获错误输出。
# universal_newlines 或 text :如果设置为 True ,则 check_output() 返回一个字符串而不是字节对象。在 Python 3.7 及更高版本中, text 参数被引入, universal_newlines 被废弃。
# 返回值 :
# 返回执行命令后的标准输出(stdout)。
# 异常 :
# 如果命令返回非零退出状态, check_output() 会抛出 subprocess.CalledProcessError 异常。
# 执行 pip install 命令并记录输出。
LOGGER.info(check_output(f'pip install {s} {cmds}', shell=True).decode())
# 确定依赖项来源。
source = file if 'file' in locals() else requirements
# 记录更新成功的日志信息。
# def colorstr(*input): -> 用于给字符串添加 ANSI 转义代码,从而在支持 ANSI 颜色代码的终端中输出彩色文本。构建并返回最终的字符串。 -> return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" # {prefix} {n} package{'s' * (n > 1)} 根据 {source} {prefix} 更新 ⚠️ {colorstr('bold', '重新启动运行时或重新运行命令以使更新生效')}。
# 输出日志信息。
LOGGER.info(s)
# 捕获其他异常。
except Exception as e:
# 记录警告日志。
LOGGER.warning(f'{prefix} ❌ {e}')
# 这个函数的主要功能是检查和安装YOLO所需的依赖项,确保环境配置正确。它使用 pkg_resources 模块来解析和检查依赖项,使用 subprocess 模块来执行 pip install 命令。此外,它还使用 colorstr 函数来美化日志输出。
30.def check_img_size(imgsz, s=32, floor=0):
# 这段代码定义了一个名为 check_img_size 的函数,其目的是验证图像尺寸是否是某个步长的倍数,这在计算机视觉任务中,尤其是在目标检测模型(如YOLO系列)中,是一个常见的要求。
# 定义了一个名为 check_img_size 的函数,它接受三个参数。
# 1.imgsz :图像的尺寸,可以是一个整数或一个列表/元组,表示图像的宽度和高度。
# 2.s :步长,默认为32,这是模型中用于确定特征图尺寸的参数。
# 3.floor :地板值,默认为0,用于确保尺寸不会低于这个值。
def check_img_size(imgsz, s=32, floor=0):
# Verify image size is a multiple of stride s in each dimension 验证图像大小是否是每个维度的 stride 的倍数。
# 验证图像尺寸。
# 如果 imgsz 是一个整数,表示图像尺寸是一个值(例如640),将计算一个新尺寸,这个新尺寸是步长 s 的倍数,并且不小于 floor 值。 make_divisible 函数用于确保给定的尺寸是步长的倍数。
if isinstance(imgsz, int): # integer i.e. img_size=640
# def make_divisible(x, divisor): -> 将输入值 x 调整到最接近的、可以被 divisor 整除的数。计算 x 除以 divisor 的结果,并使用 math.ceil 函数向上取整到最近的整数。 -> return math.ceil(x / divisor) * divisor
new_size = max(make_divisible(imgsz, int(s)), floor)
# 如果 imgsz 是一个列表或元组,表示图像尺寸是宽度和高度的一对值,将计算新的宽度和高度,确保它们都是步长 s 的倍数,并且不小于 floor 值。
else: # list i.e. img_size=[640, 480]
imgsz = list(imgsz) # convert to list if tuple
new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
# 日志警告。
# 如果新尺寸 new_size 与原始尺寸 imgsz 不同,将记录一个警告日志,提示图像尺寸需要是最大步长 s 的倍数,并显示更新后的尺寸。
if new_size != imgsz:
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}') # 警告 ⚠️ --img-size {imgsz} 必须是最大步幅 {s} 的倍数,更新为 {new_size}。
# 返回新尺寸。函数返回新计算的尺寸。
return new_size
# check_img_size 函数确保图像尺寸符合模型的要求,即图像尺寸必须是步长的倍数。这对于确保模型能够正确处理输入图像非常重要。如果提供的尺寸不符合要求,函数将调整尺寸并记录警告日志。这个函数可以用于在图像处理或模型训练之前验证和调整图像尺寸。
31.def check_imshow(warn=False):
# 这段代码定义了一个名为 check_imshow 的函数,其目的是检查当前环境是否支持图像显示,特别是在使用 OpenCV 的 cv2.imshow() 函数时。这个检查在某些情况下非常有用,比如在非交互式环境(如某些服务器或 Docker 容器)中,图像显示可能不被支持。
# 定义了一个名为 check_imshow 的函数,它接受一个参数。
# 1.warn :这是一个布尔值,指示是否在环境不支持图像显示时记录警告日志,默认为 False 。
def check_imshow(warn=False):
# Check if environment supports image displays 检查环境是否支持图像显示。
# 开始了一个 try 块,用于尝试执行后面的代码,并捕获可能发生的任何异常。
try:
# 使用 assert 语句来确保当前环境不是 Jupyter 笔记本(通过调用 is_notebook() 函数)和 Docker 容器(通过调用 is_docker() 函数)。如果任何一个条件为真,将引发 AssertionError 异常。
assert not is_notebook()
assert not is_docker()
# 尝试使用 OpenCV 的 cv2.imshow() 函数显示一个测试图像。这里创建了一个尺寸为 1x1 的黑色图像(使用 np.zeros((1, 1, 3)) )。
cv2.imshow('test', np.zeros((1, 1, 3)))
# cv2.waitKey(delay)
# cv2.waitKey() 是 OpenCV 库中的一个函数,它用于等待特定的毫秒数,在这个时间段内,它允许其他窗口消息被处理,特别是在创建窗口和处理键盘或鼠标事件时非常有用。
# 参数 :
# delay :等待的毫秒数。如果设置为 0,则 cv2.waitKey() 会无限期等待用户的按键输入;如果设置为负数,则不会等待用户的任何输入,直接返回。
# 返回值 :
# 返回值是等待期间按下的键的 ASCII 码。如果没有任何按键被按下,返回 -1 。
# 注意事项 :
# cv2.waitKey() 在不同的操作系统上的行为可能略有不同。
# 在某些情况下, cv2.waitKey() 可能需要与其他 OpenCV 函数(如 cv2.setMouseCallback() )一起使用,以处理更复杂的交互。
# 当使用 cv2.waitKey() 时,确保在程序结束前调用 cv2.destroyAllWindows() 来关闭所有 OpenCV 创建的窗口,以避免资源泄露。
# 调用 OpenCV 的 cv2.waitKey() 函数,等待 1 毫秒,以便操作系统处理任何挂起的事件,如窗口更新。
cv2.waitKey(1)
# 调用 OpenCV 的 cv2.destroyAllWindows() 函数,关闭所有 OpenCV 创建的窗口。
cv2.destroyAllWindows()
# 再次调用 cv2.waitKey() 函数,确保所有窗口都已正确关闭。
cv2.waitKey(1)
# 如果以上代码都成功执行,函数返回 True ,表示环境支持图像显示。
return True
# 异常处理。如果 try 块中的代码抛出任何异常, except 块将捕获这个异常。
except Exception as e:
# 如果 warn 参数为 True ,则记录一条警告日志,说明环境不支持 cv2.imshow() 或 PIL Image.show() ,并显示异常信息。
if warn:
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}') # 警告⚠️ 环境不支持 cv2.imshow() 或 PIL Image.show()\n{e} 。
# 在异常情况下,函数返回 False ,表示环境不支持图像显示。
return False
# check_imshow 函数用于检测当前环境是否能够支持图像显示功能。如果环境不支持,函数将返回 False 并(如果 warn 为 True )记录一条警告日志。这个函数对于确定是否能够在当前环境中显示图像非常有用,特别是在自动化脚本或非图形用户界面环境中。
32.def check_suffix(file='yolo.pt', suffix=('.pt',), msg=''):
# 这段代码定义了一个名为 check_suffix 的函数,它用于检查一个或多个文件的后缀是否符合指定的后缀列表。
# 函数定义。
# 1.file : 要检查的文件路径或文件路径列表,默认为 'yolo.pt' 。
# 2.suffix : 可接受的文件后缀列表,默认为 ('pt',) ,表示只接受 .pt 后缀的文件。
# 3.msg : 自定义的错误消息前缀,默认为空字符串。
def check_suffix(file='yolo.pt', suffix=('.pt',), msg=''):
# Check file(s) for acceptable suffix 检查文件是否有可接受的后缀。
# 确保 file 和 suffix 参数都不为空。
if file and suffix:
# 检查 suffix 是否为字符串类型,如果是,则将其转换为列表,以便后续迭代。
if isinstance(suffix, str):
# 将单个后缀字符串包装成列表。
suffix = [suffix]
# 遍历 file 参数,如果 file 是列表或元组,直接迭代;如果不是,将 file 视为单个文件路径,包装成列表进行迭代。
for f in file if isinstance(file, (list, tuple)) else [file]:
# 获取文件路径 f 的后缀,并转换为小写。
s = Path(f).suffix.lower() # file suffix
# 检查后缀 s 是否非空。
if len(s):
# 使用 assert 语句断言文件的后缀 s 是否在可接受的后缀列表 suffix 中。如果不在,抛出 AssertionError 异常,并显示自定义的错误消息。
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}" # {msg}{f} 可接受的后缀是 {suffix}。
# check_suffix 函数提供了一个简单的方式来验证文件的后缀是否符合预期。这个函数特别适用于在处理文件输入时,确保用户提供的文件具有正确的格式。通过自定义错误消息,函数增强了代码的健壮性和用户友好性。
33.def check_yaml(file, suffix=('.yaml', '.yml')):
# 这段代码定义了一个名为 check_yaml 的函数,其目的是检查一个 YAML 文件是否存在,如果不存在且文件是一个网址,则下载该文件。这个函数还检查文件名后缀是否符合 YAML 文件的标准后缀。
# 定义了 check_yaml 函数,接受两个参数。
# 1.file :要检查的 YAML 文件路径或网址。
# 2.suffix :一个元组,包含 YAML 文件可能使用的后缀名,默认为 ('.yaml', '.yml') 。
def check_yaml(file, suffix=('.yaml', '.yml')):
# Search/download YAML file (if necessary) and return path, checking suffix 搜索/下载 YAML 文件(如有必要)并返回路径,检查后缀。
# 调用 check_file 函数,并传入 file 和 suffix 参数。 check_file 函数会检查文件是否存在,如果不存在且是一个网址,则下载文件,并确保文件后缀符合传入的 suffix 参数。
return check_file(file, suffix)
# 这个函数可以用于在执行其他操作(如加载配置文件)之前,确保 YAML 文件可用且格式正确。通过调用这个函数,你可以简化代码,因为它封装了文件检查和下载的逻辑。
34.def check_file(file, suffix=''):
# 这段代码定义了一个名为 check_file 的函数,其目的是检查一个文件是否存在,如果不存在且文件是一个网址,则下载该文件;如果文件是一个 ClearML 数据集 ID,则检查 ClearML 是否已安装;如果都不是,则在指定的目录中搜索文件。
# 定义了 check_file 函数,接受两个参数。
# 1.file :要检查的文件路径或网址。
# 2.suffix :可选参数,用于指定文件后缀。
def check_file(file, suffix=''):
# Search/download file (if necessary) and return path 搜索/下载文件(如有必要)并返回路径。
# 一个调用,用于检查文件后缀是否符合预期。
# def check_suffix(file='yolo.pt', suffix=('.pt',), msg=''): -> 用于检查一个或多个文件的后缀是否符合指定的后缀列表。
check_suffix(file, suffix) # optional
# 确保 file 参数是一个字符串。
file = str(file) # convert to str()
# 检查文件是否存在或者 file 参数是否为空字符串。
if os.path.isfile(file) or not file: # exists
# 如果文件存在或者 file 参数为空,函数返回 file 。
return file
# 检查 file 是否是一个以 http:// 或 https:// 开头的网址。如果是网址,函数会下载该文件。
elif file.startswith(('http:/', 'https:/')): # download
# 保存原始网址。
url = file # warning: Pathlib turns :// -> :/
# 解析网址,去掉查询参数,并获取文件名。
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
# 如果文件已存在本地,则打印信息并返回文件路径。
if os.path.isfile(file):
LOGGER.info(f'Found {url} locally at {file}') # file already exists 在本地的 {file} 上找到了 {url}。
# 如果文件不存在,则下载文件,并断言下载成功。
else:
LOGGER.info(f'Downloading {url} to {file}...') # 正在下载 {url} 至 {file}...
torch.hub.download_url_to_file(url, file)
assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check 文件下载失败:{url}。
return file
# 检查 file 是否是一个以 clearml:// 开头的 ClearML 数据集 ID。如果是 ClearML 数据集 ID,函数会检查 ClearML 是否已安装,并返回 file 。
elif file.startswith('clearml://'): # ClearML Dataset ID
assert 'clearml' in sys.modules, "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'." # ClearML 未安装,因此无法使用 ClearML 数据集。尝试运行“pip install clearml”。
return file
# 如果 file 不是文件路径、网址或 ClearML 数据集 ID,则在指定目录中搜索文件。
else: # search
files = []
# 搜索的目录包括 'data' 、 'models' 和 'utils' 。
for d in 'data', 'models', 'utils': # search directories
# 使用 glob.glob 搜索匹配的文件,并断言至少找到一个文件,且只有一个匹配的文件。
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
assert len(files), f'File not found: {file}' # assert file was found 未找到文件:{file}。
assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique 多个文件与“{file}”匹配,请指定确切路径:{files}。
# 返回找到的文件路径。
return files[0] # return file
# 这个函数提供了一个通用的方法来处理文件路径,无论是本地文件、网址还是 ClearML 数据集 ID,都能确保返回一个有效的文件路径。此外,它还提供了错误处理,如果文件不存在或有多个文件匹配,会抛出异常。
35.def check_font(font=FONT, progress=False):
# 这段代码定义了一个名为 check_font 的函数,其目的是检查指定的字体文件是否存在于本地,如果不存在,则从网络上下载该字体文件。
# 定义 check_font 函数,它接受两个参数。
# 1.font :字体文件的路径,默认为 FONT ,这里 FONT 应该是一个在代码其他地方定义的变量,指向一个特定的字体文件。
# 2.progress :一个布尔值,指示在下载过程中是否显示进度条,默认为 False 。
def check_font(font=FONT, progress=False):
# Download font to CONFIG_DIR if necessary 如果需要的话将字体下载到 CONFIG_DIR。
# 将 font 参数转换为 Path 对象,以便使用路径操作。
font = Path(font)
# 构造字体文件在本地配置目录 CONFIG_DIR 中的完整路径。
file = CONFIG_DIR / font.name
# 检查原始字体文件和本地配置目录中的字体文件是否都不存在。
if not font.exists() and not file.exists():
# 如果字体文件不存在,则构造一个 URL,指向 Ultralytics 网站上的字体文件。
url = f'https://ultralytics.com/assets/{font.name}'
# 记录一条信息日志,告知用户正在从 URL 下载字体文件到本地路径。
LOGGER.info(f'Downloading {url} to {file}...') # 正在下载 {url} 至 {file}...
# 使用 PyTorch 的 torch.hub.download_url_to_file 函数下载字体文件。这个函数接受 URL、文件路径和进度条显示选项作为参数。
torch.hub.download_url_to_file(url, str(file), progress=progress)
36.def check_dataset(data, autodownload=True):
# 这段代码定义了一个名为 check_dataset 的函数,其目的是检查、下载(如果需要)并解压数据集,确保数据集在本地可用。
# 这是 check_dataset 函数的定义,它接受两个参数。
# 1.data :数据集的配置信息,可以是一个字符串、 Path 对象或一个包含数据集信息的字典。
# 2.autodownload :一个布尔值,指示是否自动下载缺失的数据集,默认为 True 。
def check_dataset(data, autodownload=True):
# Download, check and/or unzip dataset if not found locally 如果本地未找到数据集,请下载、检查和/或解压缩数据集。
# Download (optional)
# 初始化一个空字符串 extract_dir ,用于存储数据集解压后的目录路径。
extract_dir = ''
# is_zip = zipfile.is_zipfile(filename)
# is_zipfile() 函数是 Python zipfile 模块中的一个函数,用于检查一个文件是否是有效的 ZIP 文件格式。
# 参数 :
# filename : 要检查的文件的路径,可以是字符串、文件对象或路径对象。
# 返回值 :
# is_zipfile() 函数返回一个布尔值。 如果文件是有效的 ZIP 文件,则返回 True 。 如果文件不是有效的 ZIP 文件或文件不存在,则返回 False 。
# is_zipfile() 函数的实现依赖于文件的“魔术数字”(文件开头的字节序列),这是许多文件格式用来标识自己的一种方式。ZIP 文件的魔术数字是 PK ( 0x50 0x4B ),这个序列出现在所有 ZIP 文件的开头。
# 如果一个文件以这个序列开头, is_zipfile() 函数就会返回 True ,表明该文件是一个 ZIP 文件。这个函数在处理文件上传、归档和解压缩任务时非常有用,因为它可以帮助程序确定如何处理特定的文件。
# is_tar = tarfile.is_tarfile(name)
# is_tarfile() 函数是 Python tarfile 模块中的一个函数,用于检查一个文件是否是有效的 tar 归档文件格式。
# 参数 :
# name : 要检查的文件的路径,可以是字符串、文件对象或路径对象。
# 返回值 :
# is_tarfile() 函数返回一个布尔值。 如果文件是有效的 tar 归档文件,则返回 True 。 如果文件不是有效的 tar 归档文件或文件不存在,则返回 False 。
# is_tarfile() 函数的实现依赖于文件的“魔术数字”(文件开头的字节序列),这是许多文件格式用来标识自己的一种方式。tar 归档文件的魔术数字是 ustar (在文件的 257 字节处),这个序列出现在所有 tar 归档文件的特定位置。
# 如果一个文件在指定位置包含这个序列, is_tarfile() 函数就会返回 True ,表明该文件是一个 tar 归档文件。这个函数在处理文件上传、归档和解压缩任务时非常有用,因为它可以帮助程序确定如何处理特定的文件。
# 检查 data 是否是一个字符串或 Path 对象,并且是否指向一个 ZIP 或 TAR 文件。
if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
# 如果 data 是一个压缩文件,使用 download 函数下载文件,并解压到指定目录。
# def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3): -> 用于多线程文件下载和解压缩的工具函数,常用于在 data.yaml 文件中设置的自动下载功能。
download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
# 查找解压后的目录中的 YAML 配置文件,并更新 data 变量。
data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
# 更新 extract_dir 为 YAML 文件的父目录,并设置 autodownload 为 False 。
extract_dir, autodownload = data.parent, False
# Read yaml (optional)
# 如果 data 是一个字符串或 Path 对象,使用 yaml_load 函数加载 YAML 文件内容。
if isinstance(data, (str, Path)):
# def yaml_load(file='data.yaml'): -> 用于从 YAML 文件中安全地加载数据。返回值。函数返回从 YAML 文件中加载的 Python 对象,通常是字典或列表,具体取决于 YAML 文件的内容。 -> return yaml.safe_load(f)
data = yaml_load(data) # dictionary
# Checks
# 检查 data 字典中是否包含必要的键 : train 、 val 和 names 。
for k in 'train', 'val', 'names':
# def emojis(str=''): -> 返回一个在特定平台(特别是Windows)上安全的、不包含emoji的字符串版本。 -> return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
assert k in data, emojis(f"data.yaml '{k}:' field missing ❌") # data.yaml '{k}:' 字段缺失 ❌。
# 如果 data['names'] 是列表或元组,将其转换为字典。
if isinstance(data['names'], (list, tuple)): # old array format
data['names'] = dict(enumerate(data['names'])) # convert to dict
# 确保 data['names'] 的键是整数。
assert all(isinstance(k, int) for k in data['names'].keys()), 'data.yaml names keys must be integers, i.e. 2: car' # data.yaml 名称键必须是整数,即 2:car 。
# 更新 data 字典,设置类别数量 nc 。
data['nc'] = len(data['names'])
# 这段代码负责处理数据集路径,确保它们是绝对路径,并且如果需要的话,对数据集中的 train 、 val 和 test 键进行路径前缀处理。
# Resolve paths
# 创建一个 Path 对象 path ,其值由以下条件决定 :如果 extract_dir 非空,则使用 extract_dir 。 如果 extract_dir 为空,尝试从 data 字典中获取 'path' 键的值。 如果两者都为空,则使用空字符串。
path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
# path_obj.is_absolute()
# is_absolute() 是 Python pathlib 模块中 Path 类的一个方法,用于检查一个路径对象是否表示一个绝对路径。
# 参数 :
# path_obj : 一个 Path 对象,表示要检查的路径。
# 返回值 :
# is_absolute() 方法返回一个布尔值 :如果路径是绝对的,返回 True 。 如果路径是相对的,返回 False 。
# is_absolute() 方法在处理文件和目录路径时非常有用,尤其是在你需要确定一个路径是否需要进一步处理以转换为绝对路径时。
# 检查 path 是否为绝对路径。
if not path.is_absolute():
# 如果 path 不是绝对路径,将其与 ROOT 路径(通常是项目的根目录)结合,并使用 resolve() 方法获取绝对路径。
path = (ROOT / path).resolve()
# 更新 data 字典中的 'path' 键,存储绝对路径。
data['path'] = path # download scripts
# 遍历 data 字典中的 'train' 、 'val' 和 'test' 键。
for k in 'train', 'val', 'test':
# 检查 data 字典中是否存在当前键 k 。
if data.get(k): # prepend path
# 检查 data[k] 是否为字符串类型。
if isinstance(data[k], str):
# 如果是字符串,将 path 与 data[k] 结合,并使用 resolve() 方法获取绝对路径。
x = (path / data[k]).resolve()
# 如果得到的路径 x 不存在,并且 data[k] 以 '../' 开头,尝试去掉 '../' 并重新获取绝对路径。
if not x.exists() and data[k].startswith('../'):
x = (path / data[k][3:]).resolve()
# 更新 data 字典中的当前键 k ,存储绝对路径的字符串表示。
data[k] = str(x)
# 如果 data[k] 不是字符串类型(即它是一个列表),则对列表中的每个元素执行相同的路径处理。
else:
# 对列表中的每个元素 x ,将其与 path 结合,并使用 resolve() 方法获取绝对路径,然后将结果转换为字符串列表。
data[k] = [str((path / x).resolve()) for x in data[k]]
# 这段代码确保了数据集中的所有路径都是绝对路径,这对于后续的数据加载和处理非常重要,特别是当数据集目录结构发生变化或者在不同的机器上运行时。通过这种方式,可以确保程序总是能够正确地找到数据集文件。
# 这段代码是 check_dataset 函数的一部分,它负责解析 YAML 文件中的配置信息,并根据配置处理数据集的下载和验证。
# Parse yaml
# 使用字典推导式从 data 字典中获取 'train' 、 'val' 、 'test' 和 'download' 键的值。
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
# 检查是否存在验证集路径( val )。
if val:
# 将验证集路径转换为 Path 对象,并解析为绝对路径。
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
# 检查所有验证集路径是否存在。
if not all(x.exists() for x in val):
# 如果存在不存在的路径,记录一条日志信息。
LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()]) # 未找到数据集⚠️,缺少路径 %s 。
# 如果没有提供下载链接( s )或 autodownload 为 False ,则抛出异常。
if not s or not autodownload:
raise Exception('Dataset not found ❌') # 未找到数据集❌。
# 记录当前时间,用于计算下载时间。
t = time.time()
# 如果下载链接是一个以 .zip 结尾的 HTTP URL。
if s.startswith('http') and s.endswith('.zip'): # URL
# 获取 URL 的文件名。
f = Path(s).name # filename
LOGGER.info(f'Downloading {s} to {f}...') # 正在下载 {s} 至 {f}...
# 使用 PyTorch 的 torch.hub.download_url_to_file 函数下载文件。
torch.hub.download_url_to_file(s, f)
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
# 调用 unzip_file 函数解压下载的文件。
# def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')): -> 解压缩一个 .zip 文件到指定的路径,并排除包含特定字符串的文件。
unzip_file(f, path=DATASETS_DIR) # unzip
# 删除下载的 ZIP 文件。
Path(f).unlink() # remove zip
r = None # success
# 如果下载链接是一个以 'bash ' 开头的字符串,表示一个 Bash 脚本。
elif s.startswith('bash '): # bash script
LOGGER.info(f'Running {s} ...') # 正在运行 {s} ...
# 执行 Bash 脚本并获取返回值。
r = os.system(s)
# 如果下载链接是一个 Python 脚本。
else: # python script
# exec(object, globals=None, locals=None)
# 在Python中, exec() 函数用于执行存储在字符串或对象中的Python代码。这个函数非常强大,但也需谨慎使用,因为它会执行任意的代码,这可能导致安全风险。
# 参数说明 :
# object :必需,要执行的代码,可以是字符串或代码对象。
# globals :可选,用于执行代码时的全局变量字典。如果为 None ,则使用当前环境的全局变量。
# locals :可选,用于执行代码时的局部变量字典。如果为 None ,则使用 globals 作为局部变量环境。
# 返回值 :
# exec() 函数没有返回值(即返回 None ),因为它直接执行代码,而不是返回执行结果。
# 安全注意事项 :
# 由于 exec() 可以执行任意代码,因此它可能会被用来执行恶意代码。因此,只有在完全信任代码来源的情况下才应该使用 exec() ,并且永远不要对用户提供的输入使用 exec() ,除非经过了严格的验证和清理。
# 执行 Python 脚本,并传入 data 字典。
r = exec(s, {'yaml': data}) # return None
# 计算下载和解压所花费的时间。
dt = f'({round(time.time() - t, 1)}s)'
# 根据执行结果,生成一条成功或失败的消息。
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌" # 成功✅ {dt},保存至 {colorstr('bold', DATASETS_DIR)} 。 / 失败 {dt} ❌。
# 记录下载结果的日志信息。
LOGGER.info(f"Dataset download {s}") # 数据集下载 {s} 。
# 调用 check_font 函数检查字体文件是否存在,如果需要,下载字体文件。
# def check_font(font=FONT, progress=False): -> 检查指定的字体文件是否存在于本地,如果不存在,则从网络上下载该字体文件。
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
# 返回包含数据集信息的字典。
return data # dictionary
# 这段代码的目的是确保数据集的完整性和可用性。如果数据集缺失或不完整,它将尝试自动下载数据集。此外,它还处理了不同类型下载链接的情况,包括 HTTP URL、Bash 脚本和 Python 脚本。最后,它还检查了字体文件,确保数据集的标签可以正确显示。
# 这个函数的目的是确保数据集在本地可用,无论是通过下载、解压还是路径解析。它还检查数据集的完整性,并在必要时执行下载脚本。这对于确保数据集准备阶段的顺利进行至关重要。
37.def check_amp(model):
# 这段代码定义了一个名为 check_amp 的函数,用于检查 PyTorch 自动混合精度(AMP)功能是否在模型上正确工作。如果 AMP 功能正常,函数返回 True ;如果出现问题,则返回 False 并记录警告日志。
# 定义了一个名为 check_amp 的函数,它接受一个参数。
# 1.model :表示要检查的 PyTorch 模型。
def check_amp(model):
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation 检查 PyTorch 自动混合精度 (AMP) 功能。操作正确则返回 True 。
# 导入依赖。导入 AutoShape 和 DetectMultiBackend 类,这些类用于调整模型输入尺寸和处理多后端模型。
from models.common import AutoShape, DetectMultiBackend
# 辅助函数 amp_allclose 。
# 定义了一个名为 amp_allclose 的辅助函数,它接受两个参数。
# 1.model 和 2.im :分别表示 模型 和 输入图像 。
def amp_allclose(model, im):
# All close FP32 vs AMP results
# 创建一个 AutoShape 实例,用于自动调整模型输入尺寸。
# class AutoShape(nn.Module):
# -> AutoShape 类实现了一个输入鲁棒性的模型包装器,用于处理不同格式的输入数据(如 OpenCV、NumPy、PIL 或 PyTorch 张量),并执行预处理、推理和非最大抑制(NMS)。
# -> def __init__(self, model, verbose=True):
m = AutoShape(model, verbose=False) # model
# 使用 FP32 精度(即标准的浮点数精度)进行推理,并获取结果。
a = m(im).xywhn[0] # FP32 inference
# 设置 AutoShape 实例的 amp 属性为 True ,启用 AMP 功能。
m.amp = True
# 使用 AMP 进行推理,并获取结果。
b = m(im).xywhn[0] # AMP inference
# torch.allclose(actual, desired, rtol=1e-05, atol=1e-08, equal_nan=False)
# torch.allclose() 是 PyTorch 库中的一个函数,用于检查两个张量(tensor)是否在元素级别上近似相等。这个函数主要用于数值计算中,由于浮点数的精度限制,直接比较两个浮点数是否相等可能不总是可靠,因此使用 torch.allclose() 来检查它们是否足够接近。
# 参数 :
# actual :实际的张量。
# desired :期望的张量。
# rtol (relative tolerance):相对容差,默认值为 1e-05 。
# atol (absolute tolerance):绝对容差,默认值为 1e-08 。
# equal_nan :是否将 NaN 视为相等,默认为 False 。
# 返回值 :
# 返回一个布尔值,如果两个张量在指定的容差内近似相等,则返回 True ;否则返回 False 。
# 注意事项 :
# 如果两个张量的形状不同, torch.allclose() 会直接返回 False 。
# rtol 和 atol 参数用于控制近似相等的严格程度。 rtol 是相对容差,适用于比例误差; atol 是绝对容差,适用于固定误差。
# 如果设置 equal_nan=True ,则两个张量中对应位置的 NaN 值也会被视为相等。
# torch.allclose() 在机器学习和数值计算中非常有用,特别是在比较模型预测结果和真实值时,或者在验证算法实现的正确性时。
# 比较 FP32 和 AMP 结果的形状是否相同,以及它们的值是否在 10% 的绝对容差内接近。
return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
# 定义一个前缀字符串,用于日志消息。
prefix = colorstr('AMP: ')
# 获取模型参数所在的设备。
device = next(model.parameters()).device # get model device
# 如果设备是 CPU 或 MPS(苹果的 Metal Performance Shaders),则返回 False ,因为 AMP 只在 CUDA 设备上使用。
if device.type in ('cpu', 'mps'):
return False # AMP only used on CUDA devices
# 定义一个路径,指向用于测试的图像文件。
f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
# 尝试加载本地图像文件,如果文件不存在且在线检查通过,则尝试从网上加载图像,否则创建一个全 1 的数组作为输入图像。
# def check_online(): -> 检查设备是否连接到互联网。函数调用 run_once 两次,并使用逻辑或操作符 or 来组合结果。如果至少有一次连接成功, check_online 函数就返回 True ,表示设备连接到互联网。 -> return run_once() or run_once()
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
# 尝试执行 amp_allclose 函数,如果成功,则记录信息日志并返回 True 。
try:
# assert amp_allclose(deepcopy(model), im) or amp_allclose(DetectMultiBackend('yolo.pt', device), im)
LOGGER.info(f'{prefix}checks passed ✅') # {prefix}检查已通过✅。
return True
# 如果 amp_allclose 函数执行过程中出现任何异常,则记录警告日志,并返回 False 。
except Exception:
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
LOGGER.warning(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}') # {prefix}检查失败❌,禁用自动混合精度。请参阅{help_url}。
return False
# check_amp 函数用于验证模型是否能够在 AMP 模式下正确工作。这个功能对于在支持 AMP 的硬件上加速训练和推理非常有用。如果检查失败,函数将提供帮助链接,以便用户查找解决方案。
38.def yaml_load(file='data.yaml'):
# 这段代码定义了一个名为 yaml_load 的函数,它用于从 YAML 文件中安全地加载数据。
# 函数定义。
# 1.file : YAML 文件的路径,默认为 'data.yaml' 。
def yaml_load(file='data.yaml'):
# Single-line safe yaml loading
# 使用 with 语句打开文件,确保文件在操作完成后正确关闭。
with open(file, errors='ignore') as f:
# yaml.safe_load(stream, Loader=None, object_pairs_hook=None, version=(1, 2), pure=False, preserve_quotes=False)
# yaml.safe_load() 是 PyYAML 库中的一个函数,用于安全地解析 YAML 文档。
# 参数 :
# stream :包含 YAML 文档的输入流,可以是文件对象、字符串等。
# Loader :指定一个 Loader 类来加载 YAML 文档,默认为 SafeLoader ,它只认识标准 YAML 标签,不能构造任意的 Python 对象。
# object_pairs_hook :一个函数,用于处理键值对序列,以便在加载映射时自定义 Python 对象的构造。
# version :指定 YAML 的版本,默认为 (1, 2) ,表示支持的 YAML 版本。
# pure :如果为 True ,则只使用 Python 的内置类型来加载 YAML 文档。
# preserve_quotes :如果为 True ,则保持字符串的引号。
# 返回值 :
# 返回一个 Python 对象,该对象是由 YAML 文档中的第一个文档构建的。如果没有文档,则返回 None 。
# 为什么使用 safe_load 而不是 load :
# safe_load 只能加载符合 YAML 规范的数据,不会执行 YAML 文件中的任何代码,因此更安全。
# load 函数可以处理更多的数据类型,包括 Python 对象和函数,如果 YAML 文档包含恶意代码,使用 load 可能会导致安全风险。
# 注意事项 :
# 始终使用 safe_load 来解析不受信任的 YAML 文档,以确保安全。
# 从受信任的来源获取 YAML 文档,以避免潜在的安全问题。
# 定期更新 PyYAML 库,以获得最新的安全补丁。
# 使用 yaml.safe_load 函数从打开的文件对象 f 中加载 YAML 数据。 yaml.safe_load 是 PyYAML 库中的一个函数,它解析 YAML 文件并返回 Python 对象。
# 返回值。函数返回从 YAML 文件中加载的 Python 对象,通常是字典或列表,具体取决于 YAML 文件的内容。
return yaml.safe_load(f)
# 这个函数假设你已经安装了 PyYAML 库。如果没有安装,你需要先运行 pip install PyYAML 来安装它。
# 使用 safe_load 而不是 load 是因为 safe_load 不会执行 YAML 文件中的任何代码,这使得它更安全,尤其是在处理不可信的输入时。
# 参数 errors='ignore' 用于处理文件编码问题,但在实际使用中可能需要根据具体情况调整错误处理策略。
39.def yaml_save(file='data.yaml', data={}):
# 这段代码定义了一个名为 yaml_save 的函数,其目的是将 Python 字典以 YAML 格式安全地保存到文件中。
# 定义了一个名为 yaml_save 的函数,它接受两个参数。
# 1.file :要保存的文件名,默认为 'data.yaml' 。
# 2.data :要保存的数据,是一个字典,默认为空字典。
def yaml_save(file='data.yaml', data={}):
# Single-line safe yaml saving 单行安全 yaml 保存。
# 开指定的 file 文件用于写入,并返回一个文件对象 f 。 with 语句确保文件在操作完成后正确关闭。
with open(file, 'w') as f:
# 使用 yaml.safe_dump 函数将 data 字典保存到文件中。 yaml.safe_dump 是 PyYAML 库中的一个函数,用于将 Python 对象转换为 YAML 格式并写入文件。
# 字典推导式 {k: str(v) if isinstance(v, Path) else v for k, v in data.items()} 遍历 data 字典中的所有项,如果值是 Path 对象,则将其转换为字符串,否则直接使用该值。
# f 是文件对象,表示要写入的目标文件。
# sort_keys=False 参数指定在写入文件时不排序字典的键。
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
# yaml_save 函数提供了一种方便的方式来将 Python 字典保存为 YAML 文件,这对于配置文件和数据集信息的存储非常有用。函数使用 yaml.safe_dump 确保写入过程的安全性,并处理了 Path 对象到字符串的转换,使得路径信息能够正确保存。
40.def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
# 这段代码定义了一个名为 unzip_file 的函数,其目的是解压缩一个 .zip 文件到指定的路径,并排除包含特定字符串的文件。
# 这是 unzip_file 函数的定义,它接受三个参数。
# 1.file :要解压缩的 .zip 文件的路径。
# 2.path :解压缩的目标路径,默认为 None 。
# 3.exclude :一个元组,包含要排除的文件名字符串,默认为 ('.DS_Store', '__MACOSX') ,这些通常是 macOS 创建的隐藏文件。
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
# Unzip a *.zip file to path/, excluding files containing strings in exclude list
# 检查 path 参数是否为 None 。
if path is None:
# 如果 path 为 None ,则将 path 设置为 .zip 文件的父目录。
path = Path(file).parent # default path
# 使用 ZipFile 类(来自 zipfile 模块)打开 .zip 文件,并将其别名为 zipObj 。
with ZipFile(file) as zipObj:
# 遍历 .zip 文件中所有文件的名称。
for f in zipObj.namelist(): # list all archived filenames in the zip
# 对于每个文件名 f ,检查它是否不包含 exclude 列表中的任何字符串。
if all(x not in f for x in exclude):
# 如果文件名不包含要排除的字符串,则使用 zipObj.extract 方法将文件解压缩到指定的 path 路径。
zipObj.extract(f, path=path)
# 这个函数的目的是解压缩 .zip 文件,同时排除那些不需要的文件,这在处理包含大量文件的压缩包时非常有用,尤其是当某些文件(如系统隐藏文件)不需要被解压时。通过这种方式,可以确保只有需要的文件被解压到目标路径。
41.def url2file(url):
# 这段代码定义了一个名为 url2file 的函数,其目的是将一个 URL 转换为对应的文件名。
# 定义了一个名为 url2file 的函数,它接受一个参数。
# 1.url :表示要转换的 URL 地址。
def url2file(url):
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt 将 URL 转换为文件名,即 https://url.com/file.txt?auth -> file.txt 。
# 首先将 url 参数转换为 Path 对象,然后转换为字符串。这样做可以处理路径中的一些特殊字符。接着,使用 replace 方法将 :/ 替换为 :// ,以确保 URL 的格式正确。
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
# 这行代码执行以下操作 :
# 使用 urllib.parse.unquote 函数对 URL 进行解码,将 URL 中的百分号编码转换回普通字符。
# 将解码后的 URL 转换为 Path 对象。
# 使用 Path 对象的 name 属性获取 URL 的基本名称,即路径中的最后一部分。
# 使用 split 方法将名称按 ? 分割,并返回第一部分,即不包含任何查询参数的文件名。
# 返回值。函数返回从 URL 转换得到的文件名。
return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
# url2file 函数提供了一种方便的方式来从 URL 中提取文件名。这个函数对于处理下载文件、保存网络资源等情况非常有用,能够确保从 URL 中正确提取出文件名,并且处理了 URL 编码和特殊字符。
42.def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
# 这段代码定义了一个名为 download 的函数,它是一个用于多线程文件下载和解压缩的工具函数,常用于在 data.yaml 文件中设置的自动下载功能。
# 这是 download 函数的定义,它接受以下参数 :
# 1.url :要下载的文件的 URL 或文件路径。
# 2.dir :下载文件的目录,默认为当前目录( . )。
# 3.unzip :一个布尔值,指示是否解压缩下载的文件,默认为 True 。
# 4.delete :一个布尔值,指示是否在解压缩后删除原始压缩文件,默认为 True 。
# 5.curl :一个布尔值,指示是否使用 curl 命令下载文件,默认为 False 。
# 6.threads :下载线程的数量,默认为 1。
# 7.retry :下载失败时的重试次数,默认为 3。
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
# Multithreaded file download and unzip function, used in data.yaml for autodownload 多线程文件下载和解压功能,用于data.yaml自动下载。
# 定义一个内部函数 download_one ,用于下载单个文件。
def download_one(url, dir):
# Download 1 file
# 初始化一个标志变量 success ,用于跟踪下载是否成功。
success = True
# 检查 url 是否是一个已存在的文件路径。
if os.path.isfile(url):
# 如果 url 不是一个已存在的文件路径,则创建一个 Path 对象 f ,表示下载文件的目标路径。
f = Path(url) # filename
else: # does not exist
f = dir / Path(url).name
# 记录一条日志信息,表示开始下载文件。
LOGGER.info(f'Downloading {url} to {f}...') # 正在下载 {url} 至 {f}...
# 进行最多 retry + 1 次的下载尝试。
for i in range(retry + 1):
# 如果使用 curl ,则执行 curl 命令下载文件。
if curl:
s = 'sS' if threads > 1 else '' # silent
r = os.system(
f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
success = r == 0
else:
# 如果不使用 curl ,则使用 PyTorch 的 torch.hub.download_url_to_file 函数下载文件。
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
success = f.is_file()
if success:
break
elif i < retry:
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...') # ⚠️ 下载失败,重试 {i + 1}/{retry} {url}...
else:
LOGGER.warning(f'❌ Failed to download {url}...') # ❌ 无法下载 {url}...
# 如果下载成功且 unzip 为 True ,则检查文件是否需要解压缩。
if unzip and success and (f.suffix == '.gz' or is_zipfile(f) or is_tarfile(f)):
LOGGER.info(f'Unzipping {f}...')
# 如果文件是 ZIP 格式,调用 unzip_file 函数解压缩文件。
if is_zipfile(f):
# def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')): -> 解压缩一个 .zip 文件到指定的路径,并排除包含特定字符串的文件。
unzip_file(f, dir) # unzip
# 如果文件是 TAR 格式,使用 tar 命令解压缩文件。
elif is_tarfile(f):
os.system(f'tar xf {f} --directory {f.parent}') # unzip
# 如果文件是 GZ 格式,使用 tar 命令解压缩文件。
elif f.suffix == '.gz':
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
# 如果 delete 为 True ,则删除原始压缩文件。
if delete:
f.unlink() # remove zip
# 将 dir 参数转换为 Path 对象。
dir = Path(dir)
# 创建下载目录。
dir.mkdir(parents=True, exist_ok=True) # make directory
# 如果线程数大于 1,则使用多线程下载。
if threads > 1:
# pool = ThreadPool(processes=5)
# 在Python中, multiprocessing.pool.ThreadPool 是 multiprocessing 模块中的 pool 子模块提供的类之一,它用于创建一个线程池,以便并行执行多个线程任务。这个类是 multiprocessing 库的一部分,该库提供了一种方式来并行化程序,利用多核处理器的能力。
# 参数 :
# processes :参数指定线程池中的线程数量。
# 返回 :
# pool :创建一个线程池实例。
# 方法 ThreadPool 提供了以下方法 :
# apply_async(func, args=(), kwds={}) :异步地将一个函数 func 应用到 args 和 kwds 参数上,并返回一个 AsyncResult 对象。
# map(func, iterable, chunksize=1) :将一个函数 func 映射到一个迭代器 iterable 的所有元素上,并返回一个迭代器。
# close() :关闭线程池,不再接受新的任务。
# join() :等待线程池中的所有任务完成。
# terminate() :立即终止线程池中的所有任务。
# ThreadPool 是 multiprocessing 库中用于并行执行 I/O 密集型任务的工具,它允许程序利用多核处理器的能力来提高性能。需要注意的是, multiprocessing 库中的进程是系统级的进程,与线程相比,它们有更大的开销,但也能提供更好的并行性能。
# 创建一个线程池。
pool = ThreadPool(threads)
# pool.imap(func, iterable, chunksize=None)
# imap() 方法用于将一个可迭代的输入序列分块分配到线程池中的线程进行处理,并将结果返回一个迭代器。这个方法特别适用于需要顺序处理输入和输出的场景。
# 参数 :
# func :一个函数,它将被调用并传入 iterable 中的每个项目。
# iterable :一个可迭代对象,其元素将被传递给 func 函数。
# chunksize :(可选)一个整数,指定了每个任务传递给 func 的项目数量。默认值为 1,意味着每个任务只包含一个项目。如果设置为大于 1 的值,那么 func 将接收到一个包含多个项目的列表。
# 返回值 :
# 返回一个 Iterator ,它生成每个输入元素经过 func 处理后的结果。
# 特点 :
# 结果的顺序与输入序列的顺序相同。
# 如果任何一个任务因为异常而终止, imap() 会立即抛出异常。
# 它允许主线程在子线程完成工作之前继续执行,而不是等待所有任务完成。
# ThreadPool.imap() 是处理 I/O 密集型任务或者需要顺序处理结果的并发任务的有用工具。与之相对的是 imap_unordered() ,它同样返回一个迭代器,但是结果的顺序可能与输入序列不同,适用于不在乎结果顺序的场景。
# 使用线程池下载文件。
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
# 关闭线程池。
pool.close()
# 等待线程池中的所有线程完成。
pool.join()
# 如果线程数为 1,则逐个下载文件。
else:
for u in [url] if isinstance(url, (str, Path)) else url:
download_one(u, dir)
# 这个函数的目的是提供一个灵活的文件下载和解压缩工具,支持多线程下载和自动解压缩,适用于大规模数据集的自动下载和准备。
43.def make_divisible(x, divisor):
# 这段代码定义了一个名为 make_divisible 的函数,其目的是将输入值 x 调整到最接近的、可以被 divisor 整除的数。这个函数在深度学习中常用于计算网络中某些层的输出通道数,以确保输出尺寸是某个特定值的倍数,从而满足某些硬件或算法的要求。
# 这行定义了一个名为 make_divisible 的函数,它接受两个参数。
# 1.x :需要被调整的数值。
# 2.divisor :除数。
def make_divisible(x, divisor):
# Returns nearest x divisible by divisor 返回最接近除数的 x。
# 检查 divisor 是否是一个 PyTorch 张量( torch.Tensor )。 isinstance 函数用于检查 divisor 的类型是否与 torch.Tensor 匹配。
if isinstance(divisor, torch.Tensor):
# 如果 divisor 是一个 PyTorch 张量,那么取张量中的最大值,并将其转换为整数。这是因为在某些情况下, divisor 可能不是一个标量,而是一个包含多个值的张量,我们需要从中选择一个值来作为除数。这里选择的是张量中的最大值。
divisor = int(divisor.max()) # to int
# 计算 x 除以 divisor 的结果,并使用 math.ceil 函数向上取整到最近的整数。 math.ceil 函数的作用是返回大于或等于给定数字的最小整数。然后,将这个整数乘以 divisor ,以确保结果能够被 divisor 整除。最后,返回这个值。
return math.ceil(x / divisor) * divisor
# make_divisible 函数的目的是确保给定的数值 x 能够被 divisor 整除。如果 divisor 是一个张量,它会取张量中的最大值作为除数,然后计算出最接近 x 的能被这个除数整除的数值,并返回这个数值。这个函数在深度学习中常用于确保某些参数(如通道数)能够被特定的数值整除,以满足某些硬件或优化的要求。
44.def clean_str(s):
# 这段代码定义了一个名为 clean_str 的函数,其目的是清除字符串 s 中的特殊字符,并用下划线 _ 替换它们。这个函数使用 Python 的 re 模块(正则表达式模块)来实现。
# 定义了一个函数 clean_str ,它接受一个参数。
# 1.s :即需要被清理的字符串。
def clean_str(s):
# Cleans a string by replacing special characters with underscore _ 通过使用下划线 _ 替换特殊字符来清理字符串。
# 使用 re.sub() 函数来替换字符串中匹配特定模式的字符。
# pattern : 这是一个正则表达式模式,用于匹配所有需要被替换的特殊字符。方括号 [] 内的字符表示匹配任何一个列出的特殊字符。
# repl : 这是替换字符串,即所有匹配到的模式都会被替换成这个字符串。在这里,所有特殊字符都会被替换为下划线 _ 。
# string : 这是需要被搜索和替换的原始字符串,即函数参数 s 。
# 返回值。函数返回替换后的字符串。
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
# clean_str 函数的主要作用是去除字符串中可能导致处理问题的字符,如管道符、@ 符号、感叹号、货币符号等,确保字符串更加标准化,便于后续的处理和使用。
45.def one_cycle(y1=0.0, y2=1.0, steps=100):
# 这段代码定义了一个名为 one_cycle 的函数,它用于生成一个周期性的函数,该函数在给定的步数内从 y1 值平滑过渡到 y2 值,并再过渡回 y1 值。这种类型的函数常用于机器学习中的学习率调度,特别是在 One Cycle 策略中。
# 定义了一个名为 one_cycle 的函数,它接受三个参数。
# 1.y1 :周期开始和结束时的值,默认为 0.0 。
# 2.y2 :周期中间的最大值,默认为 1.0 。
# 3.steps :完成一个完整周期所需的步数,默认为 100 。
def one_cycle(y1=0.0, y2=1.0, steps=100):
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf 从 y1 到 y2 的正弦斜坡的 lambda 函数 https://arxiv.org/pdf/1812.01187.pdf 。
# 返回一个 lambda 函数,这个匿名函数接受一个参数 x ,表示当前的步数。 lambda 函数的计算过程如下 :
# x * math.pi / steps :计算 x 步在周期中的比例,并将其映射到 0 到 π 的范围内。
# math.cos(...) :计算上述比例的余弦值。
# 1 - math.cos(...) :从 1 减去余弦值,得到一个从 1 到 -1 再到 1 的曲线。
# (1 - math.cos(...)) / 2 :将上述曲线的振幅调整到 0 到 1 的范围内。
# ((1 - math.cos(...)) / 2) * (y2 - y1) :将调整后的曲线乘以 y2 和 y1 之间的差值,以映射到所需的值域。
# + y1 :将上述结果加上 y1 ,以确保函数的起始值为 y1 。
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
# one_cycle 函数提供了一种方便的方式来生成一个周期性函数,该函数可以用于各种需要周期性变化的场景,特别是在机器学习中调整学习率。通过调整 y1 、 y2 和 steps 参数,可以灵活地控制周期的形状和长度。
46.def one_flat_cycle(y1=0.0, y2=1.0, steps=100):
# 这段代码定义了一个名为 one_flat_cycle 的函数,它生成一个 lambda 函数用于实现一个具有平坦峰值的周期性变化。这种周期性变化在到达中点后保持峰值,直到周期结束。
# 定义了一个名为 one_flat_cycle 的函数,它接受三个参数。
# 1.y1 :周期开始和结束时的值,默认为 0.0 。
# 2.y2 :周期中间的最大值,默认为 1.0 。
# 3.steps :完成一个完整周期所需的步数,默认为 100 。
def one_flat_cycle(y1=0.0, y2=1.0, steps=100):
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf 从 y1 到 y2 的正弦斜坡的 lambda 函数 https://arxiv.org/pdf/1812.01187.pdf 。
#return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 。
# 返回一个 lambda 函数,这个匿名函数接受一个参数 x ,表示当前的步数。 lambda 函数的计算过程如下 :
# (x - (steps // 2)) :计算 x 相对于周期中点的位置。
# * math.pi / (steps // 2) :将相对位置映射到 0 到 π 的范围内。
# math.cos(...) :计算上述映射值的余弦值。
# 1 - math.cos(...) :从 1 减去余弦值,得到一个从 1 到 -1 再到 1 的曲线。
# (1 - math.cos(...)) / 2 :将上述曲线的振幅调整到 0 到 1 的范围内。
# ((1 - math.cos(...)) / 2) * (y2 - y1) :将调整后的曲线乘以 y2 和 y1 之间的差值,以映射到所需的值域。
# + y1 :将上述结果加上 y1 ,以确保函数的起始值为 y1 。
# if (x > (steps // 2)) else y1 :如果 x 大于周期的一半,则应用上述计算,否则保持 y1 值。
return lambda x: ((1 - math.cos((x - (steps // 2)) * math.pi / (steps // 2))) / 2) * (y2 - y1) + y1 if (x > (steps // 2)) else y1
# one_flat_cycle 函数提供了一种方便的方式来生成一个具有平坦峰值的周期性函数,该函数可以用于各种需要周期性变化的场景,特别是在机器学习中调整学习率。通过调整 y1 、 y2 和 steps 参数,可以灵活地控制周期的形状和长度。与 one_cycle 函数不同, one_flat_cycle 函数在周期的后半部分保持峰值,直到周期结束。
47.def colorstr(*input):
# 这段代码定义了一个名为 colorstr 的函数,它用于给字符串添加 ANSI 转义代码,从而在支持 ANSI 颜色代码的终端中输出彩色文本。
# 这行定义了一个函数 colorstr ,它接受任意数量的参数。
def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
# 这是一个字典,定义了各种颜色和样式的 ANSI 转义代码。
colors = {
'black': '\033[30m', # basic colors
'red': '\033[31m',
'green': '\033[32m',
'yellow': '\033[33m',
'blue': '\033[34m',
'magenta': '\033[35m',
'cyan': '\033[36m',
'white': '\033[37m',
'bright_black': '\033[90m', # bright colors
'bright_red': '\033[91m',
'bright_green': '\033[92m',
'bright_yellow': '\033[93m',
'bright_blue': '\033[94m',
'bright_magenta': '\033[95m',
'bright_cyan': '\033[96m',
'bright_white': '\033[97m',
'end': '\033[0m', # misc
'bold': '\033[1m',
'underline': '\033[4m'}
# 构建并返回最终的字符串。它首先通过列表推导式和 join 函数将所有颜色和样式的 ANSI 代码连接起来,然后加上要着色的字符串 string ,最后加上 colors['end'] 来重置样式,确保之后的输出不会受到颜色代码的影响。
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
# 注意事项 :
# 这个函数依赖于终端支持 ANSI 转义代码。在不支持这些代码的环境中(比如某些旧的终端或者非终端环境),这些代码可能不会有任何效果,或者会显示为乱码。
# colors['end'] 用于重置样式,确保在着色文本之后,终端的样式不会被改变。
# 函数的设计允许灵活地组合颜色和样式,只需按顺序传入相应的参数即可。
48.def labels_to_class_weights(labels, nc=80):
# 这段代码定义了一个名为 labels_to_class_weights 的函数,其目的是根据训练标签计算每个 类别 的权重,这通常用于在目标检测任务中平衡类别不均衡的问题。
# 定义了一个名为 labels_to_class_weights 的函数,它接受两个参数。
# 1.labels :一个包含训练标签的列表或数组,其中每个标签包含类别信息和其他属性(如边界框坐标)。
# 2.nc :数据集中的类别总数,默认为 80。
def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels
# 检查标签。如果传入的 labels 中的第一个元素是 None ,表示没有加载任何标签。
if labels[0] is None: # no labels loaded
# 函数返回一个空的 PyTorch 张量。
return torch.Tensor()
# 合并标签。使用 np.concatenate 将 labels 中的所有数组合并成一个大数组。对于 COCO 数据集,每个标签的形状可能是 (5,) ,包含 [class, x, y, w, h] 。
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
# 提取类别信息。从合并后的标签数组中提取类别信息,并将它们转换为整数类型。
classes = labels[:, 0].astype(int) # labels = [class xywh]
# 计算每个类别的出现次数。使用 np.bincount 计算每个类别的出现次数, minlength=nc 确保权重数组的长度至少为 nc ,即类别总数。
weights = np.bincount(classes, minlength=nc) # occurrences per class
# Prepend gridpoint count (for uCE training)
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
# weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
# 处理空类别。将权重数组中为 0 的元素替换为 1,以避免除以零的错误。
weights[weights == 0] = 1 # replace empty bins with 1
# 计算类别权重。
# 计算每个类别的权重,即每个类别的目标数量的倒数。
weights = 1 / weights # number of targets per class
# 将权重数组归一化,使得所有权重的和为 1。
weights /= weights.sum() # normalize
# 转换为 PyTorch 张量。将 NumPy 数组转换为 PyTorch 张量,并确保数据类型为浮点数。
return torch.from_numpy(weights).float()
# labels_to_class_weights 函数根据训练标签计算每个类别的权重,这些权重可以用于损失函数中,以平衡不同类别之间的样本数量差异。这对于提高模型在类别不均衡数据集上的性能非常有用。函数首先检查标签,然后合并标签,计算每个类别的出现次数,处理空类别,计算权重,并最终将权重归一化并转换为 PyTorch 张量。
49.def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
# 这段代码定义了一个名为 labels_to_image_weights 的函数,其目的是根据训练标签计算每个 图像 的权重。这个权重通常是根据每个图像中各类别出现的次数和预定义的类别权重计算得出的。
# 定义了一个名为 labels_to_image_weights 的函数,它接受三个参数。
# 1.labels :一个包含训练标签的列表或数组,其中每个标签包含类别信息和其他属性(如边界框坐标)。
# 2.nc :数据集中的类别总数,默认为 80。
# 3.class_weights :一个数组,包含每个类别的权重,默认为 80 个元素的数组,所有元素都为 1。
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
# Produces image weights based on class_weights and image contents 根据 class_weights 和图像内容生成图像权重。
# Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
# 计算每个类别的出现次数。
# 使用列表推导式遍历 labels 中的每个标签数组,对每个数组中的第一个元素(即类别索引)使用 np.bincount 函数计算每个类别的出现次数。 minlength=nc 确保每个计数数组的长度至少为 nc ,即类别总数。结果是一个 NumPy 数组,其中每一行代表一个图像中各类别出现的次数。
class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
# 计算图像权重。
# 这行代码计算每个图像的权重 :
# class_weights.reshape(1, nc) :将 class_weights 重塑为一个列向量。
# * class_counts :将重塑后的类别权重与每个图像的类别计数相乘,得到每个类别对每个图像权重的贡献。
# .sum(1) :沿着第二个轴(即类别轴)求和,得到每个图像的总权重。
# 返回值。函数返回一个数组,其中每个元素代表对应图像的权重。
return (class_weights.reshape(1, nc) * class_counts).sum(1)
# labels_to_image_weights 函数根据训练标签和预定义的类别权重计算每个图像的权重。这个权重可以用于损失函数中,以平衡不同图像中各类别样本数量的差异。这对于提高模型在类别不均衡数据集上的性能非常有用。函数首先计算每个类别的出现次数,然后根据类别权重计算每个图像的权重,并最终返回这些权重。
50.def coco80_to_coco91_class():
# 这段代码定义了一个名为 coco80_to_coco91_class 的函数,其目的是将 COCO 数据集中的 80 个类别索引(用于验证 2014 数据集)转换为论文中使用的 91 个类别索引。这个转换通常用于将 COCO 数据集的类别索引与 YOLOv3 等模型的类别索引对齐,因为这些模型可能在训练时使用了 COCO 的 91 个类别。
# 定义了一个没有参数的函数 coco80_to_coco91_class 。
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
# 返回一个列表,包含从 1 到 90 的整数,其中每个整数代表 COCO 数据集中的一个类别索引。这个列表中的索引对应于 COCO 数据集的 91 个类别,其中一些索引被跳过了,因为 COCO 数据集中的某些类别在 YOLO 模型中没有被使用。
return [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
# coco80_to_coco91_class 函数提供了一个简单的索引映射列表,用于将 COCO 数据集的 80 个类别索引转换为 91 个类别索引。这个函数对于在不同模型和数据集之间迁移和对齐类别索引非常有用。通过使用这个映射列表,开发者可以确保类别索引的一致性,从而在不同的模型和数据集之间进行有效的比较和迁移。
51.def xyxy2xywh(x):
# 这段代码定义了一个名为 xyxy2xywh 的函数,它用于将边界框的坐标从 xyxy 格式(即左上角和右下角坐标)转换为 xywh 格式(即中心点坐标和宽高)。
# 定义 xyxy2xywh 函数,它接受一个参数。
# 1.x :这个参数是一个包含边界框坐标的数组或张量。
def xyxy2xywh(x):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right 将 nx4 个框从 [x1, y1, x2, y2] 转换为 [x, y, w, h],其中 xy1=左上角,xy2=右下角。
# 创建 x 的一个副本。如果 x 是一个 PyTorch 张量,则使用 clone() 方法;如果 x 是一个 NumPy 数组,则使用 np.copy() 方法。这样做是为了避免修改原始输入数据。
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
# 计算边界框的中心点 x 坐标。这是通过取左上角 x 坐标( x[..., 0] )和右下角 x 坐标( x[..., 2] )的平均值得到的。
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
# 计算边界框的中心点 y 坐标。这是通过取左上角 y 坐标( x[..., 1] )和右下角 y 坐标( x[..., 3] )的平均值得到的。
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
# 计算边界框的宽度。这是通过取右下角 x 坐标( x[..., 2] )和左上角 x 坐标( x[..., 0] )的差值得到的。
y[..., 2] = x[..., 2] - x[..., 0] # width
# 计算边界框的高度。这是通过取右下角 y 坐标( x[..., 3] )和左上角 y 坐标( x[..., 1] )的差值得到的。
y[..., 3] = x[..., 3] - x[..., 1] # height
# 返回转换后的边界框坐标数组或张量。
return y
# 这个函数在目标检测和图像处理中非常有用,因为它提供了一种将边界框坐标从一种格式转换为另一种格式的简便方法。这种转换在不同的计算机视觉任务中是常见的,例如在非极大值抑制(NMS)或在将检测结果可视化时。
52.def xywh2xyxy(x):
# 这段代码是一个函数 xywh2xyxy ,它的作用是将边界框的坐标从 (x, y, w, h) 格式转换为 (x1, y1, x2, y2) 格式。这里的 (x, y) 表示边界框的中心点坐标, w 和 h 分别表示边界框的宽度和高度。转换后的 (x1, y1) 是边界框的左上角坐标, (x2, y2) 是右下角坐标。
# 这行定义了一个名为 xywh2xyxy 的函数,它接受一个参数。
# 1.x :这个参数是一个包含边界框坐标的数组。
def xywh2xyxy(x):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
# 创建了一个新的变量 y ,它是用来存储转换后的坐标。如果输入 x 是一个 PyTorch 张量( torch.Tensor ),则使用 clone() 方法来复制这个张量,这样可以避免修改原始数据。如果 x 不是 PyTorch 张量,那么它是一个 NumPy 数组,此时使用 np.copy(x) 来复制数组。
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
# 计算边界框左上角的 x 坐标。原始的 x 坐标是边界框中心的 x 坐标,宽度 w 的一半从中心向左移动,因此用中心的 x 坐标减去宽度的一半。
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
# 计算边界框左上角的 y 坐标。原始的 y 坐标是边界框中心的 y 坐标,高度 h 的一半从中心向上移动,因此用中心的 y 坐标减去高度的一半。
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
# 计算边界框右下角的 x 坐标。与左上角的 x 坐标计算相反,这里需要从中心的 x 坐标向右移动宽度的一半。
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
# 计算边界框右下角的 y 坐标。与左上角的 y 坐标计算相反,这里需要从中心的 y 坐标向下移动高度的一半。
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
# 这行代码表示函数执行完毕后返回转换后的坐标数组 y 。
return y
# 这个函数通过简单的算术运算将边界框的中心点坐标和宽高转换成了左上角和右下角的坐标,这是计算机视觉中常见的坐标转换操作。
53.def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
# 这段代码定义了一个名为 xywhn2xyxy 的函数,它用于将边界框的坐标从中心点坐标加宽高( x, y, w, h )的格式转换为左上角和右下角坐标( x1, y1, x2, y2 )的格式。
# 这行定义了一个名为 xywhn2xyxy 的函数,它接受五个参数。
# 1.x :边界框的坐标数组,格式为 [x, y, w, h] ,其中 x 和 y 是中心点坐标, w 和 h 是宽度和高度。
# 2.w 和 3.h :图像的宽度和高度,默认分别为640。
# 4.padw 和 5.padh :图像的左右和上下padding值,默认均为0。
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
# Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
# 创建 x 的一个副本。如果 x 是一个 PyTorch 张量,则使用 clone() 方法;否则,使用 NumPy 的 copy() 方法。
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
# 计算左上角的 x 坐标。公式 (x - w/2) 将中心点的 x 坐标转换为左上角的 x 坐标,然后乘以图像宽度 w 并加上 padw 。
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
# 计算左上角的 y 坐标。公式 (x - h/2) 将中心点的 y 坐标转换为左上角的 y 坐标,然后乘以图像高度 h 并加上 padh 。
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
# 计算右下角的 x 坐标。公式 (x + w/2) 将中心点的 x 坐标转换为右下角的 x 坐标,然后乘以图像宽度 w 并加上 padw 。
y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
# 计算右下角的 y 坐标。公式 (x + h/2) 将中心点的 y 坐标转换为右下角的 y 坐标,然后乘以图像高度 h 并加上 padh 。
y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
# 返回转换后的边界框坐标数组 y ,格式为 [x1, y1, x2, y2] 。
return y
# 这个 xywhn2xyxy 函数将边界框的坐标从中心点坐标加宽高格式转换为左上角和右下角坐标格式,这在图像处理和目标检测任务中非常常见。这种转换有助于与不同的图像处理库和深度学习框架兼容,因为不同的库和框架可能使用不同的坐标格式。
54.def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
# 这段代码定义了一个名为 xyxy2xywhn 的函数,它将边界框的坐标从 (x1, y1, x2, y2) 格式(其中 (x1, y1) 是左上角坐标, (x2, y2) 是右下角坐标)转换为 (x, y, w, h) 格式(其中 (x, y) 是边界框的中心点坐标, w 和 h 是边界框的宽度和高度),并且将坐标归一化到 [0, 1] 范围内。
# 这是 xyxy2xywhn 函数的定义。
# 1.x :是输入的边界框坐标数组。
# 2.w 和 3.h :分别是图像的 宽度 和 高度 ,默认为 640。
# 4.clip :是一个布尔值,指示是否需要将边界框限制在图像尺寸内。
# 5.eps :是一个用于数值稳定性的小值,默认为 0.0。
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
# 如果 clip 为 True ,则执行边界框的裁剪操作。
if clip:
# 使用 clip_boxes 函数将边界框限制在图像尺寸内,减去 eps 以避免精确边界问题。
# def clip_boxes(boxes, shape): -> 用是将边界框(boxes)的坐标限制在给定图像形状(shape)的范围内,确保边界框不会超出图像的边界。
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
# 创建输入数组 x 的副本,如果 x 是 PyTorch 张量则使用 clone() ,如果是 NumPy 数组则使用 copy() 。
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
# 计算边界框的中心 x 坐标,并将其归一化到 [0, 1] 范围内。
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
# 计算边界框的中心 y 坐标,并将其归一化到 [0, 1] 范围内。
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
# 计算边界框的宽度,并将其归一化到 [0, 1] 范围内。
y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
# 计算边界框的高度,并将其归一化到 [0, 1] 范围内。
y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
# 返回转换后的边界框坐标数组。
return y
# 这个函数在目标检测和图像分割等计算机视觉任务中非常有用,因为它提供了一种将边界框坐标转换为更常用格式的方法,并且归一化使得这些坐标与图像尺寸无关,便于模型处理。
55.def xyn2xy(x, w=640, h=640, padw=0, padh=0):
# 这段代码定义了一个名为 xyn2xy 的函数,它将归一化的分割点坐标转换为像素坐标。
# 这是 xyn2xy 函数的定义,它接受以下参数。
# 1.x :归一化的分割点坐标数组。
# 2.w 和 3.h :图像的宽度和高度,默认为 640。
# 4.padw 和 5.padh :水平和垂直的填充值,默认为 0。
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
# Convert normalized segments into pixel segments, shape (n,2)
# 创建输入数组 x 的副本。如果 x 是 PyTorch 张量,则使用 clone() 方法;如果是 NumPy 数组,则使用 copy() 函数。
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
# 将归一化的 x 坐标( x[..., 0] )转换为像素坐标,方法是乘以图像宽度 w 并加上水平填充 padw 。
y[..., 0] = w * x[..., 0] + padw # top left x
# 将归一化的 y 坐标( x[..., 1] )转换为像素坐标,方法是乘以图像高度 h 并加上垂直填充 padh 。
y[..., 1] = h * x[..., 1] + padh # top left y
# 返回转换后的像素坐标数组 y 。
return y
# 这个函数在处理图像分割任务时非常有用,特别是在需要将归一化的坐标转换为实际像素坐标时。通过这种方式,可以确保分割点坐标与图像的实际尺寸相匹配,从而进行准确的分割。
56.def segment2box(segment, width=640, height=640):
# 这段代码定义了一个名为 segment2box 的函数,它用于将线段标签转换为边界框标签。这个转换过程包括找到线段在图像中的最小和最大 x、y 坐标,从而确定一个包含该线段的边界框。
# 这行定义了一个名为 segment2box 的函数,它接受三个参数。
# 1.segment :线段标签。
# 2.width :图像宽度,默认为640。
# 3.height :图像高度,默认为640。
def segment2box(segment, width=640, height=640):
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) 将 1 个段标签转换为 1 个框标签,应用图像内部约束,即 (xy1, xy2, ...) 到 (xyxy)。
# 将线段的坐标转置,并将 x 坐标和 y 坐标分别赋值给 x 和 y 。
x, y = segment.T # segment xy
# 创建一个布尔数组 inside ,用于标记线段上的点是否在图像范围内。
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
# 只保留在图像范围内的 x 和 y 坐标。
x, y, = x[inside], y[inside]
# 执行以下操作 :
# 如果存在有效的 x 坐标(即 any(x) 为真),则计算 x 坐标的最小值和最大值,以及 y 坐标的最小值和最大值,形成一个边界框,并返回这个边界框的坐标。
# 如果不存在有效的 x 坐标(即线段完全在图像范围外),则返回一个全零的数组,表示没有有效的边界框。
# 返回值。返回一个包含边界框坐标的数组,格式为 [x_min, y_min, x_max, y_max] 。
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
# 这个 segment2box 函数将线段标签转换为边界框标签,确保了线段在图像范围内。这个转换对于目标检测任务中的数据处理非常有用,尤其是在处理线段标注的数据时。通过这种方式,可以将线段数据转换为更常用的边界框格式,以便进行进一步的处理和分析。
57.def segments2boxes(segments):
# 这段代码定义了一个名为 segments2boxes 的函数,它将分割标签(segment labels)转换为边界框标签(box labels)。
# 这是 segments2boxes 函数的定义,它接受一个参数。
# 1.segments :这是一个包含分割数据的列表或数组,其中每个元素代表一个对象的分割点坐标。
def segments2boxes(segments):
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
# 初始化一个空列表 boxes ,用于存储转换后的边界框数据。
boxes = []
# 遍历 segments 中的每个分割数据 s 。
for s in segments:
# 将分割数据 s 转置并分解为 x 和 y 坐标数组。 s 是一个 NumPy 数组,其中每一列代表一个分割点的 x 和 y 坐标。
x, y = s.T # segment xy
# 对于每个分割数据,计算包含所有分割点的最小和最大 x 、 y 坐标,形成一个边界框,并将其添加到 boxes 列表中。这个边界框以 (cls, xyxy) 格式表示,其中 cls 是类别标签, xyxy 是边界框的左上角和右下角坐标。
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
# 将 boxes 列表转换为 NumPy 数组,并使用 xyxy2xywh 函数将其从 xyxy 格式转换为 xywh 格式。 xywh 格式表示边界框的中心点 x 坐标、中心点 y 坐标以及边界框的宽度和高度。
return xyxy2xywh(np.array(boxes)) # cls, xywh
# segments2boxes 函数的目的是将分割标签转换为边界框标签,这在处理目标检测任务时非常有用,尤其是当数据集中包含分割标注时。通过这种方式,可以将分割数据转换为更适合目标检测模型训练的格式。
58.def resample_segments(segments, n=1000):
# 这段代码定义了一个名为 resample_segments 的函数,它用于对线段进行上采样,即在保持线段形状不变的情况下增加线段上的点的数量。这通常用于在图像变换中更精确地处理线段。
# 这行定义了一个名为 resample_segments 的函数,它接受两个参数。
# 1.segments :要进行上采样的线段数组。
# 2.n :上采样后线段上的点的数量,默认为1000。
def resample_segments(segments, n=1000):
# Up-sample an (n,2) segment 对 (n,2) 段进行上采样。
# 遍历每个线段 s 及其索引 i 。
for i, s in enumerate(segments):
# 将线段的第一个点添加到线段的末尾,以闭合线段。这是为了确保线段是闭合的,从而在上采样时可以正确地插值。
s = np.concatenate((s, s[0:1, :]), axis=0)
# 生成一个从0到 len(s) - 1 的等间隔的数组 x ,长度为 n ,用于确定上采样点的位置。
x = np.linspace(0, len(s) - 1, n)
# 创建一个从0到 len(s) - 1 的数组 xp ,用于原始线段点的索引。
xp = np.arange(len(s))
# np.interp(x, xp, fp, left=None, right=None, period=None)
# np.interp 是 NumPy 库中的一个函数,用于一维线性插值。给定一组数据点 x 和相应的值 xp ,以及一个新的查询点 x , np.interp 函数会找到 xp 中 x 值所在的区间,并使用线性插值来估计 x 对应的值。
# 参数说明 :
# x :查询点,即你想要插值的点。
# xp :数据点,一个一维数组,包含数据点的横坐标。
# fp : xp 对应的值,一个一维数组,包含数据点的纵坐标。
# left :可选参数,如果 x 中的值小于 xp 中的最小值,则使用这个值作为插值结果。
# right :可选参数,如果 x 中的值大于 xp 中的最大值,则使用这个值作为插值结果。
# period :可选参数,表示周期性,如果指定, xp 将被视为周期性的。
# 返回值 :
# 插值结果,一个与 x 形状相同的数组。
# 对每个线段进行上采样。
# np.interp(x, xp, s[:, i]) 对线段的每个维度(x和y)进行插值,生成新的点。
# np.concatenate(...) 将插值后的x和y坐标连接起来。
# .reshape(2, -1).T 将连接后的数组重塑为 (-1, 2) 的形状,即每个点的x和y坐标。
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
# 返回值。这行代码返回上采样后的线段数组。
return segments
# 这个 resample_segments 函数通过对线段进行上采样,增加了线段上的点的数量,这有助于在图像变换中更精确地处理线段。通过插值,可以在保持线段形状的同时,增加线段的分辨率。
59.def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
# 这段代码定义了一个名为 scale_boxes 的函数,它用于调整边界框(boxes)的大小,使其适应新的图像尺寸。这个函数通常用于目标检测任务中,当图像尺寸改变时,需要对检测到的边界框进行相应的缩放。
# 这行代码定义了 scale_boxes 函数,它接受四个参数。
# 1.img1_shape :新图像的尺寸,格式为 (height, width)。
# 2.boxes :需要缩放的边界框数组,每个边界框通常是 [x1, y1, x2, y2] 格式,其中 (x1, y1) 是左上角坐标,(x2, y2) 是右下角坐标。
# 3.img0_shape :原始图像的尺寸,格式为 (height, width)。
# 4.ratio_pad :一个可选参数,如果提供,将用于计算缩放比例和填充。
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
# Rescale boxes (xyxy) from img1_shape to img0_shape 将框(xyxy)从 img1_shape 重新缩放为 img0_shape。
# 如果没有提供 ratio_pad ,则计算缩放比例和填充。
if ratio_pad is None: # calculate from img0_shape
# 计算缩放比例 gain ,取高度和宽度比例的最小值。
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
# 计算水平和垂直方向上的填充量。
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
# 如果提供了 ratio_pad ,则直接使用提供的缩放比例和填充:
else:
# 使用提供的缩放比例。
gain = ratio_pad[0][0]
# 使用提供的填充量。
pad = ratio_pad[1]
# 将边界框的 x 坐标(左右边界)向左移动填充量。
boxes[:, [0, 2]] -= pad[0] # x padding
# 将边界框的 y 坐标(上下边界)向上移动填充量。
boxes[:, [1, 3]] -= pad[1] # y padding
# 将边界框的四个坐标都除以缩放比例 gain ,以重新缩放到目标尺寸。
boxes[:, :4] /= gain
# 调用 clip_boxes 函数,确保边界框不会超出目标图像的边界。
# def clip_boxes(boxes, shape): -> 用于确保边界框(boxes)的坐标不会超出给定图像的尺寸(shape)。
clip_boxes(boxes, img0_shape)
# 返回重新缩放后的边界框数组。
return boxes
# 这个函数的目的是将边界框从原始图像尺寸调整到新图像尺寸,同时保持边界框的相对位置和比例。 clip_boxes 函数的作用是确保边界框的坐标不会超出图像的边界。这个函数在图像处理和目标检测中非常有用,尤其是在进行图像缩放或裁剪后需要调整边界框时。
60.def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
# 这段代码定义了一个名为 scale_segments 的函数,其目的是将边界框坐标(通常是 xyxy 格式)从一个图像尺寸( img1_shape )调整到另一个图像尺寸( img0_shape )。这种调整在图像处理和目标检测任务中很常见,特别是在处理不同分辨率的图像时。
# 定义了一个名为 scale_segments 的函数,它接受五个参数。
# 1.img1_shape :调整后的图像尺寸,格式为 (height, width)。
# 2.segments :要调整的边界框坐标数组,格式为 [x1, y1, x2, y2, ...]。
# 3.img0_shape :原始图像尺寸,格式为 (height, width)。
# 4.ratio_pad :一个可选参数,包含缩放比例和填充信息。如果不提供,将根据 img0_shape 和 img1_shape 计算。
# 5.normalize :一个布尔值,指示是否将坐标归一化到 [0, 1] 范围内,默认为 False 。
def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
# Rescale coords (xyxy) from img1_shape to img0_shape 将坐标 (xyxy) 从 img1_shape 重新缩放为 img0_shape 。
# 计算缩放比例和填充。
# 如果 ratio_pad 未提供,计算缩放比例 gain 为 img1_shape 和 img0_shape 之间的最小比例,以及需要的填充 pad 以使 img0_shape 缩放到 img1_shape 。
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
# 如果提供了 ratio_pad ,则直接使用提供的缩放比例和填充。
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
# 调整边界框坐标。
# 调整边界框坐标,首先减去填充,然后除以缩放比例 gain 。
segments[:, 0] -= pad[0] # x padding
segments[:, 1] -= pad[1] # y padding
segments /= gain
# 裁剪边界框坐标。调用 clip_segments 函数裁剪边界框坐标,确保它们不会超出 img0_shape 的范围。
clip_segments(segments, img0_shape)
# 归一化边界框坐标。
# 如果 normalize 为 True ,则将边界框坐标归一化到 [0, 1] 范围内。
if normalize:
segments[:, 0] /= img0_shape[1] # width
segments[:, 1] /= img0_shape[0] # height
# 返回调整后的边界框坐标。
return segments
# scale_segments 函数提供了一种方便的方式来调整边界框坐标,使其适应不同分辨率的图像。这个函数首先计算缩放比例和填充,然后调整边界框坐标,最后裁剪和归一化坐标。这对于在不同尺寸的图像上训练和测试目标检测模型非常有用。
61.def clip_boxes(boxes, shape):
# 这段代码定义了一个名为 clip_boxes 的函数,其作用是将边界框(boxes)的坐标限制在给定图像形状(shape)的范围内,确保边界框不会超出图像的边界。这个函数处理了两种类型的输入:PyTorch张量和NumPy数组。
# 定义了 clip_boxes 函数,包含以下参数。
# 1.boxes :需要裁剪的边界框数组,格式为 [x1, y1, x2, y2] 。
# 2.shape :图像的尺寸,格式为 (高度, 宽度) 。
def clip_boxes(boxes, shape):
# Clip boxes (xyxy) to image shape (height, width)
# 检查 boxes 是否是一个 PyTorch 张量。如果是,函数将对张量的每个元素单独进行操作,这通常更快。
if isinstance(boxes, torch.Tensor): # faster individually
# 如果 boxes 是 PyTorch 张量,使用 clamp_ 方法将所有边界框的 x1 坐标限制在 0 到图像宽度 shape[1] 之间。 clamp_ 方法会就地修改张量。
boxes[:, 0].clamp_(0, shape[1]) # x1
# 将所有边界框的 y1 坐标限制在 0 到图像高度 shape[0] 之间。
boxes[:, 1].clamp_(0, shape[0]) # y1
# 将所有边界框的 x2 坐标限制在 0 到图像宽度 shape[1] 之间。
boxes[:, 2].clamp_(0, shape[1]) # x2
# 将所有边界框的 y2 坐标限制在 0 到图像高度 shape[0] 之间。
boxes[:, 3].clamp_(0, shape[0]) # y2
# 如果 boxes 不是 PyTorch 张量,那么它可能是一个 NumPy 数组。在这种情况下,函数将对数组的元素进行分组操作,这通常更快。
else: # np.array (faster grouped)
# 对于 NumPy 数组,使用 clip 方法将所有边界框的 x1 和 x2 坐标限制在 0 到图像宽度 shape[1] 之间。
boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
# 将所有边界框的 y1 和 y2 坐标限制在 0 到图像高度 shape[0] 之间。
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
# 这个函数的目的是确保边界框的坐标不会超出图像的边界,这对于图像处理和目标检测任务中的数据清洗和预处理非常重要。通过限制边界框的坐标,可以避免在后续的处理中出现坐标超出图像范围的错误。
62.def clip_segments(segments, shape):
# 这段代码定义了一个名为 clip_segments 的函数,其目的是将边界框坐标(通常是 xy1, xy2 格式)裁剪到图像的尺寸范围内。这个函数确保边界框的坐标不会超出图像的边界。
# 定义了一个名为 clip_segments 的函数,它接受两个参数。
# 1.segments :要裁剪的边界框坐标数组,格式为 [x1, y1, x2, y2, ...]。
# 2.shape :图像的尺寸,格式为 (height, width)。
def clip_segments(segments, shape):
# Clip segments (xy1,xy2,...) to image shape (height, width) 将片段(xy1、xy2、...)剪辑为图像形状(高度、宽度)。
# 裁剪边界框坐标。
# 如果 segments 是一个 PyTorch 张量,使用 clamp_ 方法直接在原张量上操作,将 x 坐标限制在 [0, width] 范围内,将 y 坐标限制在 [0, height] 范围内。 clamp_ 方法是原地操作(in-place),意味着它会修改原始张量。
if isinstance(segments, torch.Tensor): # faster individually
# torch.clamp(input, min=None, max=None)
# torch.clamp() 是 PyTorch 库中的一个函数,用于将张量中的元素限制在指定的范围内。如果元素超出了这个范围,它们将被设置为范围的上限或下限。
# 参数 :
# input :要进行裁剪的输入张量。
# min :元素的最小值。默认为 None ,表示不设置下界。
# max :元素的最大值。默认为 None ,表示不设置上界。
# 返回值 :
# 返回一个新的张量,其中的元素被限制在 [min, max] 范围内。
# 注意事项 :
# torch.clamp() 函数返回的是新张量,原始输入张量不会被修改。
# 如果需要在原地修改张量,可以使用 clamped_() 方法,例如 tensor.clamp_(0, 3) 。
# torch.clamp() 可以用于多维张量,并且可以指定不同的 min 和 max 值用于不同的维度。
# min 和 max 参数也可以是标量值,或者与输入张量形状相同的张量,用于对不同元素应用不同的限制。
segments[:, 0].clamp_(0, shape[1]) # x
segments[:, 1].clamp_(0, shape[0]) # y
# 如果 segments 不是 PyTorch 张量(例如,是一个 NumPy 数组),则使用 clip 方法将 x 坐标限制在 [0, width] 范围内,将 y 坐标限制在 [0, height] 范围内。 clip 方法返回一个新的数组,不修改原始数组。
else: # np.array (faster grouped)
# numpy.clip(a, a_min, a_max, out=None, **kwargs)
# np.clip() 是 NumPy 库中的一个函数,它用于将数组中的元素限制在指定的范围内。如果元素超出了这个范围,它们将被设置为范围的上限或下限。
# 参数 :
# a :要进行裁剪的输入数组。
# a_min :元素的最小值。如果 a_min 未指定,则任何元素都不会因为低于此值而被裁剪。
# a_max :元素的最大值。如果 a_max 未指定,则任何元素都不会因为高于此值而被裁剪。
# out :(可选)用于存放结果的输出数组。它必须有与输入数组相同的形状和类型。
# **kwargs :其他关键字参数,用于传递给 ufunc (通用函数)。
# 返回值 :
# 返回一个新数组,其中的元素被限制在 [a_min, a_max] 范围内。
# 注意事项 :
# np.clip() 函数返回的是新数组,原始输入数组不会被修改。
# 如果需要在原地修改数组,可以使用 np.clip() 函数的 out 参数,或者使用 clipped_arr = arr.clip(0, 3) 这样的赋值语句,其中 arr 是 NumPy 数组。
# np.clip() 可以用于多维数组,并且可以指定不同的 a_min 和 a_max 值用于不同的轴。
segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
# clip_segments 函数提供了一种方便的方式来确保边界框坐标不会超出图像的边界。这个函数根据输入数据的类型(PyTorch 张量或 NumPy 数组)选择不同的裁剪方法,以提高效率。这对于图像处理和目标检测任务中处理边界框坐标非常有用。
63.def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, labels=(), max_det=300, nm=0,):
# 这段代码定义了一个名为 non_max_suppression 的函数,它用于在目标检测任务中执行非极大值抑制(Non-Maximum Suppression, NMS),以去除多余的边界框,只保留最佳的检测结果。
# 定义了一个函数 non_max_suppression ,它接受多个参数。
# 1.prediction :模型输出的预测结果,包含边界框、置信度和类别概率。如果模型在验证模式下,输出可能是一个包含推理结果和损失结果的列表或元组,此时函数只取推理结果。
# 2.conf_thres :置信度阈值,用于过滤低置信度的预测结果。
# 3.iou_thres :交并比(IoU)阈值,用于非极大值抑制(NMS)过程中去除重叠过多的边界框。
# 4.classes :指定要返回的类别列表。如果为 None ,则返回所有类别的预测结果。
# 5.agnostic :如果设置为 True ,则NMS过程中不考虑类别,即所有边界框都视为同一类别。
# 6.multi_label :如果设置为 True ,则允许每个边界框有多个类别标签。
# 7.labels :如果提供了标签,则这些标签将用于自动标记预测结果。
# 8.max_det :每个图像允许返回的最大检测数量。
# 9.nm :掩码的数量。 nm 代表每个边界框关联的掩码数量,通常用于实例分割任务。
def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nm=0, # number of masks
):
# 对推理结果进行非最大抑制 (NMS) 以拒绝重叠检测。
# 返回 :
# 检测列表,每个图像的 (n,6) 张量 [xyxy, conf, cls]。
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
# 这段代码是 non_max_suppression 函数的开始部分,它处理输入的预测结果,并进行一些初步的设置和检查。
# 检查 prediction 是否是一个列表或元组。在某些情况下,比如模型在验证模式下,输出可能是一个包含推理结果和损失结果的列表或元组。
if isinstance(prediction, (list, tuple)): # YOLO model in validation model, output = (inference_out, loss_out)
# 如果 prediction 是列表或元组,那么只选择第一个元素作为推理输出。这是因为在验证模式下,模型可能会返回一个包含两个元素的元组,第一个元素是推理结果,第二个元素是损失值。
prediction = prediction[0] # select only inference output
# 获取预测结果张量所在的设备(CPU或GPU)。这是为了后续可能需要将张量移动到CPU上进行处理。
device = prediction.device
# 检查设备类型是否包含 'mps',即是否是Apple的Metal Performance Shaders(MPS),这是一种在Apple硬件上优化的深度学习框架。
mps = 'mps' in device.type # Apple MPS
# 如果检测到设备是MPS,由于MPS可能不支持所有的操作,代码将预测结果张量转移到CPU上,以确保后续的NMS操作可以在支持的环境中执行。
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
# 如果设备是MPS,执行这个操作将预测结果张量转移到CPU。
prediction = prediction.cpu()
# 获取批量大小,即预测结果张量的第一个维度的大小,这代表了同时处理的图像数量。
bs = prediction.shape[0] # batch size
# 计算类别的数量。预测结果张量的第二个维度的大小减去掩码的数量 nm 和固定的4个值(通常是边界框的坐标和置信度),得到类别的数量。
nc = prediction.shape[1] - nm - 4 # number of classes
# 计算掩码开始的索引。这是 边界框坐标 、 置信度 和 类别概率 之后的起始位置。
mi = 4 + nc # mask start index
# 对于每个预测结果,选择第4个值到 mi 索引之间的值(即类别概率),找到这些值中的最大值,并检查是否大于置信度阈值 conf_thres 。这样可以得到一个布尔掩码 xc ,用于筛选出置信度超过阈值的预测结果,这些结果将作为候选框进入后续的NMS过程。
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
# 这两行代码使用了 Python 的 assert 语句来确保函数 non_max_suppression 的两个参数 conf_thres (置信度阈值)和 iou_thres (交并比阈值)在有效范围内。 assert 语句用于在代码中设置条件检查,如果条件为假(即 assert 后面的表达式结果为 False ),则会引发一个 AssertionError 。
# Checks
# 检查 conf_thres 是否在0和1之间(包括0和1)。如果 conf_thres 的值小于0或大于1,将会引发一个 AssertionError ,并显示错误消息。
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' # 置信度阈值 {conf_thres} 无效,有效值介于 0.0 和 1.0 之间。
# 检查 iou_thres 是否在0和1之间。如果 iou_thres 的值超出这个范围,将会引发一个 AssertionError ,并显示错误消息。
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' # IoU {iou_thres} 无效,有效值介于 0.0 至 1.0 之间。
# 这些检查确保了在执行非极大值抑制之前,置信度阈值和IoU阈值是合理的,从而避免了因参数设置不当而导致的潜在错误。如果这些参数超出了预期范围,函数将不会继续执行,而是立即报错,提示用户检查输入参数。
# 这段代码是 non_max_suppression 函数中设置部分,它定义了一些用于控制非极大值抑制(NMS)过程的参数和初始设置。
# Settings
# min_wh = 2 # (pixels) minimum box width and height
# 设置最大宽度和高度限制为7680像素。这个参数用于限制边界框的大小,防止异常值影响NMS的结果。
max_wh = 7680 # (pixels) maximum box width and height
# 设置传递给 torchvision.ops.nms() 的边界框数量上限为30000。这是为了防止一次性处理过多的边界框,可能会导致内存不足或性能问题。
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
# 设置NMS操作的时间限制为2.5秒加上批量大小的0.05倍。这个时间限制用于避免NMS过程耗时过长,影响整体的推理速度。
time_limit = 2.5 + 0.05 * bs # seconds to quit after
# 设置是否需要冗余检测。如果设置为True,即使NMS后的结果数量少于最大检测数,也会保留所有NMS后的结果。
redundant = True # require redundant detections
# 如果类别数量大于1,启用多标签模式。这意味着每个边界框可以有多个类别标签。
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
# 设置是否使用合并NMS。合并NMS是一种NMS的变体,它会合并重叠的边界框而不是直接删除它们。
merge = False # use merge-NMS
# 记录当前时间,用于后续计算NMS操作耗时。
t = time.time()
# 初始化输出列表 output ,它将存储每个图像的NMS结果。对于批量大小 bs 中的每个图像,创建一个形状为 (0, 6 + nm) 的零张量,其中 6 代表边界框的四个坐标、置信度和一个类别标签, nm 是掩码的数量。这些零张量被复制 bs 次,每个图像一个。
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
# 这些设置为NMS过程提供了必要的参数和初始状态,确保了NMS可以按照预期执行,并且可以在合理的时间范围内完成。
# 这段代码是 non_max_suppression 函数中处理每个图像预测结果的部分。
# 开始一个循环,遍历 prediction 张量中的每个元素。 xi 是图像的索引, x 是对应图像的预测结果。这里的 prediction 是一个批量中的所有图像预测结果的集合。
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
# 对每个图像的预测结果进行转置( .T ),并使用之前计算的布尔掩码 xc 来筛选出置信度超过阈值的预测结果。这样, x 就只包含那些置信度超过 conf_thres 的预测框。
x = x.T[xc[xi]] # confidence
# Cat apriori labels if autolabelling
# 码检查是否存在预先定义的标签 labels ,并且当前图像 xi 是否有标签。如果存在,则继续处理。
if labels and len(labels[xi]):
# 获取当前图像的标签。
lb = labels[xi]
# 创建一个新的零张量 v ,其形状为 (标签数量, 类别数量 + 掩码数量 + 5) ,这里的 5 代表边界框的四个坐标和置信度。
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
# 将标签中的边界框坐标复制到 v 的前四列。
v[:, :4] = lb[:, 1:5] # box
# 在 v 中设置类别标签。这里 lb[:, 0] 是类别索引,加上 4 是因为类别概率在第五列开始的位置。
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
# 将筛选后的预测结果 x 和新的标签张量 v 沿着第一个维度(行)拼接起来。
x = torch.cat((x, v), 0)
# If none remain process next image
# 检查处理后的 x 是否为空,即没有任何预测结果。
if not x.shape[0]:
# 如果 x 为空,则跳过当前图像的处理,继续下一个图像。
continue
# 这段代码的主要作用是遍历批量中的每个图像,应用置信度约束,如果有预定义的标签,则将它们添加到预测结果中,最后检查是否有有效的预测结果,如果没有,则跳过当前图像。
# 这段代码继续处理每个图像的预测结果,将其转换为标准的边界框格式,并根据置信度和类别进行筛选。
# Detections matrix nx6 (xyxy, conf, cls)
# 将预测结果 x 按照维度1(列)分割成三个部分 : box (边界框坐标), cls (类别概率), mask (掩码)。 split 方法的参数 (4, nc, nm) 指定了每个部分的大小。
box, cls, mask = x.split((4, nc, nm), 1)
# 将边界框坐标从中心点坐标和宽高 ( x, y, w, h ) 转换为左上角和右下角坐标 ( x1, y1, x2, y2 )。这个转换是通过调用 xywh2xyxy 函数实现的。
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
# 如果启用了多标签模式,即每个边界框可以有多个类别标签。
if multi_label:
# 找出所有置信度超过阈值 conf_thres 的类别索引 i 和类别 j 。
i, j = (cls > conf_thres).nonzero(as_tuple=False).T
# 将 边界框坐标 、 置信度 、 类别索引 和 掩码 进行拼接,形成最终的检测结果矩阵。这里 x[i, 4 + j, None] 获取每个边界框对应类别的置信度。
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
# 如果不启用多标签模式,只选择每个边界框置信度最高的类别。
else: # best class only
# 找出每个边界框置信度最高的类别及其置信度。
conf, j = cls.max(1, keepdim=True)
# 将 边界框坐标 、 最高置信度 、 类别索引 和 掩码 进行拼接,然后根据置信度阈值筛选结果。
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
# Filter by class
# 如果指定了要筛选的类别列表。
if classes is not None:
# 只保留那些类别标签在指定类别列表中的边界框。这里 x[:, 5:6] 取出每个边界框的类别索引,然后与指定的类别列表进行比较。
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# 这段代码的核心作用是将预测结果转换为标准的边界框格式,并根据置信度和类别进行筛选,以准备进行非极大值抑制(NMS)。
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# 这段代码是 non_max_suppression 函数中的一部分,它负责检查每个图像的预测边界框数量,并根据需要进行排序和筛选。
# Check shape
# 获取当前图像的预测边界框数量 n 。
n = x.shape[0] # number of boxes
# 如果 n 为0,即没有预测边界框,那么跳过当前图像的处理,继续处理下一张图像。
if not n: # no boxes
continue
# 如果预测边界框数量 n 超过设置的最大值 max_nms ,则进行如下操作。
elif n > max_nms: # excess boxes
# 对预测结果 x 按照置信度(第5列,索引为4)进行降序排序,并选择置信度最高的前 max_nms 个边界框。这是因为在目标检测中,通常只保留置信度最高的边界框进行后续处理。
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
# 如果预测边界框数量 n 没有超过 max_nms ,则进行如下操作:
else:
# 对预测结果 x 按照置信度进行降序排序,保留所有边界框。这是因为边界框数量没有超过限制,所以不需要筛选。
x = x[x[:, 4].argsort(descending=True)] # sort by confidence
# 这段代码的目的是确保每个图像的预测边界框数量在合理的范围内,并且在进行非极大值抑制(NMS)之前,根据置信度对边界框进行排序。这样可以提高NMS的效率,并且确保只处理最有可能包含目标的边界框。
# 这段代码是 non_max_suppression 函数中执行批量非极大值抑制(NMS)的部分。
# Batched NMS
# prediction 张量通常具有以下结构 :
# 列 0-3: 边界框坐标(例如 x, y, w, h )。
# 列 4: 目标得分(置信度)。
# 列 5 : 类别索引。
# 列 6 及 以后 :类别概率。
# 在目标检测模型的输出张量 prediction 中,第5列及以后的列包含了类别概率。然而,在计算类别偏移量 c 时,我们只使用第5列,即类别索引,而不是后续的类别概率列。这是因为类别偏移量 c 的目的是为了在NMS过程中区分不同类别的边界框,而不是基于类别概率。
# 以下是为什么只在第5列进行类别偏移量 c 计算的原因 :
# 类别索引作为标识 :
# 第5列通常包含了每个边界框预测的类别索引,这是一个离散的值,用于标识每个边界框最可能属于的类别。这个索引是一个整数,代表了类别的ID。
# 类别概率与NMS无关 :
# 后续的列包含了每个边界框属于各个类别的概率。这些概率值用于确定边界框属于每个类别的置信度,但在NMS过程中,我们关心的是边界框的空间位置和它们所属的类别,而不是具体的概率值。
# 避免类别间重叠 :
# 在NMS中,我们通过添加类别偏移量来确保属于不同类别的边界框在坐标空间中不会重叠。这个偏移量是基于类别索引计算的,因为我们需要一个固定且唯一的值来代表每个类别。
# 简化计算 :
# 使用类别索引来计算偏移量简化了计算过程。我们不需要考虑类别概率的具体数值,只需要一个标识符来区分不同的类别。
# 实现效率 :
# 在实际实现中,使用类别索引作为偏移量可以提高计算效率。我们不需要对每个类别的概率进行迭代或加权,只需要一个简单的乘法操作即可得到偏移量。
# 兼容性 :
# 这种方法与PyTorch等深度学习框架中NMS的实现兼容,这些框架通常期望类别信息以索引的形式存在,而不是以概率的形式。
# 综上所述,类别偏移量 c 的计算只涉及第5列的类别索引,是因为这个索引为每个边界框提供了一个唯一的类别标识,这对于在NMS过程中区分不同类别的边界框是必要的。而类别概率并不直接影响NMS的执行,因此不用于计算偏移量。
# 计算类别偏移量 c 。如果 agnostic 为 True ,则所有类别的偏移量为0;否则,使用 max_wh 作为偏移量。这个偏移量用于在NMS过程中考虑类别信息。
# x[:, 5:6] 代表了每个边界框的类别索引。
# 只在 x 的第5列进行类别偏移量的计算,是因为第5列通常代表了边界框的类别索引,而这个索引是我们在多类别检测场景中执行NMS时需要考虑的关键信息。
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
# 计算调整后的边界框 boxes 和置信度 scores 。边界框通过加上类别偏移量 c 来调整, scores 是预测结果中的置信度列。
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
# 使用 PyTorch 的 torchvision.ops.nms 函数执行非极大值抑制,根据 IoU 阈值 iou_thres 去除重叠的边界框,并返回保留的边界框索引 i 。
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
# 如果保留的边界框数量超过最大检测数 max_det ,则只保留前 max_det 个。
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
# 如果启用了合并NMS并且边界框数量在一定范围内,则执行合并NMS。合并NMS是一种NMS的变体,它会合并而不是直接删除重叠的边界框。
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# 计算保留的边界框和所有边界框之间的 IoU 矩阵,并与 IoU 阈值进行比较,得到一个布尔矩阵。
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
# 根据 IoU 矩阵和置信度计算权重,权重高的边界框在合并时有更大的影响。
weights = iou * scores[None] # box weights
# 使用 权重矩阵 和 所有边界框的坐标 执行矩阵乘法,计算合并后的边界框坐标,并进行归一化。
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
# 如果需要冗余检测,即保留多个重叠的边界框。
if redundant:
# 只保留那些与多个边界框重叠的边界框索引。
i = i[iou.sum(1) > 1] # require redundancy
# 这段代码的核心作用是在每个图像中执行NMS,去除多余的边界框,并在启用合并NMS的情况下合并重叠的边界框。这样可以确保每个图像中只保留最佳的边界框,减少冗余并提高检测的准确性。
# 这段代码是 non_max_suppression 函数的最后部分,它负责将处理后的检测结果保存到输出列表中,并检查是否超出了时间限制。
# 将经过NMS处理后保留的边界框索引 i 对应的检测结果 x[i] 赋值给输出列表 output 中的第 xi 个元素,其中 xi 是当前处理的图像在批量中的索引。
output[xi] = x[i]
# 这个条件判断用于检查是否在Apple的Metal Performance Shaders(MPS)设备上执行代码。
if mps:
# 如果在MPS设备上,将输出张量 output[xi] 转移到原始设备(CPU或GPU)。这是因为MPS设备可能不支持所有的PyTorch操作,所以需要在CPU上执行NMS,然后将结果转移回原始设备。
output[xi] = output[xi].to(device)
# 这行代码检查自NMS开始以来的时间是否超过了设定的时间限制 time_limit 。
if (time.time() - t) > time_limit:
# 如果超过了时间限制,使用 LOGGER 发出警告,提示NMS的时间限制被超过了。
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded') # 警告 ⚠️ 已超出 NMS 时间限制 {time_limit:.3f}s。
# 如果超过了时间限制,跳出循环,不再处理剩余的图像。
break # time limit exceeded
# 函数返回最终的输出列表 output ,其中包含了批量中每个图像经过NMS处理后的检测结果。
return output
# 这段代码确保了NMS处理的结果被正确保存,并且在性能和时间上进行了优化。如果在MPS设备上,它会将结果转移回原始设备。此外,它还通过检查时间限制来防止NMS过程耗时过长,影响整体的推理效率。
# 这个函数实现了NMS的核心逻辑,包括置信度筛选、类别筛选、IoU计算和NMS操作。它还支持多标签模式、合并NMS和时间限制等功能。
64.def strip_optimizer(f='best.pt', s=''):
# 这段代码定义了一个名为 strip_optimizer 的函数,其目的是从训练好的模型文件中移除优化器(optimizer)和其他一些不需要的键,以便在不继续训练的情况下使用模型进行推理。这通常用于模型的部署阶段,因为优化器在推理时是不需要的。
# 定义了一个名为 strip_optimizer 的函数,它接受两个参数。
# 1.f :要处理的模型文件名,默认为 'best.pt' 。
# 2.s :处理后的模型文件名,默认为空字符串,如果不提供,则覆盖原文件。
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
# 使用 torch.load 加载模型文件 f ,并将其映射到 CPU 设备。
x = torch.load(f, map_location=torch.device('cpu'))
# 如果模型字典 x 中包含 'ema' 键,表示模型使用了指数移动平均(EMA)技术,将模型替换为 EMA 模型。
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
# 遍历一个包含不需要的键的元组,并将其设置为 None ,从而在保存时这些键将不会被包含。
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
x[k] = None
# 将模型字典中的 'epoch' 键设置为 -1 ,表示训练已完成。
x['epoch'] = -1
# 将模型转换为半精度(FP16)。
x['model'].half() # to FP16
# 遍历模型的所有参数,并禁用梯度计算。
for p in x['model'].parameters():
p.requires_grad = False
# 保存处理后的模型字典 x 到文件 s 或原文件 f 。
torch.save(x, s or f)
# os.path.getsize(path)
# os.path.getsize() 是 Python 标准库 os.path 模块中的一个函数,用于获取文件的大小,单位为字节。
# 参数 :
# path :文件的路径,可以是相对路径或绝对路径。
# 返回值 :
# 返回指定文件的大小,以字节为单位。如果文件不存在,会抛出 FileNotFoundError 异常。
# 注意事项 :
# os.path.getsize() 返回的是文件的实际大小,如果文件很大,返回的值可能非常大。
# 这个函数只能用于获取文件的大小,不能用于目录。
# 如果需要以更友好的方式显示文件大小(例如,自动转换为 KB、MB 或 GB),可能需要自己编写额外的代码来处理单位转换。
# 计算处理后的模型文件大小。
mb = os.path.getsize(s or f) / 1E6 # filesize
# 使用 LOGGER 记录一条包含原文件名、处理后的文件名(如果提供了 s )和文件大小的信息日志。
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") # 优化器从 {f} 中剥离,{f' 保存为 {s},' if s else ''} {mb:.1f}MB 。
# strip_optimizer 函数用于从训练好的模型中移除优化器和其他不需要的键,以便在推理时使用。这个函数还替换模型为 EMA 模型(如果存在),禁用梯度计算,并将模型转换为半精度以节省内存。最后,函数保存处理后的模型并记录相关信息。这个函数对于模型的部署和推理阶段非常有用。
# 在机器学习和深度学习中,优化器(Optimizer)是用于调整模型参数以最小化损失函数的算法。以下是一些常见的优化器 :
# SGD(随机梯度下降) :最基础的优化器之一,通过随机采样梯度来更新模型参数。
# Momentum SGD :在SGD的基础上增加了动量项,帮助加速SGD在相关方向上的进展,并抑制振荡。
# Nesterov Accelerated Gradient (NAG) :是Momentum SGD的一个变种,动量项考虑了梯度的方向。
# Adagrad :适用于处理稀疏数据,通过累积历史梯度的平方和来调整每个参数的学习率。
# RMSprop :类似于Adagrad,但是梯度的累积方式不同,使用移动平均来计算平方根。
# Adam (Adaptive Moment Estimation) :结合了Momentum和RMSprop的特点,是目前最流行的优化器之一。
# Adadelta :类似于RMSprop,但是不依赖于初始学习率。
# AdamW :是Adam的变种,与权重衰减一起使用时更为有效。
# Nadam (Nesterov-accelerated Adaptive Moment Estimation) :结合了NAG和Adam的特点。
# 在模型中优化器的作用包括 :
# 参数更新 :优化器根据损失函数的梯度更新模型的参数,这是训练过程中最关键的步骤之一。
# 调整学习率 :不同的优化器以不同的方式调整学习率,以确保模型能够有效地收敛。
# 加速收敛 :某些优化器(如Momentum和Adam)能够加速SGD的收敛速度。
# 处理不同数据特性 :不同的优化器适用于不同类型的数据和优化问题,选择合适的优化器可以提高模型的性能。
# 稳定性 :优化器可以帮助模型在训练过程中保持稳定,尤其是在面对复杂的损失景观时。
# 适应性 :一些优化器能够自适应地调整每个参数的学习率,这对于非均匀或病态问题特别有用。
# 选择合适的优化器对于模型的训练效果和收敛速度至关重要。在实际应用中,可能需要根据具体问题和数据集的特性来选择或调整优化器。
65.def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
# 这段代码定义了一个名为 print_mutation 的函数,其目的是在神经网络架构搜索或超参数优化过程中记录和打印结果。这个函数将结果保存到 CSV 和 YAML 文件中,并在控制台打印当前的进化代数和结果。
# 定义了一个名为 print_mutation 的函数,它接受六个参数。
# 1.keys :结果字典的键列表。
# 2.results :与键对应的结果值列表。
# 3.hyp :超参数字典。
# 4.save_dir :保存文件的目录。
# 5.bucket :Google Cloud Storage 存储桶名称,用于可选的上传和下载操作。
# 6.prefix :日志消息的前缀,默认为 'evolve: ' 。
def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
# 这段代码是 print_mutation 函数的一部分,它执行了几个步骤来准备数据,以便将结果记录到 CSV 和 YAML 文件中。
# 使用路径拼接操作符 / 来定义两个文件的路径。 save_dir 是一个 Path 对象,表示保存文件的目录。 evolve_csv 和 evolve_yaml 分别是 CSV 和 YAML 文件的完整路径。
evolve_csv = save_dir / 'evolve.csv'
evolve_yaml = save_dir / 'hyp_evolve.yaml'
# 将传入的 keys 列表和超参数字典 hyp 的键合并,形成一个包含所有键的新元组。 keys 列表可能包含结果的键,而 hyp.keys() 返回超参数的键。
keys = tuple(keys) + tuple(hyp.keys()) # [results + hyps]
# 使用生成器表达式遍历 keys 元组,并使用 strip() 方法移除每个键名前后的空白字符,然后再次转换为元组。
keys = tuple(x.strip() for x in keys)
# 将 results 列表和超参数字典 hyp 的值合并,形成一个包含所有值的新元组。
vals = results + tuple(hyp.values())
# 计算合并后的键元组 keys 的长度,存储在变量 n 中。这个值将用于后续的格式化输出。
n = len(keys)
# 这些代码行准备了要记录到 CSV 和 YAML 文件中的数据。它们定义了文件路径,合并并清理了键,准备了对应的值,并计算了键的数量。这些步骤为后续的文件写入和数据记录奠定了基础。
# # 这段代码是 print_mutation 函数中处理下载操作的部分,用于从 Google Cloud Storage (GCS) 下载文件。
# Download (optional)
# 检查是否提供了 bucket 参数。如果提供了,表示用户希望进行与 Google Cloud Storage 相关的操作。
if bucket:
# 构建一个指向 Google Cloud Storage 中特定存储桶 ( bucket ) 的 evolve.csv 文件的 URL。
url = f'gs://{bucket}/evolve.csv'
# 调用 gsutil_getsize 函数来获取远程 evolve.csv 文件的大小。这个函数它使用 gsutil 命令行工具来获取文件大小。
# 然后,代码检查这个远程文件的大小是否大于本地文件的大小(如果本地文件存在)。如果本地文件不存在,则 evolve_csv.stat().st_size 将不会执行,因为 exists() 方法将返回 False ,因此比较的是远程文件大小和 0 。
# def gsutil_getsize(url=''): -> 获取通过 Google Cloud Storage 的 gsutil 工具指定 URL 的文件大小(以字节为单位)。 -> return eval(s.split(' ')[0]) if len(s) else 0
if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
# 如果远程文件比本地文件大,或者本地文件不存在,则执行 os.system 调用,运行 gsutil cp 命令将远程文件复制到本地 save_dir 目录。这一步是可选的,仅当远程文件更新时才会执行。
os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
# 这段代码实现了一个可选的下载步骤,用于从 Google Cloud Storage 下载最新的 evolve.csv 文件到本地目录。这在分布式训练或超参数优化场景中很有用,因为它允许用户同步在不同训练周期中生成的数据。如果远程文件比本地文件新(即文件大小更大),则会自动下载更新的文件,确保本地数据是最新的。
# 这段代码负责将信息记录到 evolve.csv 文件中,包括可选的表头和当前的值。
# Log to evolve.csv
# 检查 evolve.csv 文件是否存在。如果文件已存在,则 s 被设置为空字符串,不添加表头。如果文件不存在, s 将包含由 keys 生成的 CSV 表头。 '%20s,' * n 创建了一个格式化字符串,其中 n 是键的数量, %20s 表示宽度为20的字符串字段。这个字符串与 keys 元组一起使用,生成一个包含所有键的 CSV 行,并去掉了末尾的逗号。
s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
# 以追加模式 ( 'a' ) 打开 evolve_csv 文件,允许在文件末尾添加内容。
with open(evolve_csv, 'a') as f:
# 将格式化后的表头 s (如果文件不存在时生成)和格式化后的值 vals 写入文件。 '%20.5g,' * n 创建了一个格式化字符串,用于将每个值格式化为最多5位小数的浮点数,并占据宽度为20的字段。这个字符串与 vals 元组一起使用,生成一个包含所有值的 CSV 行,并去掉了末尾的逗号,然后在末尾添加了换行符。
f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
# 这段代码将结果记录到 evolve.csv 文件中,如果文件不存在,则添加表头;如果文件已存在,则只添加当前的值。这种记录方式使得每次函数调用的结果都可以追加到文件中,便于跟踪和分析超参数优化过程。通过格式化字符串,确保了 CSV 文件中的数据具有良好的可读性和一致性。
# 这段代码负责将从 evolve.csv 文件中读取的数据保存到 evolve.yaml 文件中,并且以一种更易读的格式记录最佳代的信息。
# Save yaml
# 以写入模式打开 evolve_yaml 文件。如果文件已存在,则覆盖;如果不存在,则创建。
with open(evolve_yaml, 'w') as f:
# 使用 pandas 库读取 evolve.csv 文件到 DataFrame 对象 data 中。
data = pd.read_csv(evolve_csv)
# 使用 rename 方法和 lambda 函数去除 DataFrame 中所有列名的前后空白字符。
data = data.rename(columns=lambda x: x.strip()) # strip keys
# 找到 data 中前四列(通常包含适应度相关的数据)的最大值的索引 i ,这代表了最佳代的索引。
# def fitness(x): -> 计算模型的适应度,作为一系列度量指标的加权组合。返回值。函数返回一个包含每个模型适应度得分的数组。 -> return (x[:, :4] * w).sum(1)
i = np.argmax(fitness(data.values[:, :4])) #
# 计算 DataFrame data 的行数,即进化的总代数。
generations = len(data)
# 将一些元数据写入 YAML 文件的开头,包括最佳代的信息、最后一代的信息以及前七个键和它们对应的值的格式化字符串。
f.write('# YOLO Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
'\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n') # YOLO 超参数演化结果\n'最佳代:{i}\n' 最后一代:{generations - 1}\n' + '# ' ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n' 。
# 使用 pandas 的 loc 方法选取最佳代(索引 i )从第八列到最后的数据,并将其转换为字典。然后使用 yaml.safe_dump 将字典以 YAML 格式安全地写入文件, sort_keys=False 确保键的顺序与原始 DataFrame 保持一致。
yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
# 这段代码将 evolve.csv 中的数据转换为 YAML 格式,并记录了进化过程中的最佳代信息。通过这种方式,可以方便地查看和比较不同代的超参数和性能。YAML 文件提供了一种易于阅读和解析的数据格式,适合于配置文件和数据记录。
# 这段代码负责将进化代数和当前结果打印到控制台,并在提供了存储桶名称的情况下,将 evolve.csv 和 evolve.yaml 文件上传到 Google Cloud Storage。
# Print to screen
# 使用 LOGGER.info 方法将信息打印到控制台。 prefix 是日志消息的前缀, generations 是完成的进化代数, keys 是结果和超参数的键, vals 是对应的值。
# f'{generations} generations finished, current result:\n' :格式化字符串,显示完成的代数和当前结果。
# prefix + ', '.join(f'{x.strip():>20s}' for x in keys) :将 keys 中的每个键名去除前后空格,并格式化为右对齐的字符串,每个键名占用20个字符的宽度。
# prefix + ', '.join(f'{x:20.5g}' for x in vals) :将 vals 中的每个值格式化为最多5位小数的浮点数,占用20个字符的宽度。
LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
for x in vals) + '\n\n')
# 如果提供了 bucket 参数,使用 os.system 调用 shell 命令,将 evolve.csv 和 evolve.yaml 文件上传到指定的 Google Cloud Storage 存储桶。
if bucket:
# gsutil cp : gsutil 工具的复制命令,用于将文件从本地复制到 Google Cloud Storage。
# {evolve_csv} {evolve_yaml} :本地的 CSV 和 YAML 文件路径。
# gs://{bucket} :目标 Google Cloud Storage 存储桶的路径。
os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
# 这段代码实现了两个功能 :将进化算法的结果打印到控制台,以及在提供了存储桶名称的情况下,将结果文件上传到云端。这样,用户可以方便地查看进化过程的进度,并在云端保存结果,便于后续的分析和记录。
# print_mutation 函数用于记录和打印超参数优化过程中的结果,它将结果保存到本地文件,并可选地上传到 Google Cloud Storage。这个函数对于跟踪优化过程和结果非常有用,尤其是在进行大规模的超参数搜索时。
66.def apply_classifier(x, model, img, im0):
# 这段代码定义了一个名为 apply_classifier 的函数,它用于对 YOLO 目标检测模型的输出应用第二阶段分类器。这种技术常用于提高模型对目标类别的分类精度。
# 定义了一个名为 apply_classifier 的函数,它接受四个参数。
# 1.x :YOLO 模型的输出,通常包含检测到的边界框、置信度和类别概率。
# 2.model :用于第二阶段分类的分类器模型。
# 3.img :YOLO 模型处理的原始图像的尺寸。
# 4.im0 :预处理前的原始图像或图像列表。
def apply_classifier(x, model, img, im0):
# Apply a second stage classifier to YOLO outputs 将第二阶段分类器应用于 YOLO 输出。
# Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
# 确保 im0 是一个列表,无论是单个图像还是多个图像。
im0 = [im0] if isinstance(im0, np.ndarray) else im0
# 遍历 x 中的每个元素,其中 i 是索引, d 是包含当前图像检测结果的张量。
for i, d in enumerate(x): # per image
# 如果检测结果不为空,则克隆一份以避免修改原始数据。
if d is not None and len(d):
d = d.clone()
# 这段代码负责对检测到的目标进行重塑和填充操作,并将边界框从图像尺寸调整回原始图像尺寸。
# Reshape and pad cutouts
# 将边界框从 xyxy 格式(即左上角和右下角的坐标)转换为 xywh 格式(即中心点坐标和宽高)。
b = xyxy2xywh(d[:, :4]) # boxes
# 将每个边界框调整为正方形。它通过取每个边界框的宽度和高度的最大值来实现,然后将这个最大值赋值给宽度和高度,从而将矩形调整为正方形。
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
# 对正方形边界框进行填充。首先将宽度和高度扩大 1.3 倍,然后再加上 30 像素的填充。
b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
# 将 xywh 格式的边界框转换回 xyxy 格式,并确保坐标为整数类型。
d[:, :4] = xywh2xyxy(b).long()
# Rescale boxes from img_size to im0 size
# 调整边界框的尺寸,使其从 YOLO 模型处理的图像尺寸( img )调整回原始图像尺寸( im0[i] )。 scale_boxes 函数接受目标尺寸、边界框坐标和原始尺寸作为参数,并按比例调整边界框坐标。
# def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): -> 用于调整边界框(boxes)的大小,使其适应新的图像尺寸。返回重新缩放后的边界框数组。 -> return boxes
scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
# 这段代码通过对检测到的目标进行重塑、填充和尺寸调整,确保裁剪区域适合进一步的处理,例如送入分类器进行分类。这个过程对于提高目标检测的精度和鲁棒性非常重要,特别是在需要对检测结果进行细粒度分析的场景中。
# 这段代码负责从检测结果中提取类别信息,并为每个检测到的目标裁剪图像区域,调整大小并进行预处理,以便用于分类器模型的输入。
# Classes
# 从检测结果 d 中提取第六列的数据,即类别索引,并将其转换为长整型( long )。这通常表示模型预测的类别。
pred_cls1 = d[:, 5].long()
# 初始化一个空列表 ims ,用于存储每个裁剪和预处理后的图像区域。
ims = []
# 遍历检测结果 d ,对于每个检测到的目标,根据其边界框坐标( xyxy 格式)从原始图像 im0[i] 中裁剪出图像区域。
for a in d:
cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
# 将裁剪出的图像区域调整大小为 224x224 像素,这是许多预训练分类器模型的标准输入尺寸。
im = cv2.resize(cutout, (224, 224)) # BGR
# 将图像从 BGR 格式转换为 RGB 格式,并重新排列维度以匹配 PyTorch 张量的格式(通道数在前)。
im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
# 将图像数据类型从 uint8 转换为 float32,并确保内存是连续的,这对于 PyTorch 张量是必要的。
im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
# 将图像像素值从 [0, 255] 范围归一化到 [0.0, 1.0] 范围。
im /= 255 # 0 - 255 to 0.0 - 1.0
# 将预处理后的图像添加到列表 ims 中。
ims.append(im)
# 这段代码通过对每个检测到的目标进行裁剪、调整大小、颜色空间转换、数据类型转换和归一化,准备用于分类器模型的输入图像。这些步骤是将目标检测模型的输出转换为适合进一步分类处理的标准格式的关键步骤。
# 这段代码继续处理 apply_classifier 函数中的操作,用于应用第二阶段分类器并对结果进行筛选。
# 将预处理后的图像列表 ims 转换为 PyTorch 张量,并发送到检测结果 d 所在的设备(例如 GPU)。然后,使用分类器模型 model 对这些图像进行预测。 argmax(1) 函数沿维度 1(类别维度)返回每个预测的最高概率索引,即分类器预测的类别。
pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
# 筛选出那些 YOLO 模型预测类别与分类器模型预测类别相匹配的检测结果。 pred_cls1 是 YOLO 模型预测的类别, pred_cls2 是分类器模型预测的类别。只有当两者相等时,对应的检测结果才会被保留。
x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
# 函数返回更新后的检测结果 x ,其中只包含那些 YOLO 模型和分类器模型预测类别相匹配的检测。
return x
# 这段代码实现了第二阶段分类器的应用和结果筛选。通过这种方式,可以提高目标检测的精度,尤其是在需要对检测结果进行更精细分类的场景中。最终,只有那些经过分类器确认的检测结果才会被保留,从而提高了整个系统的准确性和可靠性。
# apply_classifier 函数通过应用第二阶段分类器来提高目标检测的精度。它处理 YOLO 模型的输出,裁剪和调整图像区域,使用分类器进行预测,并筛选出匹配的检测结果。这种方法特别适用于需要高分类精度的场景。
67.def increment_path(path, exist_ok=False, sep='', mkdir=False):
# 这段代码定义了一个名为 increment_path 的函数,它的作用是为文件或目录生成一个新的路径名,如果原始路径已经存在,则通过在路径后面添加一个数字(默认从2开始递增)来创建一个新的路径。
# 定义了一个函数 increment_path ,它接受三个参数。
# 1.path :要检查或修改的文件或目录的路径。
# 2.exist_ok :一个布尔值,如果为 True ,则即使路径已存在也不会递增路径名。
# 3.sep :用于分隔数字的字符串,默认为空字符串。
# 4.mkdir :一个布尔值,如果为 True ,则会创建对应的目录。
def increment_path(path, exist_ok=False, sep='', mkdir=False):
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
# 将传入的 path 参数转换为 Path 对象,这样可以跨操作系统使用,因为 Path 对象会根据操作系统处理路径分隔符。
path = Path(path) # os-agnostic
# 检查路径是否存在,如果存在且 exist_ok 参数为 False ,则进入下面的代码块。
if path.exists() and not exist_ok:
# 如果 path 是一个文件,那么将文件的后缀名保存到 suffix 变量中,并将 path 对象的后缀名去掉,以便后续添加数字。
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
# Method 1
# 使用一个 for 循环从2开始到9999,尝试生成新的路径名。
for n in range(2, 9999):
# 使用格式化字符串构造新的路径名,其中 sep 是分隔符, n 是递增的数字, suffix 是文件的后缀名(如果 path 是文件的话)。
p = f'{path}{sep}{n}{suffix}' # increment path
# 检查新构造的路径 p 是否不存在。
if not os.path.exists(p): #
# 如果找到了一个不存在的路径,就跳出循环。
break
# 将找到的新路径赋值给 path 变量。
path = Path(p)
# Method 2 (deprecated)
# dirs = glob.glob(f"{path}{sep}*") # similar paths
# matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
# i = [int(m.groups()[0]) for m in matches if m] # indices
# n = max(i) + 1 if i else 2 # increment number
# path = Path(f"{path}{sep}{n}{suffix}") # increment path
# 如果 mkdir 参数为 True ,则执行下面的代码。
if mkdir:
# 创建路径对应的目录,如果目录的父目录不存在也会被创建,如果目录已存在则不会抛出异常。
path.mkdir(parents=True, exist_ok=True) # make directory
# 函数返回最终的路径 path 。
return path
# 这个函数可以用于确保在创建新的文件或目录时不会覆盖已有的文件或目录,通过在原有路径后面添加递增的数字来避免冲突。如果设置了 mkdir 参数,它还会创建这个新的路径对应的目录。
68.def imread(path, flags=cv2.IMREAD_COLOR):
# OpenCV Chinese-friendly functions OpenCV中文友好函数------------------------------------------------------------------------------------
# 这段代码提供了一组函数,用于重写 OpenCV 的 imread 、 imwrite 和 imshow 函数,以便更好地集成到 Python 环境中。
# 定义 imshow_ 。创建了 cv2.imshow 的一个副本,命名为 imshow_ 。这样做是为了避免在后续重定义 cv2.imshow 时引起的递归调用错误。
imshow_ = cv2.imshow # copy to avoid recursion errors
# 这段代码定义了一个名为 imread 的函数,它是对 OpenCV 的 cv2.imread 函数的一个封装。这个函数用于从指定路径读取图像数据,并根据提供的参数以不同的模式解码图像。
# 定义了一个名为 imread 的函数,它接受两个参数。
# 1.path :图像文件的路径。
# 2.flags :一个可选参数,指定读取图像的模式,默认值为 cv2.IMREAD_COLOR ,表示以彩色模式读取图像。
def imread(path, flags=cv2.IMREAD_COLOR):
# numpy.fromfile(file, dtype=float, count=-1, sep='', offset=0)
# np.fromfile 是 NumPy 库中的一个函数,它用于从二进制文件中读取数据,并将其作为一维数组返回。这个函数特别适合于读取那些以二进制形式存储的数值数据,例如图像文件。
# 参数 :
# file :文件路径或者一个类似文件的对象。可以是字符串路径或者一个具有 read 方法的对象。
# dtype :数组元素的数据类型,默认为 float 。也可以指定为其他类型,如 np.int32 、 np.uint8 等。
# count :要读取的元素数量。如果设置为 -1 (默认值),则读取整个文件。
# sep :仅当文件是文本文件时使用,指定行与行之间的分隔符。
# offset :从文件开头开始读取之前要跳过的字节数。
# 返回值 :
# 返回一个一维 NumPy 数组,包含从文件中读取的数据。
# 注意事项 :
# np.fromfile 默认以二进制模式读取文件,因此不需要以 'rb' 模式打开文件。
# 如果文件中的数据类型不是 float ,需要正确指定 dtype 参数,以确保数据被正确解释。
# 如果文件很大,读取整个文件可能会消耗大量内存。在这种情况下,可以考虑分块读取文件。
# np.fromfile 不适用于文本文件。对于文本文件,可以使用 np.loadtxt 或 np.genfromtxt 等函数。
# cv2.imdecode(buf, flags)
# cv2.imdecode() 是 OpenCV 库中的一个函数,它用于从图像内存缓冲区中解码图像数据。这个函数通常与 cv2.imencode() 搭配使用,用于在内存中处理图像数据,而不是直接读写文件。
# 参数 :
# buf :一个包含图像数据的内存缓冲区,通常是一个 NumPy 数组。
# flags :一个标志,指定解码图像的模式,例如 cv2.IMREAD_COLOR 、 cv2.IMREAD_GRAYSCALE 或 cv2.IMREAD_UNCHANGED 。
# 返回值 :
# 返回一个包含解码图像的 NumPy 数组,如果解码失败,则返回 None 。
# 注意事项 :
# cv2.imdecode() 接受的 buf 参数应该是一个一维 NumPy 数组,其中包含了图像的原始字节数据。
# 函数的行为取决于 flags 参数,它决定了图像的解码方式和输出格式。例如, cv2.IMREAD_COLOR 表示以彩色模式读取图像, cv2.IMREAD_GRAYSCALE 表示以灰度模式读取图像。
# 如果 buf 参数为空或者图像数据损坏, cv2.imdecode() 将返回 None 。
# cv2.imdecode() 是处理图像数据流或从网络接收图像时的有用工具,因为它允许在不直接操作文件的情况下对图像进行解码。
# 这行代码执行以下操作 :
# np.fromfile(path, np.uint8) :使用 NumPy 的 fromfile 函数从 path 指定的文件路径读取图像数据,并将数据类型转换为 np.uint8 ,这是图像数据的常用数据类型。
# cv2.imdecode :使用 OpenCV 的 imdecode 函数将读取的图像数据解码成图像矩阵。 flags 参数控制解码模式,例如 cv2.IMREAD_COLOR 表示彩色图像, cv2.IMREAD_GRAYSCALE 表示灰度图像, cv2.IMREAD_UNCHANGED 表示不修改图像数据(包括Alpha通道)。
# 返回值。函数返回解码后的图像矩阵,这是一个 NumPy 数组,可以用于进一步的图像处理或分析。
return cv2.imdecode(np.fromfile(path, np.uint8), flags)
# imread 函数提供了一个方便的方式来读取图像文件,并根据需要以不同的模式解码图像。这个函数结合了 NumPy 和 OpenCV 的功能,使得图像读取和处理变得更加灵活和高效。
69.def imwrite(path, im):
# 这段代码定义了一个名为 imwrite 的函数,其目的是将图像数据 im 写入到指定的文件路径 path 。这个函数尝试使用 OpenCV 的 cv2.imencode 函数来编码图像,并将其保存到磁盘。
# 定义了一个名为 imwrite 的函数,它接受两个参数。
# 1.path :要保存图像的文件路径。
# 2.im :要保存的图像数据,通常是一个 NumPy 数组。
def imwrite(path, im):
# 开始了一个 try 块,用于尝试执行后面的代码,并捕获可能发生的任何异常。
try:
# cv2.imencode(ext, img, params=None)
# cv2.imencode() 是 OpenCV 库中的一个函数,用于将图像编码为特定格式的内存缓冲区。这个函数常用于准备图像数据以便存储或传输。
# 参数 :
# ext :一个字符串,指定图像的文件扩展名,例如 .jpg 、 .png 等。这个扩展名决定了编码的格式。
# img :要编码的图像,必须是一个有效的图像矩阵。
# params :一个元组或列表,包含特定编码格式的参数。例如,对于 JPEG 图像,可以包含压缩质量( [cv2.IMWRITE_JPEG_QUALITY, 90] )。
# 返回值 :
# 返回一个元组,其中包含两个元素 :
# 一个布尔值,表示编码是否成功。
# 编码后的图像数据,如果成功,它是一个内存缓冲区(即一个字节数组);如果失败,则为 None 。
# 注意事项 :
# cv2.imencode() 需要一个有效的图像矩阵作为输入。如果图像未正确加载或为空,编码将失败。
# 编码参数 params 是可选的,但对于一些格式(如 JPEG)来说,可能需要提供额外的参数来控制编码过程。
# 编码后的图像数据是一个内存缓冲区,可以直接写入文件或通过网络发送,而不需要先保存到磁盘。
# cv2.imencode() 返回的布尔值 ret 可以用来检查编码操作是否成功。如果失败,应进行错误处理。
# numpy.ndarray.tofile(fname, sep="", format="%s")
# np.ndarray.tofile() 这是一个 NumPy 数组的方法,用于将数组保存到二进制文件中。
# 参数 :
# fname :要写入的文件名。
# sep :(可选)用于分隔数组元素的字符串,默认为空字符串。
# format :(可选)用于格式化每个元素的字符串,默认为 "%s" 。
# 功能 :
# tofile() 方法将 NumPy 数组的内容写入到指定的文件中。该方法按照数组在内存中的布局(通常是行优先顺序)来写入数据。文件被写入的数据是二进制格式的,因此只能使用相同的数据类型和数组形状来读取。
# 返回值 :tofile() 方法不返回任何值。
# 注意事项 :
# tofile() 方法只能用于保存二进制数据,因此读取文件时也必须使用相同的数据类型和形状。
# 如果需要保存文本格式的数组,可以使用 numpy.savetxt() 函数。
# 文件被写入的数据不包含任何形状或类型的元数据,因此读取时需要确保使用正确的形状和数据类型。
# 这行代码执行以下操作 :
# Path(path).suffix :获取文件路径 path 的文件扩展名。
# cv2.imencode :使用 OpenCV 的 imencode 函数根据文件扩展名编码图像数据 im 。
# [1] : imencode 函数返回一个元组,其中第一个元素是成功标志,第二个元素是编码后的图像缓冲区。这里选择第二个元素,即编码后的图像数据。
# tofile(path) :将编码后的图像数据写入到指定的文件路径 path 。
cv2.imencode(Path(path).suffix, im)[1].tofile(path)
# 如果图像成功写入文件,则函数返回 True 。
return True
# 如果在尝试写入图像的过程中发生任何异常, except 块将捕获这个异常。
except Exception:
# 返回 False 。
return False
# imwrite 函数提供了一个简单的方式来保存图像文件。它处理了文件扩展名的自动检测和图像的编码过程,使得用户可以方便地将图像数据保存为不同的文件格式。通过返回值,用户可以知道操作是否成功,这在自动化脚本和错误处理中非常有用。
70.def imshow(path, im):
# 这段代码定义了一个名为 imshow 的函数,它是一个自定义的函数,用于显示图像。这个函数是对 OpenCV 的 cv2.imshow 函数的一个封装,它允许在不同的环境中(特别是 Jupyter 笔记本)正确显示图像。
# 定义了一个名为 imshow 的函数,它接受两个参数。
# 1.path :用于显示图像的窗口名称。
# 2.im :要显示的图像数据,通常是一个 NumPy 数组。
def imshow(path, im):
# 这行代码执行以下操作 :
# path.encode('unicode_escape') :将 path 字符串编码为 Unicode 转义序列。这在处理包含特殊字符的路径时非常有用,例如在 Jupyter 笔记本中显示图像时。
# .decode() :将编码后的 Unicode 转义序列解码回字符串。
# imshow_ :调用之前保存的 cv2.imshow 函数的副本,将解码后的路径和图像数据传递给它,以显示图像。
# 使用 Unicode 转义序列可能是为了确保窗口名称中的特殊字符(如空格、中文字符等)能够在不同的操作系统和环境中被正确处理。这样,无论用户的系统环境如何,图像都能在预期的窗口中显示,而不会受到路径中特殊字符的影响。
imshow_(path.encode('unicode_escape').decode(), im)
# imshow 函数提供了一个方便的方式来显示图像,特别是在需要处理包含特殊字符的路径或在不同环境中显示图像时。这个函数结合了 Unicode 转义序列的处理和 OpenCV 的图像显示功能,使得图像显示更加灵活和可靠。
# 将字符串先编码为 Unicode 转义序列再解码回字符串的过程,通常用于确保字符串中的特殊字符能够被正确处理,尤其是在不同操作系统和编程环境中。以下是这一过程的几个主要原因 :
# 处理特殊字符 :
# 有些字符在不同的系统或编程环境中可能被视为控制字符或有特殊含义,例如在文件路径中,某些字符可能需要转义。
# Unicode 转义序列提供了一种标准化的方式来表示这些特殊字符,确保它们在不同环境中都能被正确理解和处理。
# 兼容性 :
# 在跨平台应用中,特别是在涉及文件路径或其他系统资源的字符串时,使用 Unicode 转义序列可以提高代码的兼容性。
# 例如,在 Windows 系统中,文件路径分隔符是反斜杠 \ ,而在 Unix/Linux 系统中是正斜杠 / 。Unicode 转义序列可以帮助统一这些差异。
# 国际化和本地化 :
# 对于支持国际化的应用,使用 Unicode 转义序列可以确保字符串中的非英文字符(如汉字、日文、阿拉伯文等)能够被正确处理和显示。
# 安全性 :
# 在处理来自用户输入或其他不可信来源的字符串时,转义特殊字符可以防止注入攻击或其他安全问题。
# 字符串的可读性和维护性 :
# 使用 Unicode 转义序列可以使字符串在代码中更易于阅读和维护,尤其是在字符串中包含需要特殊处理的字符时。
# 特定环境的要求 :
# 某些环境或库可能要求字符串以特定的格式传递,例如在 Jupyter 笔记本中显示图像时,窗口名称可能需要特定的编码格式。
# 重定义 OpenCV 函数。将 OpenCV 模块中的 imread 、 imwrite 和 imshow 函数重定义为上面定义的自定义函数。这样做可以让 OpenCV 使用这些自定义的实现。
cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
# 这段代码通过重定义 OpenCV 的基本图像处理函数,提供了更灵活的图像读取、写入和显示功能。这些自定义函数可以更好地处理文件路径、异常情况,并与 Python 的其他库(如 NumPy 和 pathlib)集成。通过这种方式,用户可以在使用 OpenCV 时获得更一致和便捷的体验。
# Variables ------------------------------------------------------------------------------------------------------------