YOLOv7-0.1部分代码阅读笔记-general.py
general.py
utils\general.py
目录
general.py
1.所需的库和模块
2.def set_logging(rank=-1):
3.def init_seeds(seed=0):
4.def get_latest_run(search_dir='.'):
5.def isdocker():
6.def emojis(str=''):
7.def check_online():
8.def check_git_status():
9.def check_requirements(requirements='requirements.txt', exclude=()):
10.def check_img_size(img_size, s=32):
11.def check_imshow():
12.def check_file(file):
13.def check_dataset(dict):
14.def make_divisible(x, divisor):
15.def clean_str(s):
16.def one_cycle(y1=0.0, y2=1.0, steps=100):
17.def colorstr(*input):
18.def labels_to_class_weights(labels, nc=80):
19.def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
20.def coco80_to_coco91_class():
21.def xyxy2xywh(x):
22.def xywh2xyxy(x):
23.def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
24.def xyn2xy(x, w=640, h=640, padw=0, padh=0):
25.def segment2box(segment, width=640, height=640):
26.def segments2boxes(segments):
27.def resample_segments(segments, n=1000):
28.def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
29.def clip_coords(boxes, img_shape):
30.def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
31.def bbox_alpha_iou(box1, box2, x1y1x2y2=False, GIoU=False, DIoU=False, CIoU=False, alpha=2, eps=1e-9):
32.def box_iou(box1, box2):
33.def wh_iou(wh1, wh2):
34.def box_giou(box1, box2):
35.def box_ciou(box1, box2, eps: float = 1e-7):
36.def box_diou(box1, box2, eps: float = 1e-7):
37.def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, labels=()):
38.def strip_optimizer(f='best.pt', s=''):
39.def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
40.def apply_classifier(x, model, img, im0):
41.def increment_path(path, exist_ok=True, sep=''):
1.所需的库和模块
# YOLOR general utils
import glob
import logging
import math
import os
import platform
import random
import re
import subprocess
import time
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
import torch
import torchvision
import yaml
from utils.google_utils import gsutil_getsize
from utils.metrics import fitness
from utils.torch_utils import init_torch_seeds
# torch.set_printoptions()
# torch.set_printoptions 是 PyTorch 提供的一个函数,用于设置打印张量时的选项。这个函数允许你自定义张量的打印输出,以便在调试和展示结果时更加方便和清晰。
# 常用参数及其定义 :
# linewidth :设置每行打印的最大字符数。默认值通常是80。
# precision :设置浮点数打印时的小数点后位数。默认值通常是4。
# profile :设置打印配置文件。可以是 'short' 、 'full' 或 'long' 。 'short' 会打印更少的信息, 'full' 会打印完整的张量, 'long' 会打印更详细的信息。
# threshold :设置当张量的元素个数超过这个值时,张量会被压缩打印。默认值通常是1000。
# edgeitems :设置在打印多维张量时,每维度打印的元素个数。默认值通常是3。
# sign :设置是否打印浮点数的符号。可以是 '+' 、 '-' 或 False 。默认值通常是 False 。
# nan_inf :设置如何打印 NaN 和 inf 值。可以是 'quiet' (不打印)、 'warn' (打印警告)或其他字符串。默认值通常是 'warn' 。
# 设置了一些库的全局配置,以优化性能和输出格式。
# Settings
# 这是 PyTorch 的设置,用于控制打印输出的格式。
# linewidth=320 :设置输出的行宽为320个字符。
# precision=5 :设置浮点数的打印精度为5位小数。
# profile='long' :指定打印配置文件, 'long' 表示使用长格式,这通常包括更多的信息。
torch.set_printoptions(linewidth=320, precision=5, profile='long')
# np.set_printoptions()
# np.set_printoptions 是 NumPy 库中用于设置数组打印选项的函数。通过这个函数,你可以控制 NumPy 数组在控制台中的显示方式。
# precision :指定浮点数打印时的小数点后的位数,默认值为8。
# threshold :当数组元素个数超过这个值时,数组会被压缩显示,默认值为1000。
# edgeitems :在压缩显示时,每维度显示的元素个数,默认值为3。
# linewidth :设置输出行的最大宽度,默认值通常与控制台宽度有关。
# suppress :如果为 True ,则浮点数打印时不显示尾随的零,默认值为 False 。
# nanstr :设置如何显示 NaN 值,默认为 'nan' 。
# infstr :设置如何显示 inf 值,默认为 'inf' 。
# formatter :允许你为数据类型指定一个格式化函数或字典,用于自定义打印格式。 formatter 参数是一个高级特性,允许你为特定的数据类型定义自定义的打印格式。
# 这是 NumPy 的设置,用于控制数组的打印输出。
# linewidth=320 :设置输出的行宽为320个字符。
# formatter={'float_kind': '{:11.5g}'.format} :指定浮点数的打印格式。 '{:11.5g}' 是一个格式化字符串, g 表示以通用格式打印(即自动选择固定点或科学记数法), 11.5g 表示总共11个字符宽,其中小数点后5位。
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
# 这是 Pandas 的设置,用于控制显示的最大列数。
# max_columns = 10 :意味着在打印 DataFrame 时,最多显示10列。
pd.options.display.max_columns = 10
# 这是 OpenCV 的设置,用于控制多线程。
# setNumThreads(0) :禁止 OpenCV 使用多线程,这是因为 OpenCV 的多线程可能与 PyTorch 的 DataLoader 不兼容,可能会导致问题。
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
# 这是设置环境变量,用于控制 NumExpr 库的最大线程数。
# os.environ :是一个代表当前环境变量的字典。
# NUMEXPR_MAX_THREADS :是 NumExpr 的环境变量,用于设置最大线程数。
# os.cpu_count() :返回当前机器的CPU核心数。
# min(os.cpu_count(), 8) :确保线程数不会超过8,即使机器的核心数超过8。
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
2.def set_logging(rank=-1):
# 这段代码定义了一个名为 set_logging 的函数,它用于配置 Python 的日志记录系统。
# 1.rank :默认值为 -1 ,通常用于分布式训练环境中,以区分不同的进程。
def set_logging(rank=-1):
# 调用 logging 模块的 basicConfig 函数来设置日志记录的基本配置。这个函数只在日志系统尚未配置时生效,即如果程序中已经调用过 basicConfig ,则后续调用将不会有任何效果。
logging.basicConfig(
# 设置日志消息的格式。这里的格式字符串只包含日志消息本身,不包括时间戳、日志级别或其他信息。
format="%(message)s",
# 根据 rank 的值设置日志记录的级别。
# 如果 rank 为 -1 或 0 ,则设置日志级别为 INFO ,这意味着所有级别为 INFO 及以上( WARNING 、 ERROR 、 CRITICAL )的日志消息都会被输出。
# 如果 rank 为其他值,则设置日志级别为 WARNING ,这意味着只有级别为 WARNING 及以上的日志消息会被输出。
level=logging.INFO if rank in [-1, 0] else logging.WARN)
# 这个 set_logging 函数的目的是在程序开始时设置合适的日志记录级别。在分布式训练环境中, rank 参数可以用来区分主进程和其他进程,以便在非主进程中减少日志输出,避免日志信息的冗余。
# 请注意, logging 模块是 Python 标准库的一部分,用于提供灵活的日志记录系统。通过配置日志记录,你可以控制程序运行时的输出信息,方便调试和监控程序的行为。
3.def init_seeds(seed=0):
# 这段代码定义了一个名为 init_seeds 的函数,其目的是初始化随机数生成器(RNG)的种子。
# 这个函数的作用是确保代码的随机性是可重复的,即每次使用相同的种子时,生成的随机数序列是相同的。这对于调试和复现实验结果非常有用。
# 这个函数接受一个参数。
# 1.seed :默认值为0。
def init_seeds(seed=0):
# random.seed(a=None, version=2)
# random.seed() 是 Python 标准库 random 模块中的一个函数,用于初始化随机数生成器的种子。这个函数确保了随机数生成器产生的随机数序列是可重复的,即在相同的种子下,每次运行程序时产生的随机数序列都是相同的。
# 参数 :
# a :种子值,可以是任何 hashable(可哈希)对象。如果为 None ,则使用当前时间作为种子。
# version :随机数生成器的版本,可以是 1 或 2。默认为 2。版本 2 在 Python 3.2.3 和 3.3.0 中引入,提供了更好的随机性。
# 作用 :
# 当你提供一个特定的种子值时, random.seed() 会重置随机数生成器的状态,使得随后的随机数生成可以预测。
# 如果不提供种子值(或为 None ),则随机数生成器将使用一个不可预测的值(通常是当前时间)作为种子,这使得每次程序运行时产生的随机数序列都是不同的。
# 初始化随机数生成器 (RNG) 种子。
# Initialize random number generator (RNG) seeds
# 这是 Python 标准库 random 模块的函数,用于初始化 Python 内置随机数生成器的种子。
random.seed(seed)
# np.random.seed( seed=None, *, device='cpu' )
# np.random.seed() 是 NumPy 库中用于设置随机数生成器种子的函数。这个函数确保了 NumPy 生成的随机数序列是可重复的,即在相同的种子下,每次运行程序时产生的随机数序列都是相同的。
# 参数 :
# seed :种子值,可以是任何 hashable(可哈希)对象。如果为 None ,则使用一个随机种子。
# device :指定生成随机数的设备,可以是 'cpu' 或 'gpu' 。默认为 'cpu' 。
# 作用 :
# 当你提供一个特定的种子值时, np.random.seed() 会重置 NumPy 的随机数生成器的状态,使得随后的随机数生成可以预测。
# 如果不提供种子值(或为 None ),则随机数生成器将使用一个不可预测的值(通常是当前时间)作为种子,这使得每次程序运行时产生的随机数序列都是不同的。
# 请注意, np.random.seed() 只影响 NumPy 的随机数生成器,不影响 Python 内置的 random 模块或 PyTorch 的随机数生成器。如果你需要确保多个库的随机性是一致的,你需要分别对这些库设置相同的种子。
# 这是 NumPy 库的函数,用于初始化 NumPy 的随机数生成器的种子。
np.random.seed(seed)
# def init_torch_seeds(seed=0):
# -> 它用于初始化 PyTorch 的随机数生成器种子,以确保代码的随机性是可重复的。这个函数的设计考虑了速度和可重复性之间的权衡。如果你需要确保结果的可重复性,可以将种子值设置为0;如果你更关心性能,可以设置一个非零的种子值。
init_torch_seeds(seed)
4.def get_latest_run(search_dir='.'):
# 这段代码定义了一个名为 get_latest_run 的函数,其目的是在指定的目录(默认为当前目录)及其子目录中查找最新的 last.pt 文件。这个文件通常用于存储 PyTorch 模型的权重,以便在训练过程中断后可以从中断的地方恢复训练。
# search_dir :要搜索的目录路径,默认为当前目录( . )。
def get_latest_run(search_dir='.'):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
# 使用 glob 模块搜索指定目录及其所有子目录中的所有以 last 开头并以 .pt 结尾的文件。 recursive=True 参数使得搜索是递归的,即包括所有子目录。
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
# 如果 last_list 不为空,使用 max 函数找到列表中最“新”的文件。这里的“新”是基于文件的创建时间( os.path.getctime ),即返回创建时间最晚的文件路径。
# if last_list else '' :如果 last_list 为空(即没有找到任何文件),函数返回一个空字符串。
return max(last_list, key=os.path.getctime) if last_list else ''
# 这个函数对于自动化训练流程和恢复训练非常有用,特别是在需要从上次训练中断的地方继续训练时。
5.def isdocker():
# 这段代码定义了一个名为 isdocker 的函数,其目的是检测当前的运行环境是否是一个 Docker 容器。
def isdocker():
# 环境是Docker容器吗。
# Is environment a Docker container
# Path('/workspace').exists() :这个表达式检查 /workspace 路径是否存在。在 Docker 容器中, /workspace 目录通常被用作工作目录,因此,如果这个目录存在,很可能表明代码正在 Docker 容器中运行。
# Path('/.dockerenv').exists() :这是另一个可选的检查方式,用于检测 /.dockerenv 文件是否存在。Docker 在容器中创建这个文件,以表明它是一个 Docker 环境。如果这个文件存在,那么代码也很可能正在 Docker 容器中运行。
# 返回值。
# 如果检测到 Docker 环境(即 /workspace 或 /.dockerenv 存在),函数返回 True 。
# 如果没有检测到 Docker 环境,函数返回 False 。
return Path('/workspace').exists() # or Path('/.dockerenv').exists()
# 这个函数对于根据不同的运行环境调整代码行为非常有用,例如,你可能需要在 Docker 容器中使用不同的配置或资源路径。
6.def emojis(str=''):
# 这段代码定义了一个名为 emojis 的函数,其目的是确保传入的字符串在不同操作系统平台上能够安全地显示,特别是在处理 emoji 或其他非 ASCII 字符时。
# str :一个字符串参数,默认为空字符串 '' 。
def emojis(str=''):
# 返回与平台相关的表情符号安全版本的字符串。
# platform.system() :这个函数调用用于获取当前操作系统的名称。
# if platform.system() == 'Windows' :检查当前操作系统是否为 Windows。 如果是 Windows 平台,执行以下操作 。
# str.encode().decode('ascii', 'ignore') :首先将字符串编码为字节串,然后使用 ASCII 编码解码,非 ASCII 字符(包括 emoji)被忽略,从而确保字符串中只包含 ASCII 字符,这在 Windows 系统上是安全的。
# 如果不是 Windows 平台,直接返回原始字符串 str ,因为在大多数其他操作系统(如 macOS 和 Linux)上,通常可以安全地显示 emoji 和其他非 ASCII 字符。
# Return platform-dependent emoji-safe version of string
# 返回一个在当前操作系统平台上能够安全显示的字符串。
return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
# 这个函数对于编写跨平台的应用程序特别有用,因为它可以自动处理不同操作系统对 emoji 和非 ASCII 字符的支持差异。
7.def check_online():
# 这段代码定义了一个名为 check_online 的函数,其目的是检查当前环境是否能够连接到互联网。
def check_online():
# 检查互联网连接。
# Check internet connectivity
import socket
# try...except :这是一个异常处理结构,用于捕获在尝试连接时可能发生的任何 OSError 异常,例如网络不可达、超时等。
try:
# socket.create_connection(("1.1.1.1", 443), 5) :这个表达式尝试创建一个到指定主机和端口的连接。这里使用的是 1.1.1.1 ,这是一个公共的 DNS 服务器地址,端口 443 是 HTTPS 服务的标准端口。函数尝试在 5 秒内建立连接。
socket.create_connection(("1.1.1.1", 443), 5) # check host accesability
# 返回值。
# 如果能够成功建立连接,函数返回 True ,表示当前环境可以连接到互联网。
return True
except OSError:
# 如果在尝试连接过程中捕获到 OSError 异常,函数返回 False ,表示当前环境无法连接到互联网。
return False
# 这个函数对于需要在程序运行前检查网络连接的场景非常有用,比如在执行需要网络请求的操作之前确保网络可用,或者在网络请求失败时提供用户反馈。
8.def check_git_status():
# 这段代码定义了一个名为 check_git_status 的函数,其目的是检查当前的 Git 仓库是否是最新的。如果代码落后于远程仓库,它会提醒用户执行 git pull 来更新代码。
def check_git_status():
# Recommend 'git pull' if code is out of date
print(colorstr('github: '), end='')
try:
assert Path('.git').exists(), 'skipping check (not a git repository)'
assert not isdocker(), 'skipping check (Docker image)'
assert check_online(), 'skipping check (offline)'
cmd = 'git fetch && git config --get remote.origin.url'
url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url
branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
if n > 0:
s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
f"Use 'git pull' to update or 'git clone {url}' to download latest."
else:
s = f'up to date with {url} ✅'
print(emojis(s)) # emoji-safe
except Exception as e:
print(e)
9.def check_requirements(requirements='requirements.txt', exclude=()):
# 这段代码定义了一个名为 check_requirements 的函数,其目的是检查当前环境中安装的依赖是否满足特定的要求,并在必要时自动更新这些依赖。这个函数可以接收一个包含依赖的文本文件或依赖列表,并排除一些不需要检查的包。
# 1.requirements :可以是一个包含依赖的文本文件路径(默认为 'requirements.txt' ),或者是一个包含包名和版本的列表或元组。
# 2.exclude :一个元组,包含不需要检查的包名。
# 返回值 :无返回值,函数直接打印结果或执行操作。
def check_requirements(requirements='requirements.txt', exclude=()):
# 检查已安装的依赖项是否满足要求(传递 *.txt 文件或包列表)。
# Check installed dependencies meet requirements (pass *.txt file or list of packages)
# 导入 pkg_resources 模块来处理包资源。
import pkg_resources as pkg
# 设置提示前缀,使用 colorstr 函数(未在代码中定义,可能是外部定义的函数)来设置颜色和样式。
# def colorstr(*input):
# -> 它用于给字符串添加 ANSI 转义代码,以便在支持 ANSI 颜色代码的终端中以彩色输出。返回一个字符串,它由指定颜色和样式的 ANSI 代码、待着色的字符串和重置颜色的 ANSI 代码( colors['end'] )组成。
# -> return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
prefix = colorstr('red', 'bold', 'requirements:')
# 检查 requirements 参数的类型。
# 如果是字符串或 Path 对象,假设它是一个文件路径,尝试打开并解析文件中的依赖。
if isinstance(requirements, (str, Path)): # requirements.txt file requirements.txt 文件。
file = Path(requirements)
if not file.exists():
print(f"{prefix} {file.resolve()} not found, check failed.") # 未找到 {prefix} {file.resolve()},检查失败。
return
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
# 如果不是字符串或 Path 对象,假设它是一个列表或元组,直接使用它。
else: # list or tuple of packages 包列表或元组。
# 解析文件中的依赖或使用提供的依赖列表,并排除不需要检查的包。
requirements = [x for x in requirements if x not in exclude]
# 遍历依赖列表,尝试使用 pkg.require 来检查每个依赖是否已安装且版本符合要求。
n = 0 # number of packages updates 软件包更新数量。
for r in requirements:
try:
pkg.require(r)
except Exception as e: # DistributionNotFound or VersionConflict if requirements not met 如果未满足要求,则 DistributionNotFound 或 VersionConflict。
n += 1
print(f"{prefix} {e.req} not found and is required by YOLOR, attempting auto-update...") # 未找到 {prefix} {e.req} 且 YOLOR 需要它,正在尝试自动更新...
# 如果依赖未满足,捕获异常,并尝试使用 pip 自动安装缺失的包。
print(subprocess.check_output(f"pip install '{e.req}'", shell=True).decode())
# 如果有包被更新,打印提示信息,并建议用户重启运行时环境或重新运行命令以使更新生效。
if n: # if packages updated 如果软件包已更新。
source = file.resolve() if 'file' in locals() else requirements
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" # 根据 {source} 更新了 {prefix} {n} 包{'s' * (n > 1)}。{prefix} ⚠️ {colorstr('bold', '重新启动运行时或重新运行命令以使更新生效')
# def emojis(str=''):
# -> 确保传入的字符串在不同操作系统平台上能够安全地显示,特别是在处理 emoji 或其他非 ASCII 字符时。返回一个在当前操作系统平台上能够安全显示的字符串。
# -> return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
print(emojis(s)) # emoji-safe
# 这个函数对于确保 Python 项目运行环境满足特定依赖要求非常有用,尤其是在自动化脚本和部署过程中。
10.def check_img_size(img_size, s=32):
# 这段代码定义了一个名为 check_img_size 的函数,其目的是验证给定的图像尺寸 img_size 是否是步长 s 的倍数。如果不是,函数将调整 img_size 到最接近的、大于或等于 img_size 的 s 的倍数,并打印一条警告信息。
# 参数.
# 1.img_size :要检查的图像尺寸。
# 2.s :步长,默认值为32。
def check_img_size(img_size, s=32):
# 验证 img_size 是 stride 的倍数。
# Verify img_size is a multiple of stride s
# 调用 make_divisible 函数,将 img_size 调整到最接近的、大于或等于 img_size 的 s 的倍数。
# def make_divisible(x, divisor): -> 它用于确保一个数值 x 能够被 divisor 整除。将向上取整后的结果乘以 divisor ,得到一个能够被 divisor 整除的数。 -> return math.ceil(x / divisor) * divisor
new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
# 检查调整后的 new_size 是否与原始的 img_size 不同。
if new_size != img_size:
# 如果 new_size 与 img_size 不同,打印一条警告信息,告知用户图像尺寸已被更新。
print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size)) # 警告:--img-size %g 必须是最大步幅 %g 的倍数,更新为 %g。
# 返回调整后的图像尺寸 new_size 。
return new_size
11.def check_imshow():
# 这段代码定义了一个名为 check_imshow 的函数,其目的是检查当前环境是否支持图像显示,特别是使用 OpenCV 的 cv2.imshow() 函数和 PIL(Python Imaging Library)的 Image.show() 函数。
def check_imshow():
# 检查环境是否支持图像显示。
# Check if environment supports image displays
try:
# 检查 Docker 环境。首先检查当前环境是否是 Docker 容器。如果是,则直接返回错误信息,因为 cv2.imshow() 在 Docker 环境中通常不可用。
assert not isdocker(), 'cv2.imshow() is disabled in Docker environments' # Docker 环境中已禁用 cv2.imshow()。
# 尝试显示图像。使用 cv2.imshow() 尝试创建一个窗口并显示一个空白图像(一个尺寸为 1x1x3 的黑色图像)。
cv2.imshow('test', np.zeros((1, 1, 3)))
# 等待和关闭窗口。调用 cv2.waitKey(1) 等待1毫秒,以便图像窗口可以更新,然后调用 cv2.destroyAllWindows() 关闭所有 OpenCV 窗口,再等待1毫秒确保窗口关闭。
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
# 返回值。
# 如果环境支持图像显示,函数返回 True 。
return True
except Exception as e:
# 异常处理:如果在尝试显示图像的过程中发生任何异常,捕获异常并打印警告信息,然后返回 False 。
print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}') # 警告:环境不支持 cv2.imshow() 或 PIL Image.show() 图像显示。
# 如果环境不支持图像显示,函数返回 False 。
return False
# 这个函数对于确定是否可以在当前环境中显示图像非常有用,特别是在自动化测试或非图形用户界面(GUI)环境中,图像显示可能不可用或不期望。通过这个函数,可以避免在不支持图像显示的环境中尝试显示图像,从而避免程序崩溃或错误。
12.def check_file(file):
# 这段代码定义了一个名为 check_file 的函数,其目的是在给定文件路径不存在时,在当前目录及其子目录中搜索文件。如果找到了文件,函数将返回文件的路径;如果未找到或有多个匹配的文件,函数将抛出异常。
# 1.file :要检查的文件名或文件路径
def check_file(file):
# Path.is_file()
# .is_file() 是 Python pathlib 模块中的 Path 类的一个方法,用于检查路径是否指向一个文件。
# 返回值 :
# 如果路径存在且指向一个文件,则返回 True 。
# 如果路径不存在或指向的不是一个文件(例如,它是一个目录),则返回 False 。
# 这个方法在处理文件和目录时非常有用,因为它允许你编写更清晰的代码来检查路径类型,而不是使用低级的操作系统调用或错误处理。
# 如果未找到则搜索文件。
# Search for file if not found
# 检查文件是否存在。使用 Path(file).is_file() 检查给定的文件路径是否存在,如果存在或者文件名为空字符串,则直接返回文件路径。
if Path(file).is_file() or file == '':
return file
else:
# 搜索文件。如果文件不存在,使用 glob.glob('./**/' + file, recursive=True) 在当前目录及其所有子目录中搜索匹配的文件。 recursive=True 参数使得搜索是递归的。
files = glob.glob('./**/' + file, recursive=True) # find file
# 断言文件存在。使用 assert len(files) 确保至少找到了一个文件,如果没有找到文件,则抛出 File Not Found 的异常。
assert len(files), f'File Not Found: {file}' # assert file was found 未找到文件:{file}。
# 断言文件唯一。使用 assert len(files) == 1 确保只找到了一个匹配的文件,如果找到多个文件,则抛出异常,提示用户指定确切的路径。
assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique 多个文件与“{file}”匹配,请指定确切路径:{files}。
# 返回文件路径。如果文件存在且唯一,返回找到的文件路径。
return files[0] # return file
# 这个函数对于确保文件存在且路径唯一非常有用,特别是在处理需要确切文件路径的脚本或程序时。通过这个函数,可以避免因文件路径错误或不明确而导致的错误。
13.def check_dataset(dict):
# 这段代码是一个Python函数,名为 check_dataset ,它的目的是检查本地是否存在指定的数据集,如果不存在,则尝试下载。
# 1.dict :一个字典,包含以下键值对 :
# 'val' :一个或多个数据集文件的路径,可以是单个路径字符串或路径字符串列表。
# 'download' :当本地找不到数据集时,用于下载数据集的URL或脚本。如果是一个URL,它应该是一个以 .zip 结尾的压缩文件链接;如果是一个脚本,它应该是一个bash命令或脚本路径。
def check_dataset(dict):
# 如果本地找不到数据集,请下载。
# Download dataset if not found locally
# 从传入的字典中获取键 'val' 和 'download' 对应的值,并分别赋给变量 val 和 s 。
val, s = dict.get('val'), dict.get('download')
# 检查 val 是否存在且不为空。
if val and len(val):
# Path.resolve(strict=False)
# resolve() 函数是 Python pathlib 模块中 Path 类的一个方法,用于将路径对象转换为绝对路径,并解析路径中的所有符号链接(symlinks),返回一个新的路径对象。
# 参数 :
# strict :布尔值,默认为 False 。如果设置为 True ,则在路径不存在时引发 FileNotFoundError 。如果为 False ,则尽可能解析路径并附加任何剩余部分,而不检查它是否存在。
# 返回值 :
# 返回一个新的 Path 对象,该对象表示原始路径的绝对路径,并且已经解析了所有符号链接。
# 功能 :
# resolve() 方法将路径对象转换为绝对路径,并解析路径中的所有符号链接。
# 它还会处理路径中的所有特殊符号,例如 . (当前目录)和 .. (上级目录)。
# 如果路径不存在且 strict 参数为 True ,则引发 FileNotFoundError 。
# 如果 strict 参数为 False ,则尽可能解析路径并附加任何剩余部分,而不检查它是否存在。
# 注意事项 :
# resolve() 方法总是返回一个绝对路径,无论传入的路径是相对路径还是绝对路径。
# 如果路径是相对路径,则会根据当前工作目录将其解析为绝对路径。
# resolve() 方法会对路径进行标准化处理,消除冗余的分隔符、处理上级目录符号(..)和当前目录符号(.),以保证返回的路径是规范化的。
# 如果 val 是列表,则对每个元素应用 Path.resolve() 方法来获取绝对路径;如果不是列表,则将 val 放入列表中再进行处理。
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
# os.path.exists(path)
# exists() 函数是 Python os.path 模块中的一个函数,用于检查指定路径的文件或目录是否存在。
# 参数 :
# path :要检查的文件或目录的路径。
# 返回值 :
# 返回 True :如果指定的路径存在。
# 返回 False :如果指定的路径不存在。
# 功能 :
# exists() 函数用于检查给定的路径是否指向一个存在的文件或目录。
# 它不会区分路径是指向文件还是目录,只要路径存在,它就返回 True 。
# 注意事项 :
# exists() 函数不会抛出异常,即使路径不存在,它也会安静地返回 False 。
# 如果你需要区分路径是文件还是目录,可以使用 os.path.isfile() 和 os.path.isdir() 函数。
# 这个函数不适用于检查路径是否可访问(例如,是否有读取或写入权限),它只检查路径是否存在。
# 检查 val 列表中的所有路径是否存在。
if not all(x.exists() for x in val):
# 如果有路径不存在,则打印警告信息,列出不存在的路径。
print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()]) # 警告:未找到数据集,不存在路径:%s。
# 检查 s 是否存在且不为空。
if s and len(s): # download script
# 如果 s 存在且不为空,则打印下载信息。
print('Downloading %s ...' % s) # 正在下载 %s...
# str.startswith(prefix[, start[, end]])
# startswith() 是 Python 中字符串( str )对象的一个方法,用于检查字符串是否以指定的前缀开始。
# 参数 :
# prefix :要检查的前缀字符串。
# start (可选) :开始检查的起始位置,默认为 0。
# end (可选) :结束检查的位置,默认为字符串的长度。这意味着如果未指定 end ,则检查整个字符串。
# 返回值 :
# 返回 True :如果字符串以指定的 prefix 开始。
# 返回 False :如果字符串不以指定的 prefix 开始。
# 功能 :
# startswith() 方法用于确定字符串是否以特定的子字符串 prefix 开头。
# 可以指定搜索的开始和结束位置,这允许在字符串的一部分上执行检查。
# 注意事项 :
# 如果 prefix 参数是一个元组, startswith() 方法将检查字符串是否以元组中的任何一个前缀开始,并在任何一个匹配时返回 True 。
# 如果 start 或 end 参数超出字符串的范围,将会引发 ValueError 。
# 该方法对大小写敏感,如果要进行不区分大小写的比较,可以在调用 startswith() 之前将字符串和前缀都转换为全小写或全大写。
# str.endswith(suffix[, start[, end]])
# endswith() 是 Python 中字符串( str )对象的一个方法,用于检查字符串是否以指定的后缀结束。
# 参数 :
# suffix :要检查的后缀字符串或后缀元组。
# start (可选) :开始检查的起始位置,默认为 0,表示从字符串的开头开始检查。
# end (可选) :结束检查的位置,默认为字符串的长度,表示检查到字符串的末尾。
# 返回值 :
# 返回 True :如果字符串以指定的 suffix 结束。
# 返回 False :如果字符串不以指定的 suffix 结束。
# 功能 :
# endswith() 方法用于确定字符串是否以特定的子字符串 suffix 结尾。
# 可以指定搜索的开始和结束位置,这允许在字符串的一部分上执行检查。
# 如果 suffix 是一个元组,则字符串必须以元组中的任何一个后缀结束。
# 注意事项 :
# 如果 suffix 参数是一个元组, endswith() 方法将检查字符串是否以元组中的任何一个后缀结束,并在任何一个匹配时返回 True 。
# 如果 start 或 end 参数超出字符串的范围,将会引发 ValueError 。
# 该方法对大小写敏感,如果要进行不区分大小写的比较,可以在调用 endswith() 之前将字符串和后缀都转换为全小写或全大写。
# 检查 s 是否是一个以 http 开头且以 .zip 结尾的URL。
if s.startswith('http') and s.endswith('.zip'): # URL
# 获取URL的文件名。
f = Path(s).name # filename
# torch.hub.download_url_to_file(url, filename, hash_prefix=None)
# download_url_to_file() 函数是 PyTorch 的一个实用函数,用于从指定的 URL 下载文件并保存到本地文件系统。
# 参数 :
# url :要下载文件的 URL 地址。
# filename :下载后的文件应该保存的本地路径和文件名。
# hash_prefix (可选) :文件内容的预期哈希前缀。如果提供,函数会在下载完成后验证文件的哈希值以确保下载的完整性和正确性。
# 返回值 :
# 无返回值,但函数会将文件下载到指定的 filename 路径。
# 功能 :
# download_url_to_file() 方法用于从互联网上的指定 URL 下载文件。
# 它会自动处理 HTTP 请求和文件写入操作。
# 如果提供了 hash_prefix 参数,函数会在下载完成后计算文件的 SHA256 哈希值,并将其与提供的前缀进行比较,以确保文件未被篡改。
# 注意事项 :
# 确保在调用此函数之前已经安装了 PyTorch,并且你的网络环境可以访问指定的 URL。
# 如果下载的文件很大,这个函数可能会花费一些时间来完成下载和(如果提供了)哈希校验。
# 如果下载过程中出现问题,比如网络连接问题或文件哈希不匹配,函数可能会抛出异常。
# 使用PyTorch的 torch.hub.download_url_to_file 函数下载文件。
torch.hub.download_url_to_file(s, f)
# 解压下载的文件,并删除压缩文件。
r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
# 如果 s 不是一个URL,则假设它是一个bash脚本。
else: # bash script
# 执行bash脚本。
r = os.system(s)
# 根据返回值 r 打印下载是否成功。
print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
# 如果 s 不存在或为空。
else:
# 抛出异常,提示数据集未找到。
raise Exception('Dataset not found.')
# 这段代码的目的是自动化地检查数据集是否存在,如果不存在则尝试下载。它使用了 torch.hub.download_url_to_file 来下载文件,并使用 os.system 来执行解压命令或bash脚本。
14.def make_divisible(x, divisor):
# 这段代码定义了一个名为 make_divisible 的函数,它用于确保一个数值 x 能够被 divisor 整除。
# 这在深度学习中很有用,特别是在确定网络层的输出尺寸时,有时我们需要确保输出尺寸能够被某个特定的数(如32、64等)整除,以满足特定的硬件要求或提高性能。
# 1.x :是需要被调整的数值。
# 2.divisor :是 x 需要能够被整除的数。
def make_divisible(x, divisor):
# Returns x evenly divisible by divisor
# math.ceil(x / divisor) :使用 math.ceil 函数对 x 除以 divisor 的结果进行向上取整,得到最接近或大于该商的整数。
# * divisor :将向上取整后的结果乘以 divisor ,得到一个能够被 divisor 整除的数。
return math.ceil(x / divisor) * divisor
15.def clean_str(s):
# 这段代码定义了一个名为 clean_str 的函数,其目的是清理输入的字符串 s ,将其中的特殊字符替换为下划线 _ 。这个函数使用了 Python 的 re 模块,即正则表达式模块,来实现模式匹配和替换。
# 1.s :需要被清理的原始字符串。
def clean_str(s):
# Cleans a string by replacing special characters with underscore _ 通过使用下划线 _ 替换特殊字符来清理字符串。
# pattern :这是一个正则表达式模式,用于匹配所有需要被替换的特殊字符。
# 在这个例子中,它匹配了包括竖线、at 符号、井号、感叹号、货币符号、百分号、& 符号、括号、等于符号、问号、倒挂的问号、脱字符号、星号、分号、冒号、重音符号、尖括号、加号在内的一系列特殊字符。
# repl :这是替换字符串,用于指定被匹配到的字符应该被替换成什么。在这个例子中,所有匹配到的特殊字符都会被替换为下划线 _ 。
# string :这是需要被搜索和替换的原始字符串。
# 返回一个新的字符串,其中所有指定的特殊字符都被替换为了下划线 _ 。
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
# 这对于数据清洗、标准化文件名或消除不希望的字符等场景非常有用。
16.def one_cycle(y1=0.0, y2=1.0, steps=100):
# 这段代码定义了一个名为 one_cycle 的函数,它生成一个 lambda 函数,该 lambda 函数实现了一个正弦波形的上升和下降周期,从 y1 到 y2 。这个周期通常用于机器学习中的学习率调度,以模拟训练过程中学习率的动态变化。
# 参数 :
# 1.y1 :周期开始时的值,默认为 0.0。
# 2.y2 :周期结束时的值,默认为 1.0。
# 3.steps :周期中的总步数,默认为 100。
def one_cycle(y1=0.0, y2=1.0, steps=100):
# 从 y1 到 y2 的正弦斜坡的 lambda 函数。
# lambda function for sinusoidal ramp from y1 to y2
# 返回值。
# 返回一个 lambda 函数,该函数接受一个参数 x (表示周期中的步数),并返回一个浮点数,表示在该步数时的值。
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
# 功能 :
# one_cycle 函数利用余弦函数创建一个从 y1 到 y2 的正弦波形上升和下降周期。
# 这个周期的形状是一个余弦曲线的一半,从 1 开始下降到 0,然后通过公式转换为从 y1 到 y2 的线性变化。
# 公式解释 :
# math.cos(x * math.pi / steps) :计算余弦值,其中 x 是当前步数, steps 是总步数, math.pi 是圆周率π,用于将步数映射到 [0, π] 区间。
# (1 - math.cos(x * math.pi / steps)) / 2 :将余弦值从 [1, -1] 映射到 [0, 1] 区间,表示周期的一半。
# (y2 - y1) :计算 y1 和 y2 之间的差值。
# ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 :将上述值结合起来,计算在周期中的任意步数 x 时的值。
17.def colorstr(*input):
# 这段代码定义了一个名为 colorstr 的函数,它用于给字符串添加 ANSI 转义代码,以便在支持 ANSI 颜色代码的终端中以彩色输出。这个函数接受一个或多个参数,其中最后一个参数被视为要着色的字符串,其余参数指定颜色和样式。
def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
# 参数处理。
# 如果输入参数的数量大于1,那么所有的参数(除了最后一个)被视为颜色和样式,最后一个参数是待着色的字符串。
# 如果只有一个参数,那么默认使用 'blue' 和 'bold' 样式,并使用这个参数作为字符串。
# *args :这是一个可变参数列表,它会捕获所有未被特别命名的参数。在函数定义中使用 * 来收集所有未命名的参数到一个元组中。
# string :这是一个普通参数,它将接收最后一个传入的参数值。
# input :这是函数 colorstr 的参数列表,它接受任意数量的参数。
# if len(input) > 1 else ('blue', 'bold', input[0]) :这是一个条件表达式,它检查传入的参数数量。
# 如果传入的参数超过一个( len(input) > 1 ),那么 input 将被解包,其中所有参数(除了最后一个)将被收集到 args 元组中,最后一个参数将被赋值给 string 。
# 如果传入的参数只有一个或没有额外的参数( len(input) <= 1 ),那么将使用默认值 'blue' 和 'bold' 作为颜色和样式,并且将这个单一的参数作为 string 。
# 这种参数处理方式使得 colorstr 函数非常灵活,可以处理不同数量的参数输入,并提供默认值以简化常见用例的调用。
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
# 颜色和样式字典。 colors 字典定义了各种颜色和样式对应的 ANSI 转义代码。
colors = {'black': '\033[30m', # basic colors
'red': '\033[31m',
'green': '\033[32m',
'yellow': '\033[33m',
'blue': '\033[34m',
'magenta': '\033[35m',
'cyan': '\033[36m',
'white': '\033[37m',
'bright_black': '\033[90m', # bright colors
'bright_red': '\033[91m',
'bright_green': '\033[92m',
'bright_yellow': '\033[93m',
'bright_blue': '\033[94m',
'bright_magenta': '\033[95m',
'bright_cyan': '\033[96m',
'bright_white': '\033[97m',
'end': '\033[0m', # misc
'bold': '\033[1m',
'underline': '\033[4m'}
# 函数返回一个字符串,它由指定颜色和样式的 ANSI 代码、待着色的字符串和重置颜色的 ANSI 代码( colors['end'] )组成。
# colors[x] for x in args :这是一个生成器表达式,它遍历 args 元组中的每个元素(这些元素是颜色或样式的名称),并在 colors 字典中查找对应的 ANSI 转义代码。
# ''.join(...) :这个 join 方法将上一步得到的 ANSI 转义代码列表连接成一个单独的字符串。因为 colors 字典中的每个条目都是一个字符串,所以 join 方法将它们串联起来,没有任何分隔符。
# f'{string}' :这是一个格式化字符串字面量(f-string),它将变量 string 的值插入到字符串中。
# colors['end'] :这是在文本显示完毕后用于重置终端颜色的 ANSI 转义代码。
# 首先,遍历 args 中的所有颜色和样式名称,查找对应的 ANSI 转义代码,并将它们连接成一个字符串。
# 然后,将这个字符串与要着色的文本 string 结合。
# 最后,添加一个重置颜色的 ANSI 转义代码,以确保文本之后的输出不会受到颜色设置的影响。
# 这样,当这个返回值被打印到支持 ANSI 颜色代码的终端时, string 中的文本将以指定的颜色和样式显示,并且在文本显示完毕后,终端的颜色设置会恢复到默认状态。
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
18.def labels_to_class_weights(labels, nc=80):
# 这段代码定义了一个名为 labels_to_class_weights 的函数,它用于根据训练标签计算类别权重。类别权重通常用于在目标检测任务中平衡不同类别的样本数量,使得模型不会偏向于样本数量较多的类别。
# 1.labels :一个包含标签的列表,其中每个标签是一个数组,形状为 (n, 5) ,包含类别和边界框坐标( class xywh )。
# 2.nc :数据集中类别的总数,默认为 80。
def labels_to_class_weights(labels, nc=80):
# 从训练标签中获取类别权重(逆频率)。
# Get class weights (inverse frequency) from training labels
# 检查是否有标签被加载。
if labels[0] is None: # no labels loaded 未加载标签。
# 检查 labels 列表的第一个元素是否为 None ,如果是,则表示没有加载任何标签,函数返回一个空的 PyTorch 张量。
return torch.Tensor()
# 合并标签。这行使用 NumPy 的 concatenate 函数将 labels 列表中的所有数组合并成一个大数组。这个数组的形状为 (866643, 5) ,其中 866643 是样本数量,5 表示每个样本有五个值(类别和四个边界框坐标)。
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO 对于 COCO,labels.shape = (866643,5)。
# 提取类别索引。这行从合并后的标签数组中提取第一列,即类别索引,并将其转换为整数类型。
classes = labels[:, 0].astype(np.int) # labels = [class xywh]
# numpy.bincount(x, minlength=None)
# np.bincount 是 NumPy 库中的一个函数,它用于计算非负整数数组中每个值的出现次数。
# 参数 :
# x :输入数组,其中的元素必须是非负整数。
# minlength (可选) :输出数组的最小长度。如果提供,数组 x 中小于 minlength 的值将被忽略,而 x 中等于或大于 minlength 的值将导致数组被扩展以包含这些值。如果未提供或为 None ,则输出数组的长度将与 x 中的最大值加一相匹配。
# 返回值 :
# 返回一个数组,其中第 i 个元素代表输入数组 x 中值 i 出现的次数。
# 功能 :
# np.bincount 函数对输入数组 x 中的每个值进行计数,返回一个一维数组,其长度至少与 x 中的最大值一样大。
# 如果 x 中的某个值没有出现,那么在返回的数组中对应的位置将为 0。
# 计算每个类别的出现次数。这行使用 NumPy 的 bincount 函数计算每个类别的出现次数。 minlength=nc 确保权重数组的长度至少为 nc ,即使某些类别在标签中从未出现。
weights = np.bincount(classes, minlength=nc) # occurrences per class 每类发生次数。
# 可选的网格点计数(已注释)。
# 这部分代码被注释掉了,它用于在权重数组前添加网格点计数,这在某些特定的训练场景(如 uCE 训练)中可能需要。这里计算每个图像的网格点数量,并据此调整权重。
# Prepend gridpoint count (for uCE training)
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
# weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
# 替换空类别计数。这行将权重数组中为 0 的元素替换为 1,以避免在后续计算中出现除以零的错误。
weights[weights == 0] = 1 # replace empty bins with 1 将空容器替换为 1。
# 计算类别权重。这行计算每个类别的权重,即每个类别的目标数量的倒数。
weights = 1 / weights # number of targets per class 每类目标数量。
# 标准化权重。这行将权重标准化,使得所有权重的和为 1。
weights /= weights.sum() # normalize 标准化。
# 返回一个 PyTorch 张量,包含每个类别的权重。
return torch.from_numpy(weights)
# 这个函数的目的是计算类别权重,这些权重可以用于训练过程中的损失函数,以平衡不同类别的样本数量,特别是在类别不平衡的数据集中。
19.def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
# 这段代码定义了一个名为 labels_to_image_weights 的函数,它根据类别权重和图像内容生成图像权重。这个函数可以用来在目标检测任务中对图像样本进行加权,以便在训练过程中对某些类别的样本给予更多的关注。
# 这行定义了一个名为 labels_to_image_weights 的函数,它接受三个参数: labels , nc ,和 class_weights 。
# 1.labels :是一个包含图像标签的列表。
# 2.nc :是类别总数的默认值(80)。
# 3.class_weights :是一个包含每个类别权重的数组,默认为全1的80维数组。
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
# 根据 class_weights 和图像内容生成图像权重。
# Produces image weights based on class_weights and image contents
# 计算每个图像的类别计数。
# 这行代码使用列表推导式遍历 labels 列表中的每个标签数组 x ,并对每个数组执行以下操作。
# x[:, 0] :提取每个标签数组中的第一列,即类别索引。
# .astype(np.int) :确保类别索引是整数类型。
# np.bincount :计算每个类别的出现次数, minlength=nc 确保结果数组的长度至少为 nc ,即使某些类别在标签中从未出现。
# np.array :将列表推导式的结果转换为 NumPy 数组, class_counts 的形状为 (len(labels), nc) 。
class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
# 计算每个图像的权重。
# class_weights.reshape(1, nc) :将 class_weights 重塑为 (1, nc) 形状,以便与 class_counts 进行广播乘法。
# * :运算符执行 class_weights 和 class_counts 之间的元素级乘法,结果的形状为 (len(labels), nc) 。
# .sum(1) :对结果数组的每一行(即每个图像)进行求和,得到每个图像的总权重,结果的形状为 (len(labels),) 。
image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
# 注释掉的代码。这行代码被注释掉了,它展示了如何使用 random.choices 函数根据 image_weights 选择一个图像索引。 k=1 表示选择一个索引。
# index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
# 返回图像权重。这行代码返回包含每个图像权重的数组。
return image_weights
# 这个函数的目的是计算每个图像的权重,这些权重基于类别权重和图像内容。这在处理类别不平衡的数据集时非常有用,可以帮助模型更关注样本数量较少的类别。
# 通过计算每个图像中每个类别的出现次数,并将这些次数与类别权重相乘,可以得到每个图像的总权重。这些权重可以用于训练过程中的图像采样,以平衡不同类别的样本数量。
20.def coco80_to_coco91_class():
# 这段代码定义了一个名为 coco80_to_coco91_class 的函数,它用于将 COCO 数据集中的 80 类索引转换为原始论文中的 91 类索引。这个转换在需要与原始 COCO 论文中的对象类别进行对比时非常有用。
# 功能 :该函数返回一个列表 x ,其中包含了 COCO 2014 数据集(80 类)中每个类别对应的 COCO 论文(91 类)中的索引。
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
# 这两行代码被注释掉了,它们原本用于从文件中加载 COCO 数据集的类别名称和 COCO 论文中的类别名称。
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
# 这行代码被注释掉了,它原本用于创建一个从 COCO 80 类索引到 COCO 91 类索引的映射列表。
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
# 这行代码被注释掉了,它原本用于创建一个从 COCO 91 类索引到 COCO 80 类索引的映射列表。
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
# 这个列表直接提供了从 COCO 80 类索引到 COCO 91 类索引的映射。列表中的每个数字代表 COCO 80 类索引对应的 COCO 91 类索引。
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
# 返回一个列表,包含从 80 类索引到 91 类索引的映射。
return x
21.def xyxy2xywh(x):
# 这段代码定义了一个名为 xyxy2xywh 的函数,它用于将边界框的坐标从 xyxy 格式(即左上角和右下角的坐标)转换为 xywh 格式(即中心点坐标加上宽度和高度)。
# 1.x :是一个包含边界框坐标的数组,形状为 nx4 ,其中 n 是边界框的数量,4 代表 (x1, y1, x2, y2) 。
def xyxy2xywh(x):
# 将 nx4 个框从 [x1, y1, x2, y2] 转换为 [x, y, w, h],其中 xy1=左上角,xy2=右下角。
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
# 如果输入 x 是一个 PyTorch 张量,则使用 clone() 方法创建一个副本;如果是 NumPy 数组,则使用 np.copy() 创建一个副本。这样做是为了避免在原数组上直接修改。
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
# 坐标转换。
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
y[:, 2] = x[:, 2] - x[:, 0] # width
y[:, 3] = x[:, 3] - x[:, 1] # height
# 返回结果。返回转换后的 xywh 格式的边界框坐标。
return y
22.def xywh2xyxy(x):
# 用于将边界框的坐标从中心点坐标加宽高(x, y, w, h)格式转换为两个角点坐标(x1, y1, x2, y2)格式。这里的x和y代表边界框的中心点坐标,w和h代表边界框的宽度和高度。转换后的x1和y1代表边界框左上角的坐标,x2和y2代表右下角的坐标。
# 1.x :是一个包含边界框坐标的数组。
def xywh2xyxy(x):
# 将 nx4 个框从 [x, y, w, h] 转换为 [x1, y1, x2, y2],其中 xy1=左上角,xy2=右下角
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
# 这行代码检查输入 x 的类型,如果是PyTorch张量( torch.Tensor ),则使用 clone() 方法创建 x 的一个副本;如果不是,假设 x 是一个NumPy数组,并使用 np.copy() 创建一个副本。这样做是为了避免在原数组上直接修改,可能会导致一些不可预见的问题。
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
# 坐标转换。
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
# 返回转换后的坐标数组。
return y
# 这个函数通常用于目标检测任务中,将边界框的坐标格式统一,以便于后续的处理和评估。
23.def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
# 这段代码定义了一个名为 xywhn2xyxy 的函数,它用于将边界框的坐标从中心点和宽度/高度(即 xywh 格式)转换为左上角和右下角的坐标(即 xyxy 格式)。这种转换在图像处理和目标检测任务中很常见,因为不同的库和框架可能使用不同的坐标格式。
# 函数参数。
# 1.x :一个形状为 nx4 的数组或张量,其中 n 是边界框的数量,4 表示 x, y, w, h (中心点的 x 坐标、中心点的 y 坐标、宽度、高度)。
# 2.w :图像的宽度,默认为 640。
# 3.h :图像的高度,默认为 640。
# 4.padw :图像的左右填充,默认为 0。
# 5.padh :图像的上下填充,默认为 0。
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
# 将 nx4 个框从 [x, y, w, h] 标准化为 [x1, y1, x2, y2],其中 xy1=左上角,xy2=右下角
# Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
# 如果输入 x 是一个 PyTorch 张量,则使用 clone() 方法创建一个副本;如果 x 是一个 NumPy 数组,则使用 copy() 方法创建一个副本。这是为了避免在原地修改输入数据。
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
# 返回一个新的数组或张量 y ,其中包含了转换后的 xyxy 格式的边界框坐标。
return y
# 这个函数是一个坐标转换工具,它将边界框的坐标从 xywh 格式转换为 xyxy 格式,这对于图像处理和目标检测任务非常有用。通过这种方式,可以确保边界框的坐标与图像的尺寸和填充保持一致。
24.def xyn2xy(x, w=640, h=640, padw=0, padh=0):
# 这段代码定义了一个名为 xyn2xy 的函数,它用于将归一化的分割掩码坐标转换为像素坐标。这个函数特别适用于将图像分割任务中使用的归一化坐标转换为实际的像素坐标,这对于图像处理和计算机视觉任务非常有用。
# 函数参数。
# 1.x :一个数组或张量,包含归一化的分割掩码坐标。这些坐标通常是相对于图像宽度和高度的比例。
# 2.w :图像的宽度,默认为 640。
# 3.h :图像的高度,默认为 640。
# 4.padw :图像的左右填充,默认为 0。
# 5.padh :图像的上下填充,默认为 0。
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
# 将规范化段转换为像素段,形状为 (n,2)
# Convert normalized segments into pixel segments, shape (n,2)
# 如果输入 x 是一个 PyTorch 张量,则使用 clone() 方法创建一个副本;如果 x 是一个 NumPy 数组,则使用 copy() 方法创建一个副本。这是为了避免在原地修改输入数据。
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
# 计算每个分割点的左上角 x 坐标。 x[:, 0] 是归一化的 x 坐标,通过乘以图像宽度 w 并加上左右填充 padw ,可以得到实际的像素坐标。
y[:, 0] = w * x[:, 0] + padw # top left x
# 计算每个分割点的左上角 y 坐标。 x[:, 1] 是归一化的 y 坐标,通过乘以图像高度 h 并加上上下填充 padh ,可以得到实际的像素坐标。
y[:, 1] = h * x[:, 1] + padh # top left y
# 返回一个新的数组或张量 y ,其中包含了转换后的像素坐标。
return y
# 这个函数是一个坐标转换工具,它将分割掩码的归一化坐标转换为像素坐标,这对于图像分割和目标检测任务非常有用。通过这种方式,可以确保分割掩码的坐标与图像的实际尺寸相匹配,从而可以正确地在图像上应用分割掩码。
25.def segment2box(segment, width=640, height=640):
# 这段代码定义了一个名为 segment2box 的函数,它将一个分割标签(通常是由一系列点组成的线段)转换为一个边界框(box),这个边界框是由这些点形成的最小外接矩形。这个函数还确保了边界框完全位于图像内部,即它的坐标不会超出图像的尺寸。
# 1.segment :输入的分割标签,形状为 (2, m) 的 numpy 数组,其中 m 是点的数量,第一行包含 x 坐标,第二行包含 y 坐标。
# 2.width 和 3.height :图像的宽度和高度,默认为 640x640。
def segment2box(segment, width=640, height=640):
# 将 1 个段标签转换为 1 个框标签,应用图像内部约束,即 (xy1, xy2, ...) 到 (xyxy)。
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
# 转置分割标签。将输入的分割标签转置,使得 x 和 y 分别包含所有的 x 和 y 坐标。
x, y = segment.T # segment xy
# 应用图像内部约束。创建一个布尔数组,标记所有位于图像内部的点。
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
# 只保留位于图像内部的点。
x, y, = x[inside], y[inside]
# 创建边界框。如果至少有一个有效的 x 坐标,计算边界框的最小 x、最小 y、最大 x 和最大 y 坐标,并返回一个包含这些坐标的数组。如果没有有效的 x 坐标(即所有点都在图像外部),则返回一个全零的数组。
# 返回一个包含边界框坐标的 numpy 数组,形状为 (1, 4) ,其中包含 [x_min, y_min, x_max, y_max] 。
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
# segment2box 函数通过找到分割标签中所有位于图像内部的点,并计算这些点形成的最小外接矩形,从而将分割标签转换为边界框。这个函数在图像分割和目标检测任务中非常有用,因为它可以将分割数据转换为边界框格式,便于后续的处理和分析。
26.def segments2boxes(segments):
# 这段代码定义了一个名为 segments2boxes 的函数,它将分割标签(通常是一系列点的坐标)转换为边界框标签。边界框标签是一种更常用的格式,用于目标检测任务,它表示为 (cls, x, y, w, h),其中 (x, y) 是边界框的中心点坐标,w 和 h 分别是边界框的宽度和高度。
# 函数参数 :
# 1.segments :一个包含分割坐标的列表或数组,每个分割坐标是一个形状为 (N, 2) 的数组,其中 N 是点的数量,每行包含一个点的 x 和 y 坐标。
def segments2boxes(segments):
# 将线段标签转换为框标签,即 (cls, xy1, xy2, ...) 转换为 (cls, xywh)。
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
# 初始化一个空列表 boxes ,用于存储转换后的边界框坐标。
boxes = []
# 遍历 segments 中的每个分割坐标数组 s 。
for s in segments:
# 将分割坐标数组 s 转置,并分别获取 x 和 y 坐标。 s.T 是 s 的转置,这样 x 包含所有点的 x 坐标, y 包含所有点的 y 坐标。
x, y = s.T # segment xy 线段 xy。
# 对于每个分割,计算包围所有点的最小和最大 x、y 坐标,并将这些值作为边界框的坐标添加到 boxes 列表中。这里使用的是 xyxy 格式,即 (x_min, y_min, x_max, y_max)。
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
# 将 boxes 列表转换为 NumPy 数组,并调用 xyxy2xywh 函数将其从 xyxy 格式转换为 xywh 格式。 xywh 格式表示为 (x, y, w, h),其中 (x, y) 是边界框的中心点坐标,w 和 h 分别是边界框的宽度和高度。
# def xyxy2xywh(x): -> 它用于将边界框的坐标从 xyxy 格式(即左上角和右下角的坐标)转换为 xywh 格式(即中心点坐标加上宽度和高度)。返回结果。返回转换后的 xywh 格式的边界框坐标。 -> return y
return xyxy2xywh(np.array(boxes)) # cls, xywh
27.def resample_segments(segments, n=1000):
# 这段代码定义了一个名为 resample_segments 的函数,它用于对输入的分割点(segments)进行上采样。上采样是指增加数据点的数量,以获得更平滑或更详细的数据表示。这个函数特别适用于图像分割任务中,对分割边界的点进行插值,以便在图像变换后保持边界的连续性和平滑性。
# 1.segments :输入的分割点数组,形状为 (n, 2, m) ,其中 n 是分割的数量, m 是每个分割中的点的数量, 2 表示每个点的 x 和 y 坐标。
# 2.n :可选参数,指定上采样后每个分割中的点的数量,默认为 1000。
def resample_segments(segments, n=1000):
# Up-sample an (n,2) segment
# 遍历每个分割。遍历输入的分割点数组 segments , i 是索引, s 是当前分割的点。
for i, s in enumerate(segments):
# numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0)
# np.linspace 是 NumPy 库中的一个函数,用于在指定的区间内生成等间隔的数列。这个函数非常有用,当你需要在两个数值之间生成一个具有固定步长的数组时。
# 参数 :
# start :数列的起始值。
# stop :数列的结束值。
# num :要生成的等间隔样本数量,默认为 50。
# endpoint :布尔值,如果为 True,则数列中包含 stop 值;如果为 False,则不包含, stop 之前的最后一个值会小于 stop 。
# retstep :布尔值,如果为 True,则返回样本和步长;如果为 False,则只返回样本。
# dtype :输出数组的数据类型。如果未指定,则由输入数据的类型决定。
# axis :指定沿着哪个轴返回样本数组,0 为默认值。
# 返回值 :
# 如果 retstep 为 False(默认情况),返回一个一维数组,包含从 start 到 stop 的等间隔数字。
# 如果 retstep 为 True ,则返回一个元组,包含样本数组和步长。
# 计算插值点。
# 生成一个从 0 到 len(s) - 1 的等间隔序列,长度为 n ,用于确定插值点的位置。
x = np.linspace(0, len(s) - 1, n)
# numpy.arange([start,] stop[, step,], dtype=None)
# np.arange 是 NumPy 库中的一个函数,用于生成一个等差数列。这个函数类似于 Python 内置的 range 函数,但它返回的是 NumPy 数组,而不是一个迭代器。
# 参数 :
# start :数列的起始值。如果不提供,则默认为 0。
# stop :数列的结束值,生成的数列不会包含这个值。
# step :两个值之间的间隔,默认为 1。
# dtype :输出数组的数据类型。如果未指定,则由输入数据的类型决定。
# 返回值 :
# 返回一个一维数组,包含从 start 开始到 stop (不包含)的等间隔数字。
# 生成一个从 0 到 len(s) - 1 的序列,用于原始点的位置。
xp = np.arange(len(s))
# 对每个分割进行插值。
# np.interp(x, xp, s[:, i]) :对每个分割的 x 和 y 坐标分别进行插值。 np.interp 是 NumPy 中的插值函数,它根据给定的 x 值和 xp 、 s 值进行线性插值。
# for i in range(2) :因为每个点有两个坐标(x 和 y),所以分别对它们进行插值。
# 重新组合插值后的点。
# np.concatenate([...]).reshape(2, -1).T :将插值后的 x 和 y 坐标重新组合成一个新的分割点数组,形状为 (2, n) ,然后转置回 (n, 2) 。
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
# 返回上采样后的分割点数组。
return segments
# resample_segments 函数通过对分割点进行线性插值,增加了每个分割中的点的数量,从而使得分割边界在图像变换后更加平滑和连续。这对于图像分割和目标检测任务中保持边界的精确性非常重要。
28.def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
# 这段代码定义了一个名为 scale_coords 的函数,它用于将坐标从一张图片的形状 ( img1_shape ) 缩放到另一张图片的形状 ( img0_shape )。这在图像处理中很常见,尤其是在进行图像缩放、裁剪或任何改变图像尺寸的操作后,需要对图像中对象的边界框坐标进行相应的调整。
# img1_shape :第一个图像(缩放后的图像)的形状,通常是一个元组 (height, width) 。
# coords :要缩放的坐标,格式为 xyxy (即 [x1, y1, x2, y2] ),其中 (x1, y1) 是左上角坐标, (x2, y2) 是右下角坐标。
# img0_shape :第二个图像(原始图像)的形状,通常是一个元组 (height, width) 。
# ratio_pad :一个可选参数,用于提供缩放比例和填充。如果提供,它应该是一个元组 (scale_ratio, (wh_padding, hh_padding)) 。
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
# 将坐标 (xyxy) 从 img1_shape 重新缩放为 img0_shape。
# Rescale coords (xyxy) from img1_shape to img0_shape
# 检查是否提供了 ratio_pad 参数。
if ratio_pad is None: # calculate from img0_shape 从 img0_shape 计算。
# 如果没有提供,计算缩放比例 gain 和填充 pad 。
# gain 是 img1_shape 和 img0_shape 之间的最小缩放比例。
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
# pad 是宽度和高度方向上的填充量。
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
# 如果提供了 ratio_pad ,则直接使用提供的缩放比例和填充。
else:
# 从 ratio_pad 元组中的第一个元素(即 ratio_pad[0] )取出第一个值(即 ratio_pad[0][0] )并赋给变量 gain 。这个值代表了缩放比例,用于调整坐标的大小。
gain = ratio_pad[0][0]
# 将 ratio_pad 元组的第二个元素(即 ratio_pad[1] )赋给变量 pad 。这个值是一个元组或列表,包含了宽度和高度方向上的填充量,用于调整坐标的位置。
pad = ratio_pad[1]
# 分别对 x 和 y 坐标应用填充。
coords[:, [0, 2]] -= pad[0] # x padding
coords[:, [1, 3]] -= pad[1] # y padding
# 对所有坐标应用缩放比例。
coords[:, :4] /= gain
# 确保坐标不会超出 img0_shape 定义的边界。
# def clip_coords(boxes, img_shape): -> 它用于将边界框坐标裁剪或限制在图像的尺寸内。
clip_coords(coords, img0_shape)
# 返回调整后的坐标。
return coords
29.def clip_coords(boxes, img_shape):
# 这段代码定义了一个名为 clip_coords 的函数,它用于将边界框坐标裁剪或限制在图像的尺寸内。这是一个常见的操作,用于确保边界框不会超出图像的边界。
# 1.boxes :要裁剪的边界框坐标,格式为 xyxy ,即 [x1, y1, x2, y2] ,其中 (x1, y1) 是左上角坐标, (x2, y2) 是右下角坐标。
# 2.img_shape :图像的形状,通常是一个元组 (height, width) 。
def clip_coords(boxes, img_shape):
# 将 xyxy 边界框剪裁为图像形状(高度、宽度)。
# Clip bounding xyxy bounding boxes to image shape (height, width)
# 将所有边界框的 x1 坐标限制在 0 到图像宽度 img_shape[1] 之间。
boxes[:, 0].clamp_(0, img_shape[1]) # x1
# 将所有边界框的 y1 坐标限制在 0 到图像高度 img_shape[0] 之间。
boxes[:, 1].clamp_(0, img_shape[0]) # y1
# 将所有边界框的 x2 坐标限制在 0 到图像宽度 img_shape[1] 之间。
boxes[:, 2].clamp_(0, img_shape[1]) # x2
# 将所有边界框的 y2 坐标限制在 0 到图像高度 img_shape[0] 之间。
boxes[:, 3].clamp_(0, img_shape[0]) # y2
# clamp_ 方法是 PyTorch 中的一个就地操作(in-place operation),它将张量的元素限制在指定的范围内,并直接修改原始张量。
30.def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
这段代码是用于计算两个边界框(bounding boxes)之间的交并比(IoU),以及它们的广义交并比(Generalized IoU, GIoU)、距离交并比(Distance IoU, DIoU)和完备交并比(Complete IoU, CIoU)。这些度量在目标检测任务中非常重要,尤其是在评估模型性能时。
# 定义一个函数 bbox_iou ,它接受两个边界框 1.box1 和 2.box2 ,以及一些可选参数来指定是否计算IoU、4.GIoU、5.DIoU或6.CIoU。 7.eps 是一个小的常数,用于防止除以零的错误。
def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
# 返回 box1 与 box2 的 IoU。box1 为 4,box2 为 nx4。
# Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
# 将 box2 转置,使其形状从 nx4 变为 4xn ,以便于计算。
box2 = box2.T
# Get the coordinates of bounding boxes 获取边界框的坐标。
# 如果 box1 和 box2 的坐标格式是 (x1, y1, x2, y2) ,则直接使用这些坐标。
if x1y1x2y2: # x1, y1, x2, y2 = box1
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
# 如果坐标格式不是 (x1, y1, x2, y2) ,则将其从 (x, y, w, h) 转换为 (x1, y1, x2, y2) 。
else: # transform from xywh to xyxy 从 xywh 转换为 xyxy。
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
# Intersection area 交叉区域。
# 计算两个边界框的交集面积。
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
# Union Area 交叉区域面积。
# 计算 box1 的宽度和高度,并加上 eps 以避免除以零。
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
# 计算两个边界框的并集面积。
union = w1 * h1 + w2 * h2 - inter + eps
# 计算IoU。
iou = inter / union
# 如果需要计算GIoU、DIoU或CIoU,则继续执行以下步骤。
if GIoU or DIoU or CIoU:
# 计算包含两个边界框的最小凸框的宽度。
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
# 计算包含两个边界框的最小凸框的高度。
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
# 如果需要计算 DIoU 或 CIoU ,则计算中心点之间的距离和宽高比的一致性。
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
(b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
if DIoU:
return iou - rho2 / c2 # DIoU
elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - (rho2 / c2 + v * alpha) # CIoU
# 如果只需要计算GIoU,则计算GIoU。
else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
c_area = cw * ch + eps # convex area
return iou - (c_area - union) / c_area # GIoU
else:
# return iou 或 return iou - (...) :根据需要返回IoU、GIoU、DIoU或CIoU。
return iou # IoU
# 这个函数是一个非常通用的边界框IoU计算工具,可以用于多种不同的目标检测和评估场景。
31.def bbox_alpha_iou(box1, box2, x1y1x2y2=False, GIoU=False, DIoU=False, CIoU=False, alpha=2, eps=1e-9):
# 这段代码是一个Python函数,用于计算两个边界框(bounding boxes)之间的交并比(IoU),并在此基础上扩展到广义交并比(Generalized IoU, GIoU)、距离交并比(Distance IoU, DIoU)和完备交并比(Complete IoU, CIoU)。
# 此外,该函数还引入了一个参数 alpha ,用于对IoU进行幂运算,这可能是为了调整IoU的敏感度或强调某些区域。
# 定义一个函数 bbox_alpha_iou ,它接受两个边界框 1.box1 和 2.box2 ,以及一些可选参数来指定是否计算IoU、4.GIoU、5.DIoU或7.CIoU,以及IoU的幂 8.alpha 和一个小的常数 9.eps 。
def bbox_alpha_iou(box1, box2, x1y1x2y2=False, GIoU=False, DIoU=False, CIoU=False, alpha=2, eps=1e-9):
# 返回 box1 到 box2 的 tsqrt_he IoU。box1 为 4,box2 为 nx4。
# Returns tsqrt_he IoU of box1 to box2. box1 is 4, box2 is nx4
# 将 box2 转置,使其形状从 nx4 变为 4xn ,以便于计算。
box2 = box2.T
# Get the coordinates of bounding boxes 获取边界框的坐标。
# 如果 box1 和 box2 的坐标格式是 (x1, y1, x2, y2) ,则直接使用这些坐标。
if x1y1x2y2: # x1, y1, x2, y2 = box1
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
# 如果坐标格式不是 (x1, y1, x2, y2) ,则将其从 (x, y, w, h) 转换为 (x1, y1, x2, y2) 。
else: # transform from xywh to xyxy 从 xywh 转换为 xyxy。
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
# Intersection area 交叉区域。
# 计算两个边界框的交集面积。
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
# Union Area 交叉区域面积。
# 计算 box1 的宽度和高度,并加上 eps 以避免除以零。
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
# 计算 box2 的宽度和高度,并加上 eps 以避免除以零。
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
# 计算两个边界框的并集面积。
union = w1 * h1 + w2 * h2 - inter + eps
# change iou into pow(iou+eps) 将 iou 更改为 pow(iou+eps)。
# iou = inter / union
# 计算IoU,并将其提升到 alpha 次幂。
iou = torch.pow(inter/union + eps, alpha)
# beta = 2 * alpha
# 如果需要计算GIoU、DIoU或CIoU,则继续执行以下步骤。
if GIoU or DIoU or CIoU:
# 计算包含两个边界框的最小凸框的宽度。
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width 凸(最小封闭矩形)宽度。
# 计算包含两个边界框的最小凸框的高度。
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height 凸(最小封闭矩形)高度。
# 如果需要计算DIoU或CIoU,则计算中心点之间的距离和宽高比的一致性。
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 距离或完整 IoU https://arxiv.org/abs/1911.08287v1。
c2 = (cw ** 2 + ch ** 2) ** alpha + eps # convex diagonal 凸对角线。
rho_x = torch.abs(b2_x1 + b2_x2 - b1_x1 - b1_x2)
rho_y = torch.abs(b2_y1 + b2_y2 - b1_y1 - b1_y2)
rho2 = ((rho_x ** 2 + rho_y ** 2) / 4) ** alpha # center distance 中心距。
if DIoU:
return iou - rho2 / c2 # DIoU
elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
with torch.no_grad():
alpha_ciou = v / ((1 + eps) - inter / union + v)
# return iou - (rho2 / c2 + v * alpha_ciou) # CIoU
return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)) # CIoU
# 如果只需要计算GIoU,则计算GIoU。
else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
# c_area = cw * ch + eps # convex area
# return iou - (c_area - union) / c_area # GIoU
c_area = torch.max(cw * ch + eps, union) # convex area
return iou - torch.pow((c_area - union) / c_area + eps, alpha) # GIoU
else:
# return iou 或 return iou - (...) :根据需要返回IoU、GIoU、DIoU或CIoU。
return iou # torch.log(iou+eps) or iou
# 这个函数通过引入 alpha 参数,提供了一种灵活的方式来调整IoU的计算方式,可能有助于在不同的应用场景中获得更好的性能。
32.def box_iou(box1, box2):
# 定义了一个名为 box_iou 的函数,用于计算两个边界框集合之间的交并比(Intersection over Union, IoU)。IoU 是一个衡量两个边界框重叠程度的指标,常用于目标检测和图像分割任务中。
# box1 和 box2 是两个边界框集合,它们都是形状为 (N, 4) 和 (M, 4) 的张量,分别代表 N 个和 M 个边界框。每个边界框由四个值组成: (x1, y1, x2, y2) ,其中 (x1, y1) 是左上角坐标, (x2, y2) 是右下角坐标。
def box_iou(box1, box2):
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
box1 (Tensor[N, 4])
box2 (Tensor[M, 4])
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
# 辅助函数:计算边界框面积。
# 这个内部函数 box_area 计算一个边界框的面积。输入 box 是一个形状为 (4, n) 的张量,其中 n 是边界框的数量。
def box_area(box):
# box = 4xn
# 面积计算公式为: 宽度 * 高度 = (x2 - x1) * (y2 - y1) 。
return (box[2] - box[0]) * (box[3] - box[1])
# box1.T 和 box2.T 分别是 box1 和 box2 的转置,使得边界框的坐标维度与边界框数量维度对齐,以便于计算面积。
# area1 和 area2 分别是两个边界框集合的面积张量,形状为 (N,) 和 (M,) 。
area1 = box_area(box1.T)
area2 = box_area(box2.T)
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
# 计算交集。
# clamp(0) 确保计算结果非负。 prod(2) 计算交集的面积,即宽度乘以高度。
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
# 计算交并比。area1[:, None] 和 area2 将面积张量增加一个新的维度以便于广播。 IoU 的计算公式为: 交集面积 / (第一个集合的面积 + 第二个集合的面积 - 交集面积) 。
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
# 这个函数返回一个形状为 (N, M) 的张量,其中包含了 box1 中每个边界框与 box2 中每个边界框之间的 IoU 值。这个函数在目标检测任务中非常有用,尤其是在非极大值抑制(NMS)过程中,用于确定哪些边界框应该被保留或抑制。
33.def wh_iou(wh1, wh2):
# 这段代码定义了一个名为 wh_iou 的函数,用于计算两个宽度和高度集合之间的交并比(IoU)矩阵。IoU 是目标检测和计算机视觉中常用的一个指标,用于衡量两个边界框的重叠程度。
# 1.wh1 : 一个张量,形状为 (nx2) ,其中 n 是第一个集合中边界框的数量, 2 表示每个边界框的宽度和高度。
# 2.wh2 : 一个张量,形状为 (mx2) ,其中 m 是第二个集合中边界框的数量, 2 表示每个边界框的宽度和高度。
def wh_iou(wh1, wh2):
# 返回 nxm IoU 矩阵。wh1 是 nx2,wh2 是 mx2
# Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
# 调整形状以便广播。
# 将 wh1 张量增加一个维度,使其形状变为 (nx1x2) 。
wh1 = wh1[:, None] # [N,1,2]
# 将 wh2 张量增加一个维度,使其形状变为 (1xmx2) 。
wh2 = wh2[None] # [1,M,2]
# 计算交集。首先使用 torch.min 计算两个边界框在每个维度上的最小值,这代表了交集的边界。然后沿着最后一个维度(维度2)计算乘积,得到每个边界框对的交集面积。
inter = torch.min(wh1, wh2).prod(2) # [N,M]
# 计算并返回 IoU。计算 IoU 值,公式为 iou = inter / (area1 + area2 - inter) ,其中 area1 和 area2 分别是 wh1 和 wh2 中边界框的面积。
# 返回一个形状为 (nxm) 的 IoU 矩阵,其中每个元素 (i, j) 表示 wh1 中第 i 个边界框与 wh2 中第 j 个边界框之间的 IoU 值。
return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
34.def box_giou(box1, box2):
# 这段代码定义了一个名为 box_giou 的函数,用于计算两组边界框之间的广义交并比(Generalized Intersection over Union,简称 GIoU)。GIoU 是一种衡量两个边界框重叠程度的指标,它不仅考虑了交集面积,还考虑了包含两个边界框的最小边界框的面积。
# 定义一个函数 box_giou ,接受两组边界框 1.box1 和 2.box2 作为输入。
def box_giou(box1, box2):
# 返回两组框之间的广义交并比(Jaccard 指数)。
"""
Return generalized intersection-over-union (Jaccard index) between two sets of boxes.
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes
Returns:
Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values
for every element in boxes1 and boxes2
"""
# 定义一个内部函数 box_area ,用于计算单个边界框的面积。
def box_area(box):
# box = 4xn
# 计算边界框的面积,公式为 width * height 。
return (box[2] - box[0]) * (box[3] - box[1])
# 计算第一组边界框的面积。
area1 = box_area(box1.T)
# 计算第二组边界框的面积。
area2 = box_area(box2.T)
# 计算两组边界框之间的交集面积。首先找到每个边界框对的最小外接矩形,然后计算交集面积。
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
# 计算两组边界框的并集面积。
union = (area1[:, None] + area2 - inter)
# 计算两组边界框的 IoU(交并比)。
iou = inter / union
# 找到两组边界框中每个边界框对的左上角坐标。
lti = torch.min(box1[:, None, :2], box2[:, :2])
# 找到两组边界框中每个边界框对的右下角坐标。
rbi = torch.max(box1[:, None, 2:], box2[:, 2:])
# 计算包含两组边界框的最小边界框的宽度和高度。
whi = (rbi - lti).clamp(min=0) # [N,M,2]
# 计算包含两组边界框的最小边界框的面积。
areai = whi[:, :, 0] * whi[:, :, 1]
# 返回 GIoU 值,即 IoU 减去包含两组边界框的最小边界框面积与两组边界框并集面积之差的比值。
return iou - (areai - union) / areai
# 这个函数可以用来计算任意两组边界框之间的 GIoU 值,这在目标检测和图像分割任务中非常有用,因为它提供了一个更全面的衡量重叠程度的指标。
35.def box_ciou(box1, box2, eps: float = 1e-7):
# 这段代码定义了一个名为 box_ciou 的函数,用于计算两组边界框之间的完备交并比(Complete Intersection over Union,简称 CIoU)。CIoU 是一种改进的 IoU 指标,它不仅考虑了交集面积和并集面积,还考虑了边界框中心点之间的距离以及宽高比的一致性。
# 定义一个函数 box_ciou ,接受两组边界框 1.box1 和 2.box2 作为输入,以及一个可选的小数 3.eps 用于防止除以零。
def box_ciou(box1, box2, eps: float = 1e-7):
"""
Return complete intersection-over-union (Jaccard index) between two sets of boxes.
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes
eps (float, optional): small number to prevent division by zero. Default: 1e-7
Returns:
Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values
for every element in boxes1 and boxes2
"""
# 定义一个内部函数 box_area ,用于计算单个边界框的面积。
def box_area(box):
# box = 4xn
# 计算边界框的面积,公式为 width * height 。
return (box[2] - box[0]) * (box[3] - box[1])
# 计算第一组边界框的面积。
area1 = box_area(box1.T)
# 计算第二组边界框的面积。
area2 = box_area(box2.T)
# 计算两组边界框之间的交集面积。
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
# 计算两组边界框的并集面积。
union = (area1[:, None] + area2 - inter)
# 计算两组边界框的 IoU(交并比)。
iou = inter / union
# 找到两组边界框中每个边界框对的左上角坐标。
lti = torch.min(box1[:, None, :2], box2[:, :2])
# 找到两组边界框中每个边界框对的右下角坐标。
rbi = torch.max(box1[:, None, 2:], box2[:, 2:])
# 计算包含两组边界框的最小边界框的宽度和高度。
whi = (rbi - lti).clamp(min=0) # [N,M,2]
# 计算包含两组边界框的最小边界框的对角线距离的平方。
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
# centers of boxes
# 计算两组边界框中心点的坐标。
x_p = (box1[:, None, 0] + box1[:, None, 2]) / 2
y_p = (box1[:, None, 1] + box1[:, None, 3]) / 2
x_g = (box2[:, 0] + box2[:, 2]) / 2
y_g = (box2[:, 1] + box2[:, 3]) / 2
# The distance between boxes' centers squared.
# 计算两组边界框中心点之间的距离的平方。
centers_distance_squared = (x_p - x_g) ** 2 + (y_p - y_g) ** 2
# 计算两组边界框的宽度和高度。
w_pred = box1[:, None, 2] - box1[:, None, 0]
h_pred = box1[:, None, 3] - box1[:, None, 1]
w_gt = box2[:, 2] - box2[:, 0]
h_gt = box2[:, 3] - box2[:, 1]
# 计算宽高比的一致性。
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
# 在不计算梯度的上下文中执行以下计算,以避免影响自动微分。
with torch.no_grad():
# 计算调整因子 alpha 。
alpha = v / (1 - iou + v + eps)
# 返回 CIoU 值,即 IoU 减去中心点距离的惩罚项再减去宽高比不一致的惩罚项。
return iou - (centers_distance_squared / diagonal_distance_squared) - alpha * v
# 这个函数可以用来计算任意两组边界框之间的 CIoU 值,这在目标检测任务中非常有用,因为它提供了一个更全面的衡量重叠程度的指标,特别是在处理不同尺度和形状的边界框时。
36.def box_diou(box1, box2, eps: float = 1e-7):
# 这段代码定义了一个名为 box_diou 的函数,用于计算两组边界框之间的距离交并比(Distance Intersection over Union,简称 DIoU)。
# DIoU 是一种改进的 IoU 指标,它在传统的 IoU 基础上增加了对边界框中心点距离的惩罚项,使得模型在训练时不仅考虑重叠区域,还考虑预测框与真实框之间的距离。
# 定义一个函数 box_diou ,接受两组边界框 1.box1 和 2.box2 作为输入,以及一个可选的小数 3.eps 用于防止除以零。
def box_diou(box1, box2, eps: float = 1e-7):
"""
Return distance intersection-over-union (Jaccard index) between two sets of boxes.
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes
eps (float, optional): small number to prevent division by zero. Default: 1e-7
Returns:
Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values
for every element in boxes1 and boxes2
"""
def box_area(box):
# box = 4xn
return (box[2] - box[0]) * (box[3] - box[1])
area1 = box_area(box1.T)
area2 = box_area(box2.T)
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
union = (area1[:, None] + area2 - inter)
iou = inter / union
lti = torch.min(box1[:, None, :2], box2[:, :2])
rbi = torch.max(box1[:, None, 2:], box2[:, 2:])
whi = (rbi - lti).clamp(min=0) # [N,M,2]
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
# centers of boxes
x_p = (box1[:, None, 0] + box1[:, None, 2]) / 2
y_p = (box1[:, None, 1] + box1[:, None, 3]) / 2
x_g = (box2[:, 0] + box2[:, 2]) / 2
y_g = (box2[:, 1] + box2[:, 3]) / 2
# The distance between boxes' centers squared.
# 计算两组边界框中心点之间的距离的平方。
centers_distance_squared = (x_p - x_g) ** 2 + (y_p - y_g) ** 2
# The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared.
# 返回 DIoU 值,即 IoU 减去中心点距离的惩罚项。
return iou - (centers_distance_squared / diagonal_distance_squared)
# 这个函数可以用来计算任意两组边界框之间的 DIoU 值,这在目标检测任务中非常有用,因为它提供了一个更全面的衡量重叠程度的指标,特别是在处理不同位置和尺度的边界框时。通过考虑中心点距离,DIoU 能够更好地指导模型学习预测框的位置,从而提高检测精度。
37.def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, labels=()):
# 该函数实现了非最大抑制(NMS)算法,用于在目标检测任务中去除重叠的检测框。
# 函数参数。
# 1.prediction : 模型的输出,包含了检测框、置信度和类别概率。
# 2.conf_thres=0.25 : 置信度阈值,只有置信度高于这个值的检测框才会被考虑。
# 3.iou_thres=0.45 : 交并比(IoU)阈值,用于确定两个检测框是否重叠过多。
# 4.classes=None : 一个可选的列表,用于过滤特定类别的检测框。
# 5.agnostic=False : 是否忽略类别信息,进行类别无关的NMS。
# 6.multi_label=False : 是否允许一个检测框有多个类别标签。
# 7.labels=() : 一个可选的元组,包含需要考虑的类别标签。
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()):
# 对推理结果运行非最大抑制 (NMS)。
# 返回 :
# 检测列表,每个图像的 (n,6) 张量 [xyxy, conf, cls]。
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
# 计算类别数。计算预测结果中类别的数量。这里假设 prediction 的形状是 [batch_size, num_anchors, num_classes + 5] ,其中 num_classes + 5 包括了4个坐标值、1个置信度值和 num_classes 个类别概率值。
nc = prediction.shape[2] - 5 # number of classes 类别数量。
# 筛选置信度高于阈值的候选框。 根据置信度阈值筛选出候选框。
xc = prediction[..., 4] > conf_thres # candidates 候选框。
# Settings 设置。
# 设置参数。
# 设置检测框的最小和最大宽度和高度。
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height (像素)框的最小和最大宽度和高度。
# 每个图像最多检测的数量。
max_det = 300 # maximum number of detections per image 每幅图像的最大检测次数。
# 传递给 torchvision.ops.nms() 的最大检测框数量。
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() 放入 torchvision.ops.nms() 的最大框数。
# 超时限制,超过这个时间后退出。
time_limit = 10.0 # seconds to quit after 秒后退出。
# 是否需要冗余检测。
redundant = True # require redundant detections 需要冗余检测。
# 如果类别数大于1,则允许每个检测框有多个标签。
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) 每个边界框多个标签(每张图片增加 0.5ms)。
# 是否使用合并NMS。
merge = False # use merge-NMS 使用合并 NMS 。
# 初始化输出。
# 记录当前时间,用于测量NMS处理所需的时间。
t = time.time()
# 初始化一个输出列表,每个图像一个元素,每个元素是一个形状为 (0, 6) 的张量,表示没有检测结果。这里的6对应于 [x1, y1, x2, y2, conf, cls] ,即每个检测框的坐标、置信度和类别。
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
# 遍历每个图像的预测结果。 遍历 prediction 中的每个元素, xi 是图像索引, x 是该图像的预测结果。
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
# 应用置信度约束。
# 从当前图像的预测结果中筛选出置信度高于 conf_thres 的检测框
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
# 合并先验标签(自标注模式)。
# 检查是否存在先验标签(自标注模式)。
if labels and len(labels[xi]):
# 获取当前图像的先验标签。
l = labels[xi]
# 创建一个零张量,用于存储先验标签的信息。
v = torch.zeros((len(l), nc + 5), device=x.device)
# 将先验标签的坐标复制到 v 中。
v[:, :4] = l[:, 1:5] # box
# 设置先验标签的置信度为1.0。
v[:, 4] = 1.0 # conf
# 设置先验标签的类别概率为1.0。这里的 l[:, 0] 是类别索引, +5 是因为类别概率从第6个通道开始。
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
# 合并 预测结果 和 先验标签 。将预测结果和先验标签合并,以便一起进行NMS处理。
x = torch.cat((x, v), 0)
# 如果没有剩余,则处理下一个图像。
# If none remain process next image
# 条件检查。 这个条件检查当前图像的预测结果 x 的第一个维度(即检测框的数量)是否为0。如果为0,意味着没有检测框的置信度高于设定的阈值,或者所有检测框都被过滤掉了。
if not x.shape[0]:
# 继续处理下一张图像。如果没有剩余的检测框, continue 语句将跳过当前循环的剩余部分,直接进入下一个循环迭代,处理下一张图像的预测结果。
continue
# Compute conf
# 计算置信度。将物体置信度( obj_conf )与类别置信度( cls_conf )相乘,得到每个类别的最终置信度。这里假设 x 的第5列及以后是类别概率,第4列是物体置信度。
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# def xywh2xyxy(x): -> 用于将边界框的坐标从中心点坐标加宽高(x, y, w, h)格式转换为两个角点坐标(x1, y1, x2, y2)格式。返回转换后的坐标数组。 -> return y
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
# 转换边界框格式。将边界框从中心点坐标及宽高(x, y, w, h)格式转换为两个角点坐标(x1, y1, x2, y2)格式。
box = xywh2xyxy(x[:, :4])
# Detections matrix nx6 (xyxy, conf, cls)
# 构建检测矩阵。
# 如果启用多标签( multi_label=True )。
if multi_label:
# torch.nonzero(input, as_tuple=False) → LongTensor or tuple of LongTensor
# nonzero 函数通常是指一个用于找出非零元素的函数。在深度学习和张量操作中, nonzero 函数经常用来获取满足特定条件的元素的索引。这个函数在PyTorch中是一个内置函数,而在NumPy中也有一个同名的函数。
# input :一个张量, nonzero 函数将在其中寻找非零元素。
# as_tuple :一个布尔值,如果设置为 True ,则返回一个包含非零元素索引的元组;如果设置为 False (默认值),则返回一个包含非零元素索引的张量。
# 函数返回一个张量或元组,其中包含输入张量中所有非零元素的索引。
# 找出所有置信度超过阈值的类别。 x[:, 5:] 表示从第6列开始的所有类别置信度, > conf_thres 创建一个布尔矩阵,其中置信度超过阈值的位置为 True 。 nonzero() 返回这些 True 值的索引, T 转置这些索引,使得每一行代表一个检测框的所有有效类别。
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
# 将边界框、置信度和类别索引合并成一个矩阵。
# 这行代码将 边界框坐标 、 置信度 和 类别索引 合并成一个矩阵。
# box[i] 表示对应于索引 i 的边界框坐标, x[i, j + 5, None] 表示对应于索引 i 和 j 的置信度( j + 5 是因为类别置信度从第6列开始), j[:, None].float() 表示类别索引,转换为浮点数并增加一个维度以便合并。
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
# 单标签情况(Best class only)。
else: # best class only
# 找出每个边界框置信度最高的类别。这行代码找出每个边界框置信度最高的类别。 x[:, 5:] 表示所有类别置信度, max(1, keepdim=True) 返回每个边界框置信度最高的 值 和对应的类别 索引 。
conf, j = x[:, 5:].max(1, keepdim=True)
# 将边界框、置信度和类别索引合并成一个矩阵,并过滤。
# 这行代码将边界框坐标、置信度最高的值和对应的类别索引合并成一个矩阵,然后过滤出置信度超过阈值的检测结果。 box
# 表示所有边界框坐标, conf 表示置信度最高的值, j.float() 表示类别索引。最后,通过 conf.view(-1) > conf_thres 过滤出置信度超过阈值的检测结果。
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class
# 按类别过滤。如果提供了需要考虑的类别列表。
if classes is not None:
# 只保留属于指定类别的检测框。这里 x[:, 5:6] 表示检测框的类别索引, torch.tensor(classes, device=x.device) 将类别列表转换为张量并确保它与 x 在同一设备上。 any(1) 表示只要检测框属于列表中的任何一个类别,就保留该检测框。
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# 应用有限约束(注释掉的部分)
# Apply finite constraint
# 检查 x 中的所有值是否都是有限数(即没有 NaN 或 Inf)。
# if not torch.isfinite(x).all():
# 如果 x 中有任何非有限数,则只保留所有值都是有限数的检测框。
# x = x[torch.isfinite(x).all(1)]
# Check shape
# 检查形状。获取剩余检测框的数量。
n = x.shape[0] # number of boxes
# 无检测框。如果没有检测框, continue 跳过当前循环,处理下一张图像。
if not n: # no boxes
continue
# 检测框数量过多。如果检测框的数量超过了 max_nms 设置的最大值。
elif n > max_nms: # excess boxes
# 根据置信度对检测框进行排序,并只保留置信度最高的 max_nms 个检测框。这里 x[:, 4] 表示检测框的置信度, argsort(descending=True) 按置信度降序排列索引, [:max_nms] 选取置信度最高的 max_nms 个检测框。
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
# Batched NMS
# 这行代码计算每个边界框的类别偏移量。 x[:, 5:6] 表示从输入张量 x 中提取第6列(因为列的索引从0开始),这通常代表类别的索引。
# agnostic 是一个布尔值,如果为 True ,则表示模型是类别不可知的(agnostic),即不考虑类别进行NMS,此时 c 被设置为0。 如果为 False ,则 c 被设置为 max_wh ,这是一个预定义的最大宽度和高度的值,用于缩放类别索引。
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
# 这行代码将边界框的坐标和得分分开。
# x[:, :4] 表示提取 x 的前4列,这通常是边界框的坐标(x中心,y中心,宽度,高度)。
# x[:, 4] 表示提取 x 的第5列,这通常是边界框的得分。
# boxes 通过将 x[:, :4] 与 c 相加来计算,这样可以将类别索引加到边界框的坐标上,这是YOLOv7中处理多类别问题的一种方式。
# scores 直接从 x 的第5列获取。
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
# keep = torchvision.ops.nms(boxes, scores, iou_threshold)
# 参数:
# boxes (Tensor[N, 4])) :bounding boxes坐标. 格式:(x1, y1, x2, y2)
# scores (Tensor[N]) :bounding boxes得分
# iou_threshold (float) :IoU过滤阈值
# 返回值:
# keep :NMS过滤后的bouding boxes索引(降序排列)
# 这行代码应用非极大值抑制(NMS),以去除重叠的边界框。 torchvision.ops.nms 是PyTorch的一个函数,它接受边界框、得分和一个IOU阈值 iou_thres ,并返回 保留的 边界框的 索引 。
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
# 这行代码检查 NMS 后保留的边界框数量是否超过了最大检测数量 max_det 。如果是,只保留前 max_det 个边界框。
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
# 如果设置了 merge 标志,并且边界框的数量 n 在1到3000之间,那么执行边界框合并。
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# def box_iou(box1, box2): -> 用于计算两个边界框集合之间的交并比(Intersection over Union, IoU)。返回一个形状为 (N, M) 的张量,其中包含了 box1 中每个边界框与 box2 中每个边界框之间的 IoU 值。 -> return inter / (area1[:, None] + area2 - inter)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# 计算保留的边界框与所有边界框之间的交并比(IOU),只保留大于 iou_thres 的IOU值。
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
# 这行代码计算每个边界框的权重,权重是IOU值与得分的乘积。
weights = iou * scores[None] # box weights
# 这行代码使用加权平均值更新保留的边界框的坐标。 torch.mm 执行矩阵乘法, weights 作为行向量, x[:, :4] 作为列向量。
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
# 如果设置了 redundant 标志,那么执行以下操作。
if redundant:
# 这行代码过滤掉那些没有冗余(即没有与其他边界框重叠)的边界框,只保留那些至少与一个其他边界框重叠的边界框。
i = i[iou.sum(1) > 1] # require redundancy
# 这行代码将 NMS 处理后的边界框结果赋值给输出张量 output 的对应位置。 xi 是输出张量的索引, x[i] 是经过NMS处理后保留的边界框。
output[xi] = x[i]
# 这行代码检查自NMS开始以来经过的时间是否超过了设定的时间限制 time_limit 。 time.time() 返回当前时间的时间戳, t 是NMS开始前记录的时间戳。
if (time.time() - t) > time_limit:
# 如果超过时间限制,打印一条警告信息,提示NMS的时间限制被超过了。这里使用了格式化字符串 f-string 来插入 time_limit 的值。
print(f'WARNING: NMS time limit {time_limit}s exceeded')
# 如果超过时间限制,则使用 break 语句退出循环。这意味着NMS处理将不会继续进行,即使还有更多的边界框需要处理。
break # time limit exceeded
# 在循环结束后,返回包含NMS处理结果的输出张量 output 。
return output
38.def strip_optimizer(f='best.pt', s=''):
# 这段代码定义了一个名为 strip_optimizer 的函数,其目的是从训练好的模型文件中移除优化器(optimizer)和其他非必要的信息,以便减小模型文件的大小,使其更适合部署。这个函数还允许你将处理后的模型保存为一个新的文件。
# 定义一个函数 strip_optimizer ,它接受两个参数。
# 1.f :是要处理的模型文件的名称,默认为 'best.pt'。
# 2.s :是处理后的模型文件的保存名称,默认为空字符串,意味着将覆盖原文件。
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer() 从 utils.general 导入 *; strip_optimizer()。
# Strip optimizer from 'f' to finalize training, optionally save as 's' 从“f”中剥离优化器以完成训练,可选择保存为“s”。
# 使用 PyTorch 的 torch.load 函数加载模型文件 f ,并将其映射到 CPU 上。
x = torch.load(f, map_location=torch.device('cpu'))
# 检查加载的模型字典 x 中是否包含 ema (指数移动平均)键。
if x.get('ema'):
# 如果存在 ema ,则用 ema 的值替换模型,这样可以确保保存的是经过 EMA 处理的模型。
x['model'] = x['ema'] # replace model with ema
# 遍历一个元组,包含需要从模型字典中移除的键。
for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
# 将这些键对应的值设置为 None ,从而在保存时不会包含这些信息。
x[k] = None
# 将 epoch 的值设置为 -1 ,表示训练已经结束。
x['epoch'] = -1
# 将模型的权重转换为半精度(FP16),这有助于进一步减小模型文件的大小。
x['model'].half() # to FP16
# 遍历模型的所有参数。
for p in x['model'].parameters():
# 将所有参数的 requires_grad 属性设置为 False ,这样在推理时不会计算梯度。
p.requires_grad = False
# 将处理后的模型字典 x 保存到文件,如果 s 不为空,则保存到 s ,否则覆盖原文件 f 。
torch.save(x, s or f)
# 计算保存后的文件大小(以兆字节为单位)。
mb = os.path.getsize(s or f) / 1E6 # filesize 文件大小。
# 打印一条消息,显示优化器已被移除,并显示处理后的文件大小。
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB") # 优化器从 {f} 中剥离,{(' 另存为 %s,' % s) if s else ''} {mb:.1f}MB。
# 这个函数在模型部署前非常有用,因为它可以减少模型文件的大小,同时确保模型仍然可以用于推理。通过移除优化器和其他不必要的信息,你可以得到一个更轻量级的模型,这对于资源受限的环境(如移动设备或嵌入式系统)尤其重要。
39.def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
# 这段代码定义了一个名为 print_mutation 的函数,它用于处理和记录超参数进化(hyperparameter evolution)的结果。超参数进化是一种自动化机器学习(AutoML)技术,用于优化模型的超参数。
# 这个函数将进化过程中的超参数设置和相应的结果打印出来,并保存到文件中,以便进一步分析和使用。
# 定义一个函数 print_mutation ,接受四个参数。
# 1.hyp :是超参数字典。
# 2.results 是结果元组。
# 3.yaml_file 是保存超参数的 YAML 文件名,默认为 'hyp_evolved.yaml' 。
# 4.bucket 是 Google Cloud Storage 桶的名称,默认为空字符串。
def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
# 将突变结果打印到 evolve.txt(用于 train.py --evolve)。
# Print mutation results to evolve.txt (for use with train.py --evolve)
# 创建一个格式化字符串,包含所有超参数的键名。
a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys 超参数键。
# 创建一个格式化字符串,包含所有超参数的值。
b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values 超参数值。
# 创建一个格式化字符串,包含所有结果值。
c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
# 打印超参数键名、值和结果。
print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c)) # \n%s\n%s\n进化适应度: %s\n。
# 如果提供了 Google Cloud Storage 桶名称,执行以下操作。
if bucket:
# 构建 Google Cloud Storage 中 evolve.txt 文件的 URL。
url = 'gs://%s/evolve.txt' % bucket
# 比较云端和本地 evolve.txt 文件的大小。
# def gsutil_getsize(url=''):
# -> 它用于获取 Google Cloud Storage(GCS)中指定 URL 文件的大小(以字节为单位)。如果输出字符串 s 不为空,则将其分割并取第一个元素(即文件大小),使用 eval 函数将其转换为整数。如果 s 为空,则返回 0。
# -> return eval(s.split(' ')[0]) if len(s) else 0 # bytes
if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
# os.system(command)
# os.system() 函数是 Python 的 os 模块中的一个函数,用于执行指定的命令行字符串。这个函数会调用系统的命令行解释器(通常是 shell)来执行命令,并返回命令执行后的退出状态码。
# 参数说明 :
# command :一个字符串,包含要执行的命令。
# 返回值 :返回命令执行后的退出状态码。在 Unix 和类 Unix 系统中,通常 0 表示成功,非 0 表示失败。在 Windows 系统中,返回值的具体含义取决于命令本身。
# 异常 :
# 如果发生错误(例如,无法找到命令解释器),可能会抛出 OSError 异常。
# 需要注意的是, os.system() 函数会创建一个新的 shell 来执行命令,这意味着它可能会受到当前工作目录的影响,并且不会继承当前 Python 进程的环境变量。
# 此外,由于安全原因,通常不推荐在程序中使用 os.system() ,因为它可能会执行任意命令,存在安全风险。
# 在可能的情况下,可以考虑使用 subprocess 模块中的函数,如 subprocess.run() 或 subprocess.call() ,因为它们提供了更多的控制和安全性。
# 如果云端文件更大,则下载云端文件。
os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local 如果大于本地,则下载 evolve.txt。
# 以追加模式打开 evolve.txt 文件。
with open('evolve.txt', 'a') as f: # append result 追加结果。
# 将结果和超参数写入文件。
f.write(c + b + '\n')
# numpy.loadtxt(fname, dtype=float, comments='#', delimiter=None, converters=None, skiprows=0, usecols=None, unpack=False, ndmin=0)
# np.loadtxt 是 NumPy 库中的一个函数,用于从文本文件中加载数据。这个函数可以读取文本文件中的数值数据,并将它们转换为 NumPy 数组。它提供了灵活的参数设置,以适应不同的数据格式和需求。
# 参数说明 :
# fname :文件名或文件对象。可以是字符串路径、文件路径列表或类似文件的对象。
# dtype :数据类型,默认为 float 。指定加载数据的期望数据类型。
# comments :注释字符,默认为 '#' 。指定用作注释的字符,注释行将被忽略。
# delimiter :分隔符,默认为 None 。指定字段的分隔符,如果为 None ,则自动检测空白字符作为分隔符。
# converters :转换函数字典,默认为 None 。为每列数据提供自定义的转换函数。
# skiprows :跳过的行数,默认为 0 。指定在开始读取数据之前要跳过的行数。
# usecols :使用列的索引或范围,默认为 None 。指定要读取的列。
# unpack :布尔值,默认为 False 。如果为 True ,则返回一个元组,其中包含解包的数组。
# ndmin :输出数组的最小维度,默认为 0 。 0 表示返回一维数组, 1 表示返回二维数组,即使只有一列数据。
# 返回值 :
# 返回一个 NumPy 数组,其形状和数据类型取决于输入文件和参数。
# 异常 :
# 如果文件无法读取或数据格式不正确,可能会抛出异常。
# np.loadtxt 是处理文本数据文件的常用函数,特别适用于那些由数值数据组成的文件,如 CSV 文件。它为数据加载和预处理提供了一个简单而有效的方法。
# 读取 evolve. 文件中的唯一行。
x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows 加载唯一行。
# numpy.argsort(a, axis=-1, kind=None, order=None)
# np.argsort 是 NumPy 库中的一个函数,用于对数组中的元素进行排序,并返回排序后元素的索引。
# 参数说明 :
# a :要排序的数组。
# axis :指定排序的轴,默认为 -1 ,即最后一个轴。如果数组是一维的,则此参数无效。
# kind :指定排序算法的类型,可选的值有 'quicksort' 、 'mergesort' 、 'heapsort' 等,默认为 None ,在这种情况下,NumPy 会自动选择最合适的排序算法。
# order :当数组元素是结构化数组时,此参数用于指定根据哪个字段进行排序。
# 返回值 :
# 返回一个数组,其中包含排序后元素的索引。
# 异常 :
# 如果输入数组 a 不是 NumPy 数组,可能会抛出异常。
# np.argsort 是一个非常有用的函数,它允许你在不改变原始数组的情况下,快速获取排序后的元素索引,这在很多数据处理和分析任务中都非常有用。
# 根据 fitness 函数对结果进行排序。
# def fitness(x):
# -> 它用于计算模型的适应度(fitness),作为一个加权组合的度量指标。计算适应度值。首先, x[:, :4] 从输入数组 x 中选择所有行的前四列,即对应于度量指标的列。然后,将这些度量指标与相应的权重相乘,最后沿数组的第一个轴(即行)求和,得到每个样本的适应度值。
# -> return (x[:, :4] * w).sum(1)
x = x[np.argsort(-fitness(x))] # sort
# numpy.savetxt(fname, X, fmt='%s', delimiter=' ', newline='\n', header='', footer='', comments='# ', encoding=None)
# np.savetxt 是 NumPy 库中的一个函数,用于将 NumPy 数组保存到文本文件中。这个函数提供了多种参数来自定义保存数据的格式和布局。
# 参数说明 :
# fname :文件名或文件对象。可以是字符串路径或类似文件的对象。
# X :要保存的 NumPy 数组。
# fmt :格式字符串,默认为 '%s' 。用于指定保存数据时的格式,可以是单个格式字符串或格式字符串的序列。
# delimiter :分隔符,默认为空格 ' ' 。指定数组元素之间的分隔符。
# newline :新行字符,默认为 '\n' 。指定行与行之间的分隔符。
# header :文件开头的字符串,默认为空字符串。在保存的数据之前写入文件的文本。
# footer :文件末尾的字符串,默认为空字符串。在保存的数据之后写入文件的文本。
# comments :注释字符串,默认为 '# ' 。指定注释字符,用于标记注释行。
# encoding :编码类型,默认为 None 。指定文件的编码类型,如 'utf-8' 。
# 返回值 :
# 无返回值,因为 np.savetxt 直接将数组数据写入到指定的文件中。
# 异常 :
# 如果文件无法写入或数据格式不正确,可能会抛出异常。
# np.savetxt 是一个非常有用的函数,它允许你将数据以文本格式导出,便于在其他程序或环境中使用,如在 Excel 或其他数据分析工具中查看和处理数据。
# 将排序后的结果保存回 evolve.txt 文件。
np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
# 保存超参数到 YAML 文件。
# Save yaml
# 遍历超参数字典的键。
for i, k in enumerate(hyp.keys()):
# 将排序后的第一行(最佳结果)的超参数值赋值给 hyp 字典。
hyp[k] = float(x[0, i + 7])
# 以写入模式打开 YAML 文件。
with open(yaml_file, 'w') as f:
# 获取排序后的第一行结果。
results = tuple(x[0, :7])
# 创建一个格式化字符串,包含所有结果值。
c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
# 写入 YAML 文件的头部信息。
f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n') # 超参数演化结果\n# 代:%g\n# 指标:
# json.dump(obj, fp, *, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, cls=None, indent=None, separators=None, default=None, sort_keys=False, **kw)
# 在Python中, dump() 函数通常与序列化相关,特别是在处理JSON数据时。 dump() 函数可以将Python对象序列化为JSON格式的字符串。这个函数是 json 模块的一部分,用于将Python的数据结构转换为JSON格式。
# 参数说明 :
# obj :要序列化的Python对象。
# fp :一个文件类对象,具有 .write() 方法,例如 open() 函数返回的文件对象。
# skipkeys :如果为 True ,则 dict 类型的键值对中,如果键不是字符串,则跳过该键值对。
# ensure_ascii :如果为 True ,则所有非ASCII字符将被转义。
# check_circular :如果为 True ,则检查循环引用的对象,如果发现循环引用,则抛出 TypeError 。
# allow_nan :如果为 True ,则允许序列化 NaN 、 Infinity 和 -Infinity 。
# cls :一个自定义的 JSONEncoder 类,用于自定义序列化过程。
# indent :用于美化输出的缩进级别,如果为 None ,则不美化输出。
# separators :一个元组,包含分隔符,用于控制美化输出时的行和值之间的分隔符。
# default :一个函数,用于处理无法被序列化的类型。
# sort_keys :如果为 True ,则在美化输出时,字典的键将被排序。
# **kw :传递给 JSONEncoder 的其他参数。
# 返回值 :
# 无返回值,因为 dump() 直接将序列化后的JSON数据写入到提供的文件对象中。
# 异常 :
# 如果在序列化过程中遇到无法被序列化的类型,且未提供 default 函数,则抛出 TypeError 。
# 需要注意的是, dump() 函数是用于写入文件的,如果你需要获取JSON字符串而不是写入文件,可以使用 json.dumps() 函数,它与 json.dump() 类似,但返回序列化后的JSON字符串而不是写入文件。
# 将超参数字典写入 YAML 文件。
yaml.dump(hyp, f, sort_keys=False)
# 如果提供了 Google Cloud Storage 桶名称,执行以下操作。
if bucket:
# 上传 evolve.txt 和 YAML 文件到云端。
os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
# 这个函数的主要作用是处理超参数进化的结果,将最佳超参数和结果保存到本地文件和云端存储中,以便进一步分析和使用。
40.def apply_classifier(x, model, img, im0):
# 这段代码定义了一个名为 apply_classifier 的函数,它用于对 YOLO(You Only Look Once)目标检测模型的输出应用第二阶段分类器。这个函数的目的是提高检测的准确性,特别是在需要对检测到的目标进行分类的场景中。
# 定义一个函数 apply_classifier ,接受四个参数。
# 1.x :是 YOLO 输出的检测结果。
# 2.model :是用于分类的模型。
# 3.img :是原始图像的尺寸。
# 4.im0 :是预处理后的图像。
def apply_classifier(x, model, img, im0):
# isinstance(object, classinfo)
# isinstance() 是 Python 内置的一个函数,用于检查一个对象是否是一个已知的类型或者是从该类型派生出来的子类的一个实例。
# 参数说明 :
# object :要检查的对象。
# classinfo :要检查的对象类型,可以是一个类或者包含多个类的元组。
# 返回值 :
# 如果 object 是 classinfo 的实例或者是 classinfo 子类的实例,则返回 True ;否则返回 False 。
# isinstance() 函数在类型检查时比直接使用 type() 函数更推荐,因为它支持类的继承检查。如果 classinfo 是一个元组, isinstance() 会检查 object 是否是元组中任一类型的实例。这使得 isinstance() 成为一个灵活且强大的工具,用于在 Python 程序中进行类型检查。
# 将第二阶段分类器应用于 yolo 输出。
# applies a second stage classifier to yolo outputs
# 如果 im0 是一个 NumPy 数组,则将其转换为列表,以便于后续处理。
im0 = [im0] if isinstance(im0, np.ndarray) else im0
# 遍历 x 中的每个元素,每个元素代表一张图像的检测结果。
for i, d in enumerate(x): # per image
# 如果检测结果 d 不为空,则继续处理。
if d is not None and len(d):
# torch.clone(input, *, memory_format=torch.preserve_format)
# 在 PyTorch 中, clone() 函数是用于创建一个张量(tensor)的副本的函数。这个副本与原始张量共享数据,但它们在内存中是独立的,这意味着对副本的修改不会影响原始张量,反之亦然。
# 参数说明 :
# input :要克隆的张量。
# memory_format :可选参数,用于指定内存格式。默认为 torch.preserve_format ,意味着克隆的张量将保留原始张量的内存格式。如果设置为 torch.contiguous_format ,则克隆的张量将被重新排列为连续的内存格式。
# 返回值 :
# 返回一个新的张量,它是输入张量的副本。
# clone() 函数在需要保留原始数据不变时非常有用,比如在训练模型时,你可能需要保存模型的原始状态,同时对模型的参数进行修改或实验。
# 创建检测结果 d 的副本,以避免修改原始数据。
d = d.clone()
# Reshape and pad cutouts 重塑和填充切口。
# 将检测框的坐标从 (x1, y1, x2, y2) 格式转换为 (x, y, w, h) 格式。
# def xyxy2xywh(x): -> 它用于将边界框的坐标从 xyxy 格式(即左上角和右下角的坐标)转换为 xywh 格式(即中心点坐标加上宽度和高度)。返回结果。返回转换后的 xywh 格式的边界框坐标。 -> return y
b = xyxy2xywh(d[:, :4]) # boxes
# 将检测框调整为正方形。
# 用于处理一个名为 b 的张量,它包含了边界框(bounding boxes)的坐标信息。具体来说,这行代码将 b 张量中每个边界框的宽度和高度调整为该边界框最大维度的值,并将这个最大维度值扩展为正方形的边界框。
# b[:, 2:] :这部分选择了张量 b 中所有行的最后两列,即每个边界框的宽度(第3列)和高度(第4列)。
# .max(1) :沿着维度1(列)计算最大值,即对于每个边界框,找到宽度和高度中的最大值。这将返回一个包含每个边界框最大维度值的张量,其形状为 (n,) ,其中 n 是边界框的数量。
# [0] :从上一步返回的最大值张量中取出第一个元素,即最大值本身。这是因为 max() 函数返回的是一个元组,其中第一个元素是最大值,第二个元素是最大值的索引。
# .unsqueeze(1) unsqueeze 函数用于在指定维度上增加一个维度,从而使张量的形状发生变化。在这里, unsqueeze(1) 在第二个维度(索引为1)上增加一个维度,将形状从 (n,) 变为 (n, 1) 。这样做是为了使张量可以与原始的 b 张量进行广播操作。
# b[:, 2:] = ... :最后,将修改后的宽度和高度值赋值回 b 张量的对应列。
# 这行代码的作用是将每个边界框的宽度和高度统一设置为该边界框的最大维度值,从而将每个边界框转换为正方形。这种操作在某些情况下是有用的,比如在将边界框输入到一个需要固定尺寸输入的模型之前,或者在进行某些图像处理任务时。
# 在YOLOv7的 general.py 文件中的 apply_classifier 函数里,将检测框调整为正方形的原因主要是为了确保输入到分类器的图像区域具有统一的形状和尺寸。
# 这样做有几个好处 :
# 1. 统一输入尺寸:大多数深度学习模型,尤其是分类器,都要求输入数据具有固定的尺寸。将检测框调整为正方形可以确保从原图中裁剪出来的图像区域在送入分类器之前具有统一的尺寸,便于模型处理。
# 2. 简化处理流程:通过将检测框调整为正方形,可以简化后续的处理流程。例如,不需要为不同长宽比的检测框编写特殊的处理代码,从而降低代码复杂度。
# 3. 提高分类器性能:分类器通常对输入图像的尺寸和比例有一定的要求。通过将检测框调整为正方形,可以确保分类器接收到的图像区域在空间尺寸上是一致的,这有助于提高分类器的性能和准确性。
# 4. 减少边界效应:在某些情况下,非正方形的检测框可能会导致边界效应,即分类器对图像边缘的响应不如中心区域强烈。通过将检测框调整为正方形,可以减少这种边界效应,使得分类器对整个图像区域的响应更加均匀。
# 5. 便于后续操作:在实际应用中,将检测框调整为正方形后,便于进行后续的操作,如图像增强、特征提取等,因为这些操作往往需要固定尺寸的输入。
# 这里, xyxy2xywh 函数将检测框的坐标从 (x1, y1, x2, y2) 格式转换为 (x, y, w, h) 格式,然后通过取宽度和高度的最大值并将它们设置为相等,从而将矩形检测框调整为正方形。这样的处理是为了确保后续分类器的输入一致性。
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square 矩形转正方形。
# 对正方形检测框进行缩放和填充。
b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
# 将调整后的检测框坐标转换回 (x1, y1, x2, y2) 格式。
# def xywh2xyxy(x): -> 用于将边界框的坐标从中心点坐标加宽高(x, y, w, h)格式转换为两个角点坐标(x1, y1, x2, y2)格式。返回转换后的坐标数组。 -> return y
d[:, :4] = xywh2xyxy(b).long()
# Rescale boxes from img_size to im0 size 将框从 img_size 重新缩放为 im0 大小。
# 将坐标从 img 的尺寸调整到 im0 的尺寸。
# def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
# -> 它用于将坐标从一张图片的形状 ( img1_shape ) 缩放到另一张图片的形状 ( img0_shape )。返回调整后的坐标。
# -> return coords
scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
# Classes
# 获取预测的类别。
pred_cls1 = d[:, 5].long()
# 创建一个空列表,用于存储裁剪后的图像。
ims = []
# 遍历每个检测结果。
for j, a in enumerate(d): # per item
# 根据检测框坐标裁剪图像。
cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
# 将裁剪的图像调整为分类器输入的大小。
im = cv2.resize(cutout, (224, 224)) # BGR
# cv2.imwrite('test%i.jpg' % j, cutout)
# 将图像从 BGR 转换为 RGB,并调整维度。
im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
# 将图像数据类型转换为浮点数。
im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
# 将图像像素值归一化到 [0, 1] 范围内。
im /= 255.0 # 0 - 255 to 0.0 - 1.0
# 将处理后的图像添加到列表中。
ims.append(im)
# numpy.argmax(a, axis=None, out=None)
# 在Python的NumPy库中, np.argmax() 函数用于返回数组中最大值的索引。
# 参数 :
# a :输入数组。
# axis :可选参数,指定沿哪个轴寻找最大值。如果为 None ,则在扁平化后的数组中寻找最大值。
# out :可选参数,用于存放输出结果的数组。
# 返回值 :
# 返回数组中最大值的索引。
# 注意事项 :
# 如果数组中有多个相同的最大值, np.argmax() 会返回第一个遇到的最大值的索引。
# 如果输入数组为空, np.argmax() 会抛出异常。
# np.argmax() 是NumPy中处理数组时非常常用的函数,可以帮助我们快速找到数组中的最大值及其索引。
# 使用分类器模型对图像进行预测,并获取预测类别。
pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction 分类器预测。
# 只保留预测类别与分类器预测类别匹配的检测结果。
x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections 保留匹配的类别检测。
# 返回处理后的检测结果。
return x
# 这个函数通过应用第二阶段分类器来提高检测的准确性,特别是在目标类别较多或需要更精细分类时。通过裁剪、调整尺寸、归一化和分类,最终只保留与分类器预测类别匹配的检测结果。
41.def increment_path(path, exist_ok=True, sep=''):
# 这段代码定义了一个名为 increment_path 的函数,它用于生成一个唯一的文件路径。如果指定的路径已经存在,函数会通过添加一个数字后缀来创建一个新的路径。
# 函数定义。
# 1.path :要检查和增量的路径字符串。
# 2.exist_ok :一个布尔值,指示如果路径已经存在是否允许函数返回原始路径而不是生成一个新的路径。
# 3.sep :用于分隔路径和数字后缀的字符串,默认为空字符串。
def increment_path(path, exist_ok=True, sep=''):
# 增加路径,即运行/exp --> 运行/exp{sep}0、运行/exp{sep}1 等。
# Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
# 路径转换为 Path 对象。使用 Path 类将字符串路径转换为一个 pathlib 路径对象,使其跨操作系统兼容。
path = Path(path) # os-agnostic
# 检查路径是否存在。
# 如果路径存在且 exist_ok 为 True ,或者路径不存在,则直接返回原始路径。
if (path.exists() and exist_ok) or (not path.exists()):
return str(path)
# 生成新的路径。
else:
# 如果路径存在且 exist_ok 为 False ,则使用 glob.glob 搜索所有与指定路径相似的路径。
dirs = glob.glob(f"{path}{sep}*") # similar paths
# 使用正则表达式 re.search 从这些路径中提取数字后缀。
# path.stem 获取路径的基本名称,不包括扩展名。
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
# i 是一个包含所有提取的数字后缀的列表。
i = [int(m.groups()[0]) for m in matches if m] # indices
# n 是最大的数字后缀加1,如果没有找到任何数字后缀,则默认为2。
n = max(i) + 1 if i else 2 # increment number
# 返回一个新的路径,它是原始路径加上一个更新的数字后缀。
return f"{path}{sep}{n}" # update path