YOLOv11-ultralytics-8.3.67部分代码阅读笔记-plotting.py
plotting.py
ultralytics\utils\plotting.py
目录
plotting.py
1.所需的库和模块
2.class Colors:
3.class Annotator:
4.def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
5.def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
6.def plot_images(images: Union[torch.Tensor, np.ndarray], batch_idx: Union[torch.Tensor, np.ndarray], cls: Union[torch.Tensor, np.ndarray], bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32), confs: Optional[Union[torch.Tensor, np.ndarray]] = None, masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8), kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32), paths: Optional[List[str]] = None, fname: str = "images.jpg", names: Optional[Dict[int, str]] = None, on_plot: Optional[Callable] = None, max_size: int = 1920, max_subplots: int = 16, save: bool = True, conf_thres: float = 0.25, ) -> Optional[np.ndarray]:
7.def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
8.def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
9.def plot_tune_results(csv_file="tune_results.csv"):
10.def output_to_target(output, max_det=300):
11.def output_to_rotated_target(output, max_det=300):
12.def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
1.所需的库和模块
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import math
import warnings
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from PIL import __version__ as pil_version
from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded
from ultralytics.utils.checks import check_font, check_version, is_ascii
from ultralytics.utils.files import increment_path
2.class Colors:
# 这段代码定义了一个名为 Colors 的类,用于管理颜色的调色板,并提供颜色的转换和访问功能。
# 定义了一个名为 Colors 的类。
class Colors:
# Ultralytics 调色板 https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors。
# 此类提供使用 Ultralytics 调色板的方法,包括将十六进制颜色代码转换为 RGB 值。
# ## Ultralytics 调色板
# | Index 索引 | Color 颜色 | HEX | RGB |
# ## 姿势调色板
# | Index 索引 | Color 颜色 | HEX | RGB |
"""
Ultralytics color palette https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors.
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
RGB values.
Attributes:
palette (list of tuple): List of RGB color values.
n (int): The number of colors in the palette.
pose_palette (np.ndarray): A specific color palette array with dtype np.uint8.
## Ultralytics Color Palette
| Index | Color | HEX | RGB |
|-------|-------------------------------------------------------------------|-----------|-------------------|
| 0 | <i class="fa-solid fa-square fa-2xl" style="color: #042aff;"></i> | `#042aff` | (4, 42, 255) |
| 1 | <i class="fa-solid fa-square fa-2xl" style="color: #0bdbeb;"></i> | `#0bdbeb` | (11, 219, 235) |
| 2 | <i class="fa-solid fa-square fa-2xl" style="color: #f3f3f3;"></i> | `#f3f3f3` | (243, 243, 243) |
| 3 | <i class="fa-solid fa-square fa-2xl" style="color: #00dfb7;"></i> | `#00dfb7` | (0, 223, 183) |
| 4 | <i class="fa-solid fa-square fa-2xl" style="color: #111f68;"></i> | `#111f68` | (17, 31, 104) |
| 5 | <i class="fa-solid fa-square fa-2xl" style="color: #ff6fdd;"></i> | `#ff6fdd` | (255, 111, 221) |
| 6 | <i class="fa-solid fa-square fa-2xl" style="color: #ff444f;"></i> | `#ff444f` | (255, 68, 79) |
| 7 | <i class="fa-solid fa-square fa-2xl" style="color: #cced00;"></i> | `#cced00` | (204, 237, 0) |
| 8 | <i class="fa-solid fa-square fa-2xl" style="color: #00f344;"></i> | `#00f344` | (0, 243, 68) |
| 9 | <i class="fa-solid fa-square fa-2xl" style="color: #bd00ff;"></i> | `#bd00ff` | (189, 0, 255) |
| 10 | <i class="fa-solid fa-square fa-2xl" style="color: #00b4ff;"></i> | `#00b4ff` | (0, 180, 255) |
| 11 | <i class="fa-solid fa-square fa-2xl" style="color: #dd00ba;"></i> | `#dd00ba` | (221, 0, 186) |
| 12 | <i class="fa-solid fa-square fa-2xl" style="color: #00ffff;"></i> | `#00ffff` | (0, 255, 255) |
| 13 | <i class="fa-solid fa-square fa-2xl" style="color: #26c000;"></i> | `#26c000` | (38, 192, 0) |
| 14 | <i class="fa-solid fa-square fa-2xl" style="color: #01ffb3;"></i> | `#01ffb3` | (1, 255, 179) |
| 15 | <i class="fa-solid fa-square fa-2xl" style="color: #7d24ff;"></i> | `#7d24ff` | (125, 36, 255) |
| 16 | <i class="fa-solid fa-square fa-2xl" style="color: #7b0068;"></i> | `#7b0068` | (123, 0, 104) |
| 17 | <i class="fa-solid fa-square fa-2xl" style="color: #ff1b6c;"></i> | `#ff1b6c` | (255, 27, 108) |
| 18 | <i class="fa-solid fa-square fa-2xl" style="color: #fc6d2f;"></i> | `#fc6d2f` | (252, 109, 47) |
| 19 | <i class="fa-solid fa-square fa-2xl" style="color: #a2ff0b;"></i> | `#a2ff0b` | (162, 255, 11) |
## Pose Color Palette
| Index | Color | HEX | RGB |
|-------|-------------------------------------------------------------------|-----------|-------------------|
| 0 | <i class="fa-solid fa-square fa-2xl" style="color: #ff8000;"></i> | `#ff8000` | (255, 128, 0) |
| 1 | <i class="fa-solid fa-square fa-2xl" style="color: #ff9933;"></i> | `#ff9933` | (255, 153, 51) |
| 2 | <i class="fa-solid fa-square fa-2xl" style="color: #ffb266;"></i> | `#ffb266` | (255, 178, 102) |
| 3 | <i class="fa-solid fa-square fa-2xl" style="color: #e6e600;"></i> | `#e6e600` | (230, 230, 0) |
| 4 | <i class="fa-solid fa-square fa-2xl" style="color: #ff99ff;"></i> | `#ff99ff` | (255, 153, 255) |
| 5 | <i class="fa-solid fa-square fa-2xl" style="color: #99ccff;"></i> | `#99ccff` | (153, 204, 255) |
| 6 | <i class="fa-solid fa-square fa-2xl" style="color: #ff66ff;"></i> | `#ff66ff` | (255, 102, 255) |
| 7 | <i class="fa-solid fa-square fa-2xl" style="color: #ff33ff;"></i> | `#ff33ff` | (255, 51, 255) |
| 8 | <i class="fa-solid fa-square fa-2xl" style="color: #66b2ff;"></i> | `#66b2ff` | (102, 178, 255) |
| 9 | <i class="fa-solid fa-square fa-2xl" style="color: #3399ff;"></i> | `#3399ff` | (51, 153, 255) |
| 10 | <i class="fa-solid fa-square fa-2xl" style="color: #ff9999;"></i> | `#ff9999` | (255, 153, 153) |
| 11 | <i class="fa-solid fa-square fa-2xl" style="color: #ff6666;"></i> | `#ff6666` | (255, 102, 102) |
| 12 | <i class="fa-solid fa-square fa-2xl" style="color: #ff3333;"></i> | `#ff3333` | (255, 51, 51) |
| 13 | <i class="fa-solid fa-square fa-2xl" style="color: #99ff99;"></i> | `#99ff99` | (153, 255, 153) |
| 14 | <i class="fa-solid fa-square fa-2xl" style="color: #66ff66;"></i> | `#66ff66` | (102, 255, 102) |
| 15 | <i class="fa-solid fa-square fa-2xl" style="color: #33ff33;"></i> | `#33ff33` | (51, 255, 51) |
| 16 | <i class="fa-solid fa-square fa-2xl" style="color: #00ff00;"></i> | `#00ff00` | (0, 255, 0) |
| 17 | <i class="fa-solid fa-square fa-2xl" style="color: #0000ff;"></i> | `#0000ff` | (0, 0, 255) |
| 18 | <i class="fa-solid fa-square fa-2xl" style="color: #ff0000;"></i> | `#ff0000` | (255, 0, 0) |
| 19 | <i class="fa-solid fa-square fa-2xl" style="color: #ffffff;"></i> | `#ffffff` | (255, 255, 255) |
!!! note "Ultralytics Brand Colors"
For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand). Please use the official Ultralytics colors for all marketing materials.
"""
# 定义了类的初始化方法 __init__ ,当创建类的实例时会自动调用此方法。
def __init__(self):
# 将颜色初始化为 hex = matplotlib.colors.TABLEAU_COLORS.values()。
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
# 定义了一个元组 hexs ,存储了一系列颜色的十六进制字符串。这些颜色将被用于后续生成调色板。
hexs = (
# 列出了十六进制颜色代码,每个颜色代码是一个6位的十六进制字符串,例如 "042AFF" 表示一种蓝色。
"042AFF",
"0BDBEB",
"F3F3F3",
"00DFB7",
"111F68",
"FF6FDD",
"FF444F",
"CCED00",
"00F344",
"BD00FF",
"00B4FF",
"DD00BA",
"00FFFF",
"26C000",
"01FFB3",
"7D24FF",
"7B0068",
"FF1B6C",
"FC6D2F",
"A2FF0B",
)
# 使用列表推导式将 hexs 中的每个十六进制颜色字符串转换为 RGB 格式的元组,并存储在 self.palette 列表中。 self.hex2rgb(f"#{c}") 调用了类中的 hex2rgb 方法,将十六进制颜色转换为 RGB 格式。
self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
# 计算调色板中颜色的数量,并将其存储在 self.n 中。
self.n = len(self.palette)
# 定义了一个 NumPy 数组 self.pose_palette ,用于存储另一组预定义的颜色。
self.pose_palette = np.array(
[
# 列出了颜色的 RGB 值,每个颜色是一个包含三个整数(红、绿、蓝)的列表。
[255, 128, 0],
[255, 153, 51],
[255, 178, 102],
[230, 230, 0],
[255, 153, 255],
[153, 204, 255],
[255, 102, 255],
[255, 51, 255],
[102, 178, 255],
[51, 153, 255],
[255, 153, 153],
[255, 102, 102],
[255, 51, 51],
[153, 255, 153],
[102, 255, 102],
[51, 255, 51],
[0, 255, 0],
[0, 0, 255],
[255, 0, 0],
[255, 255, 255],
],
# 指定 self.pose_palette 数组的数据类型为 uint8 ,即无符号8位整数,范围是 0 到 255。
dtype=np.uint8,
)
# 定义了一个特殊方法 __call__ ,允许类的实例像函数一样被调用。 参数 :
# 1.i :颜色索引, bgr 是一个布尔值,用于指定是否返回 BGR 格式。
def __call__(self, i, bgr=False):
# 将十六进制颜色代码转换为 RGB 值。
"""Converts hex color codes to RGB values."""
# 从调色板中获取颜色索引 i 对应的颜色。 使用 int(i) % self.n 确保索引在调色板范围内,即使索引超出范围也会循环。
c = self.palette[int(i) % self.n]
# 如果 bgr 参数为 True ,则将 RGB 格式的颜色转换为 BGR 格式并返回。
# 如果 bgr 为 False ,则直接返回 RGB 格式。
return (c[2], c[1], c[0]) if bgr else c
# 定义了一个静态方法 hex2rgb ,该方法不依赖于类实例的属性。
@staticmethod
# 定义了静态方法 hex2rgb ,用于将十六进制颜色字符串转换为 RGB 格式。
def hex2rgb(h):
# 将十六进制颜色代码转换为 RGB 值(即默认 PIL 顺序)。
"""Converts hex color codes to RGB values (i.e. default PIL order)."""
# 使用列表推导式将十六进制颜色字符串中的每两个字符(分别表示红、绿、蓝)转换为整数,并返回一个包含三个整数的元组。
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
# 这段代码实现了一个颜色管理类 Colors ,主要用于生成和管理颜色调色板。它提供了以下功能。初始化调色板:通过将十六进制颜色字符串转换为 RGB 格式,生成一个颜色调色板。颜色索引访问:通过索引访问调色板中的颜色,并支持返回 RGB 或 BGR 格式。十六进制到 RGB 的转换:提供了一个静态方法 hex2rgb ,用于将十六进制颜色字符串转换为 RGB 格式。额外的预定义颜色:定义了一个额外的 NumPy 数组 pose_palette ,用于存储另一组预定义的颜色,用于特定的用途(如姿态估计等)。这个类在需要管理多种颜色的场景中非常有用,例如在图像处理、可视化或机器学习任务中。
# 这行代码的作用是创建了一个 Colors 类的实例,并将其命名为 colors 。
# Colors() :调用了 Colors 类的构造函数 __init__ ,创建了一个 Colors 类的实例。
# colors :将创建的实例赋值给变量 colors ,以便后续可以通过这个变量访问类的方法和属性。
colors = Colors() # create instance for 'from utils.plots import colors' 为'from utils.plots import colors'创建实例。
# 这行代码的作用是初始化一个 Colors 类的实例,并将其存储在变量 colors 中。这个实例可以用于后续的颜色管理操作,例如通过索引获取颜色、将十六进制颜色转换为 RGB 格式等。
# 使用场景 :
# 在实际应用中, Colors 类通常用于图像可视化或颜色管理任务。例如,在一个图像处理项目中,可能需要从一个预定义的调色板中选择颜色来绘制边界框、关键点或其他图形元素。通过创建 Colors 类的实例,可以方便地访问和使用这些颜色。
3.class Annotator:
# 这段代码定义了一个名为 Annotator 的类,用于在图像上进行标注和绘制各种图形、文本和关键点等。
# 定义了一个名为 Annotator 的类,用于在图像上进行标注和绘制。
class Annotator:
# Ultralytics Annotator 用于训练/验证马赛克和 JPG 以及预测注释。
"""
Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
Attributes:
im (Image.Image or numpy array): The image to annotate.
pil (bool): Whether to use PIL or cv2 for drawing annotations.
font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations.
lw (float): Line width for drawing.
skeleton (List[List[int]]): Skeleton structure for keypoints.
limb_color (List[int]): Color palette for limbs.
kpt_color (List[int]): Color palette for keypoints.
"""
# 这段代码是 Annotator 类的初始化方法 __init__ 的实现,用于设置标注器的基本属性和初始化绘图环境。
# 定义了 Annotator 类的初始化方法,接受以下参数 :
# 1.im :输入图像,可以是 PIL 图像或 NumPy 数组。
# 2.line_width :绘制线条的宽度,默认为 None 。
# 3.font_size :字体大小,默认为 None 。
# 4.font :字体文件路径,默认为 "Arial.ttf"。
# 5.pil :是否使用 PIL 进行绘制,默认为 False 。
# 6.example :示例文本,用于检查是否包含非 ASCII 字符,默认为 "abc"。
def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
# 使用图像和线宽以及关键点和肢体的调色板初始化 Annotator 类。
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
# 检查 example 参数是否包含非 ASCII 字符(例如亚洲、阿拉伯或西里尔文字)。如果包含非 ASCII 字符,则 non_ascii 为 True 。
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
# 检查输入图像是否为 PIL 图像。如果是,则 input_is_pil 为 True 。
input_is_pil = isinstance(im, Image.Image)
# 根据参数和输入图像类型,决定是否使用 PIL 进行绘制。
# 如果 pil=True ,则使用 PIL。
# 如果 example 包含非 ASCII 字符,则使用 PIL(因为 PIL 支持 Unicode 字符)。
# 如果输入图像已经是 PIL 图像,则使用 PIL。
self.pil = pil or non_ascii or input_is_pil
# 设置 绘制线条的宽度 。 如果 line_width 参数已指定,则使用该值。 如果未指定,则根据图像尺寸动态计算线条宽度,确保线条宽度至少为 2。
self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
# 如果使用 PIL 绘制。
if self.pil: # use PIL
# 如果输入图像是 PIL 图像,则直接使用。 如果输入图像是 NumPy 数组,则将其转换为 PIL 图像。
self.im = im if input_is_pil else Image.fromarray(im)
# 创建一个 ImageDraw.Draw 对象,用于在图像上绘制图形和文本。
self.draw = ImageDraw.Draw(self.im)
# 尝试加载字体文件。
try:
# 如果 example 包含非 ASCII 字符,则加载支持 Unicode 的字体 "Arial.Unicode.ttf"。
font = check_font("Arial.Unicode.ttf" if non_ascii else font)
# 如果未指定 font_size ,则根据图像尺寸动态计算字体大小,确保字体大小至少为 12。
size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
self.font = ImageFont.truetype(str(font), size)
# 当 self.pil 为 False 时,即选择使用 OpenCV 进行图像绘制时,执行以下代码。
except Exception:
self.font = ImageFont.load_default()
# Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
# 修复 PIL 9.2.0 版本中 getsize 方法的弃用问题。从 getbbox 方法中提取文本的宽度和高度。
if check_version(pil_version, "9.2.0"):
self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
# 如果使用 OpenCV 绘制。
else: # use cv2
# 检查输入图像 im 的数据是否是连续的( contiguous )。在 OpenCV 中,图像数据通常需要是连续的,以确保高效的内存访问和操作。如果图像数据不是连续的,可能会导致错误或性能问题。这里使用 assert 语句进行检查,并在数据不连续时提供提示,建议使用 np.ascontiguousarray(im) 将图像数据转换为连续的格式。
assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." # 图像不连续。将 np.ascontiguousarray(im) 应用于 Annotator 输入图像。
# 根据输入图像 im 的可写性( writeable )来决定是否需要创建副本。
# 如果输入图像 im 是可写的( im.flags.writeable 为 True ),则直接使用原始图像 im 。
# 如果输入图像是只读的( im.flags.writeable 为 False ),则创建一个可写的副本( im.copy() )。这是因为 OpenCV 的某些操作需要对图像进行修改,而只读图像无法直接修改。
self.im = im if im.flags.writeable else im.copy()
# 计算 字体厚度 ( font thickness ),并将其存储在 self.tf 中。 字体厚度基于线条宽度( self.lw )计算,公式为 self.lw - 1 。 使用 max 函数确保字体厚度至少为 1,避免字体过细而难以看清。
self.tf = max(self.lw - 1, 1) # font thickness
# 计算 字体缩放比例 ( font scale ),并将其存储在 self.sf 中。 字体缩放比例基于线条宽度( self.lw )计算,公式为 self.lw / 3 。 这种计算方式可以确保字体大小与线条宽度成比例,从而使标注在视觉上更加协调。
self.sf = self.lw / 3 # font scale
# Pose
# 定义了 人体关键点的骨架结构 ,用于 绘制关键点之间的连接线 。每个元素是一个包含两个关键点索引的列表,表示这两个关键点之间有一条连接线。
self.skeleton = [
[16, 14],
[14, 12],
[17, 15],
[15, 13],
[12, 13],
[6, 12],
[7, 13],
[6, 7],
[6, 8],
[7, 9],
[8, 10],
[9, 11],
[2, 3],
[1, 2],
[1, 3],
[2, 4],
[3, 5],
[4, 6],
[5, 7],
]
# 定义了 肢体 和 关键点 的颜色调色板。
# limb_color 用于绘制 肢体连接线的颜色 。
self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
# kpt_color 用于绘制 关键点的颜色 。
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
# 定义了深色和浅色的颜色集合,用于根据背景颜色选择合适的文本颜色。
# dark_colors 深色背景颜色集合。
self.dark_colors = {
(235, 219, 11),
(243, 243, 243),
(183, 223, 0),
(221, 111, 255),
(0, 237, 204),
(68, 243, 0),
(255, 255, 0),
(179, 255, 1),
(11, 255, 162),
}
# light_colors 浅色背景颜色集合。
self.light_colors = {
(255, 42, 4),
(79, 68, 255),
(255, 0, 189),
(255, 180, 0),
(186, 0, 221),
(0, 192, 38),
(255, 36, 125),
(104, 0, 123),
(108, 27, 255),
(47, 109, 252),
(104, 31, 17),
}
# 这段代码实现了 Annotator 类的初始化方法,主要功能包括。输入图像处理:支持 PIL 图像和 NumPy 数组作为输入,并根据输入类型和参数决定是否使用 PIL 或 OpenCV 进行绘制。动态参数计算:根据图像尺寸动态计算线条宽度和字体大小,确保标注效果在不同尺寸的图像上都能保持一致。字体加载:支持加载指定字体文件,并在加载失败时回退到默认字体。骨架结构定义:定义了人体关键点的骨架结构,用于绘制关键点之间的连接线。颜色调色板定义:定义了肢体和关键点的颜色调色板,以及用于选择文本颜色的深色和浅色集合。这些初始化操作为后续的标注功能(如绘制边界框、文本标签、关键点等)提供了基础支持。
# 这段代码定义了 Annotator 类中的一个方法 get_txt_color ,用于根据背景颜色选择合适的文本颜色。
# 定义了一个方法 get_txt_color ,接受两个参数。
# 1.color :背景颜色,默认值为 (128, 128, 128) ,即灰度颜色。
# 2.txt_color :默认的文本颜色,默认值为 (255, 255, 255) ,即白色。
# 该方法的目的是根据背景颜色 color 自动选择一个合适的文本颜色,以确保文本在背景上清晰可见。
def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
# 根据背景颜色分配文本颜色。
"""
Assign text color based on background color.
Args:
color (tuple, optional): The background color of the rectangle for text (B, G, R).
txt_color (tuple, optional): The color of the text (R, G, B).
Returns:
txt_color (tuple): Text color for label
"""
# 如果 背景颜色 color 在 self.dark_colors 集合中(即背景颜色是深色),则返回一个浅色的文本颜色 (104, 31, 17) 。这种颜色组合可以确保文本在深色背景上清晰可见。
if color in self.dark_colors:
return 104, 31, 17
# 如果 背景颜色 color 在 self.light_colors 集合中(即背景颜色是浅色),则返回白色 (255, 255, 255) 作为文本颜色。白色文本在浅色背景上通常具有良好的对比度。
elif color in self.light_colors:
return 255, 255, 255
# 如果背景颜色既不在 self.dark_colors 也不在 self.light_colors 中,则返回默认的文本颜色 txt_color 。这为其他颜色提供了灵活性,允许用户自定义文本颜色。
else:
return txt_color
# 这个方法的主要目的是根据背景颜色自动选择合适的文本颜色,以确保文本在不同背景颜色下都能清晰可见。它通过以下逻辑实现。深色背景:如果背景颜色是深色(在 self.dark_colors 中),则选择一个浅色文本颜色 (104, 31, 17) 。浅色背景:如果背景颜色是浅色(在 self.light_colors 中),则选择白色 (255, 255, 255) 作为文本颜色。其他颜色:如果背景颜色不属于上述两种情况,则返回用户指定的默认文本颜色 txt_color 。这种方法可以有效避免在深色背景上使用深色文本或在浅色背景上使用浅色文本,从而提高标注的可读性。
# 这段代码定义了 Annotator 类中的一个方法 circle_label ,用于在给定的边界框内绘制一个圆形背景的文本标签。
# 定义了一个方法 circle_label ,接受以下参数 :
# 1.box :边界框的坐标,格式为 (x1, y1, x2, y2) 。
# 2.label :要显示的文本标签,默认为空字符串。
# 3.color :圆形背景的颜色,默认为灰色 (128, 128, 128) 。
# 4.txt_color :文本颜色,默认为白色 (255, 255, 255) 。
# 5.margin :文本与圆形边缘的间距,默认为 2。
def circle_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=2):
# 绘制一个带有背景圆的标签,背景圆位于给定边界框的中心。
"""
Draws a label with a background circle centered within a given bounding box.
Args:
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
label (str): The text label to be displayed.
color (tuple, optional): The background color of the rectangle (B, G, R).
txt_color (tuple, optional): The color of the text (R, G, B).
margin (int, optional): The margin between the text and the rectangle border.
"""
# If label have more than 3 characters, skip other characters, due to circle size
# 如果标签长度超过 3 个字符,则只取前 3 个字符。这是因为圆形背景的大小有限,无法容纳过多字符。同时,打印一条提示信息,告知用户标签被截断。
if len(label) > 3:
print(
f"Length of label is {len(label)}, initial 3 label characters will be considered for circle annotation!" # 标签的长度为 {len(label)},最初的 3 个标签字符将被视为圆形注释!
)
label = label[:3]
# Calculate the center of the box
# 计算 边界框的中心点坐标 (x_center, y_center) ,用于确定圆形背景的中心位置。
x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
# Get the text size
# 使用 OpenCV 的 getTextSize 方法获取 文本的宽度和高度 。这里使用了 cv2.FONT_HERSHEY_SIMPLEX 字体,字体缩放比例为 self.sf - 0.15 ,字体厚度为 self.tf 。
text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]
# Calculate the required radius to fit the text with the margin
# 计算 圆形背景的半径 ,确保文本和间距都能被完整包含。
required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin
# Draw the circle with the required radius
# 使用 OpenCV 的 circle 方法绘制圆形背景。
# self.im :目标图像。
# (x_center, y_center) :圆形的中心点。
# required_radius :圆形的半径。
# color :圆形的填充颜色。
# -1 :表示填充圆形。
cv2.circle(self.im, (x_center, y_center), required_radius, color, -1)
# Calculate the position for the text
# 计算 文本的起始位置 ,确保文本居中显示在圆形背景内。
# text_x :中心点的 x 坐标减去文本宽度的一半。
text_x = x_center - text_size[0] // 2
# text_y :中心点的 y 坐标加上文本高度的一半。
text_y = y_center + text_size[1] // 2
# Draw the text
# 使用 OpenCV 的 putText 方法绘制文本。
cv2.putText(
# 目标图像。
self.im,
# 要绘制的文本内容。
str(label),
# 文本的起始位置。
(text_x, text_y),
# 字体类型。
cv2.FONT_HERSHEY_SIMPLEX,
# 字体缩放比例。
self.sf - 0.15,
# 根据背景颜色选择合适的文本颜色。
self.get_txt_color(color, txt_color),
# 字体厚度。
self.tf,
# 抗锯齿线条类型。
lineType=cv2.LINE_AA,
)
# 这个方法的主要功能是在给定的边界框内绘制一个圆形背景的文本标签,具体步骤如下。限制标签长度:如果标签超过 3 个字符,则截断为前 3 个字符。计算中心点:根据边界框坐标计算中心点。计算文本尺寸:使用 OpenCV 获取文本的宽度和高度。计算圆形半径:根据文本尺寸和间距计算圆形背景的半径。绘制圆形背景:使用 OpenCV 绘制圆形背景。计算文本位置:确保文本居中显示在圆形背景内。绘制文本:使用 OpenCV 绘制文本,确保文本颜色与背景颜色对比明显。这种方法适用于在图像中标注简短的文本信息,尤其是在需要突出显示某些关键信息时。
# 这段代码定义了 Annotator 类中的一个方法 text_label ,用于在给定的边界框内绘制一个带有背景矩形的文本标签。
# 定义了一个方法 text_label ,接受以下参数 :
# 1.box :边界框的坐标,格式为 (x1, y1, x2, y2) 。
# 2.label :要显示的文本标签,默认为空字符串。
# 3.color :背景矩形的颜色,默认为灰色 (128, 128, 128) 。
# 4.txt_color :文本颜色,默认为白色 (255, 255, 255) 。
# 5.margin :文本与背景矩形边缘的间距,默认为 5。
def text_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=5):
# 绘制一个带有背景矩形的标签,该矩形位于给定边界框的中心。
"""
Draws a label with a background rectangle centered within a given bounding box.
Args:
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
label (str): The text label to be displayed.
color (tuple, optional): The background color of the rectangle (B, G, R).
txt_color (tuple, optional): The color of the text (R, G, B).
margin (int, optional): The margin between the text and the rectangle border.
"""
# Calculate the center of the bounding box
# 计算 边界框的中心点坐标 (x_center, y_center) ,用于确定文本的中心位置。
x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
# Get the size of the text
# 使用 OpenCV 的 getTextSize 方法获取 文本的宽度和高度 。这里使用了 cv2.FONT_HERSHEY_SIMPLEX 字体,字体缩放比例为 self.sf - 0.1 ,字体厚度为 self.tf 。
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]
# Calculate the top-left corner of the text (to center it)
# 计算 文本的起始位置 ,确保文本居中显示在边界框内。
# text_x :中心点的 x 坐标减去文本宽度的一半。
text_x = x_center - text_size[0] // 2
# text_y :中心点的 y 坐标加上文本高度的一半。
text_y = y_center + text_size[1] // 2
# Calculate the coordinates of the background rectangle
# 计算 背景矩形的坐标 ,确保矩形能够完整包含文本,并且有适当的间距。
# rect_x1 :文本起始点的 x 坐标减去间距。
rect_x1 = text_x - margin
# rect_y1 :文本起始点的 y 坐标减去文本高度和间距。
rect_y1 = text_y - text_size[1] - margin
# rect_x2 :文本起始点的 x 坐标加上文本宽度和间距。
rect_x2 = text_x + text_size[0] + margin
# rect_y2 :文本起始点的 y 坐标加上间距。
rect_y2 = text_y + margin
# Draw the background rectangle
# 使用 OpenCV 的 rectangle 方法绘制背景矩形。 self.im :目标图像。 (rect_x1, rect_y1) :矩形的左上角坐标。 (rect_x2, rect_y2) :矩形的右下角坐标。 color :矩形的填充颜色。 -1 :表示填充矩形。
cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)
# Draw the text on top of the rectangle
# 使用 OpenCV 的 putText 方法在背景矩形上绘制文本。
cv2.putText(
# 目标图像。
self.im,
# 要绘制的文本内容。
label,
# 文本的起始位置。
(text_x, text_y),
# 字体类型。
cv2.FONT_HERSHEY_SIMPLEX,
# 字体缩放比例。
self.sf - 0.1,
# 根据背景颜色选择合适的文本颜色。
self.get_txt_color(color, txt_color),
# 字体厚度。
self.tf,
# 抗锯齿线条类型。
lineType=cv2.LINE_AA,
)
# 这个方法的主要功能是在给定的边界框内绘制一个带有背景矩形的文本标签,具体步骤如下。计算中心点:根据边界框坐标计算中心点。获取文本尺寸:使用 OpenCV 获取文本的宽度和高度。计算文本位置:确保文本居中显示在边界框内。计算背景矩形坐标:确保背景矩形能够完整包含文本,并且有适当的间距。绘制背景矩形:使用 OpenCV 绘制背景矩形。绘制文本:在背景矩形上绘制文本,确保文本颜色与背景颜色对比明显。这种方法适用于在图像中标注文本信息,尤其是在需要突出显示某些关键信息时。背景矩形可以提高文本的可读性,尤其是在复杂的图像背景上。
# 这段代码定义了 Annotator 类中的一个方法 box_label ,用于在图像上绘制边界框(bounding box)和标签。该方法支持绘制普通的矩形框和旋转的多边形框,并且可以选择使用 PIL 或 OpenCV 进行绘制。
# 定义了一个方法 box_label ,接受以下参数 :
# 1.box :边界框的坐标,格式为 (x1, y1, x2, y2) 或旋转框的多边形点。
# 2.label :要显示的文本标签,默认为空字符串。
# 3.color :边界框和背景矩形的颜色,默认为灰色 (128, 128, 128) 。
# 4.txt_color :文本颜色,默认为白色 (255, 255, 255) 。
# 5.rotated :是否绘制旋转的边界框,默认为 False 。
def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
# 使用标签在图像上绘制边界框。
"""
Draws a bounding box to image with label.
Args:
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
label (str): The text label to be displayed.
color (tuple, optional): The background color of the rectangle (B, G, R).
txt_color (tuple, optional): The color of the text (R, G, B).
rotated (bool, optional): Variable used to check if task is OBB
"""
# 这段代码是 box_label 方法的一部分,用于在 PIL 模式下绘制边界框和文本标签。
# 根据背景颜色 color 自动选择合适的文本颜色,确保文本在背景上清晰可见。 self.get_txt_color 方法会根据背景颜色是深色还是浅色,返回一个合适的文本颜色。
txt_color = self.get_txt_color(color, txt_color)
# 如果边界框 box 是一个 PyTorch 张量,则将其转换为 Python 列表,以便后续操作。
if isinstance(box, torch.Tensor):
box = box.tolist()
# 如果使用 PIL 绘制,或者标签包含非 ASCII 字符(例如中文、阿拉伯文等),则进入 PIL 绘制逻辑。
if self.pil or not is_ascii(label):
# 如果边界框是旋转的多边形框。
if rotated:
# p1 是多边形的第一个点。
p1 = box[0]
# 使用 PIL 的 polygon 方法绘制多边形。 box 需要是一个包含多边形点的列表,每个点是一个元组。 self.lw 是线条宽度, color 是边界框的颜色。
self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) # PIL requires tuple box
# 如果边界框是普通的矩形框。
else:
# p1 是矩形的左上角点。
p1 = (box[0], box[1])
# 使用 PIL 的 rectangle 方法绘制矩形框。 box 是一个包含矩形框坐标 (x1, y1, x2, y2) 的列表。
self.draw.rectangle(box, width=self.lw, outline=color) # box
# 如果需要绘制标签。
if label:
# 使用 self.font.getsize 方法获取文本的宽度和高度。
w, h = self.font.getsize(label) # text width, height
# 判断标签是否应该绘制在边界框的外部( outside )。如果 p1[1] (左上角的 y 坐标)大于文本高度 h ,则认为标签应该绘制在边界框的外部。
outside = p1[1] >= h # label fits outside box
# 如果标签的宽度超出了图像的右侧边界,则调整标签的起始点,使其不超出图像范围。 self.im.size[0] 是图像的宽度。
if p1[0] > self.im.size[0] - w: # size is (w, h), check if label extend beyond right side of image
p1 = self.im.size[0] - w, p1[1]
# 绘制背景矩形,确保标签有合适的背景颜色。根据 outside 的值, 背景矩形的位置 会有所不同。
self.draw.rectangle(
# 如果 outside 为 True ,背景矩形绘制在边界框的上方。
# 如果 outside 为 False ,背景矩形绘制在边界框的下方。
(p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),
fill=color,
)
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
# 在背景矩形上绘制文本标签。根据 outside 的值,文本的位置会有所不同。
# 如果 outside 为 True ,文本绘制在背景矩形的上方。
# 如果 outside 为 False ,文本绘制在背景矩形的下方。
self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
# 这段代码的主要功能是在 PIL 模式下绘制边界框和文本标签,具体步骤如下。选择文本颜色:根据背景颜色选择合适的文本颜色。处理边界框类型:如果边界框是旋转的多边形框,使用 polygon 方法绘制。如果边界框是普通的矩形框,使用 rectangle 方法绘制。处理文本标签:计算文本的宽度和高度。判断标签是否应该绘制在边界框的外部。调整标签的位置,确保不超出图像范围。绘制背景矩形,确保标签有合适的背景颜色。在背景矩形上绘制文本标签。这种方法适用于在图像中标注边界框和文本信息,尤其是在需要突出显示某些关键信息时。支持普通矩形框和旋转框的绘制,使其更加灵活。
# 这段代码是 box_label 方法的一部分,用于在 OpenCV 模式下绘制边界框和文本标签。
# 如果当前的绘图模式是 OpenCV( self.pil 为 False ),则进入以下逻辑。
else: # cv2
# 如果边界框是旋转的多边形框。
if rotated:
# 将多边形的第一个点 box[0] 转换为整数列表 p1 。
p1 = [int(b) for b in box[0]]
# 使用 OpenCV 的 polylines 方法绘制多边形。 box 需要是一个包含多边形点的列表,每个点是一个整数数组。
# np.asarray(box, dtype=int) 将边界框的点转换为 NumPy 数组。 True 表示闭合多边形。 color 是边界框的颜色。 self.lw 是线条宽度。
cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw) # cv2 requires nparray box
# 如果边界框是普通的矩形框。
else:
# 将矩形框的左上角点 (box[0], box[1]) 和右下角点 (box[2], box[3]) 转换为整数元组 p1 和 p2 。
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
# 使用 OpenCV 的 rectangle 方法绘制矩形框。 color 是边界框的颜色。 self.lw 是线条宽度。 cv2.LINE_AA 表示使用抗锯齿线条。
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
# 如果需要绘制标签。
if label:
# 使用 OpenCV 的 getTextSize 方法 获取文本的宽度和高度 。 fontScale=self.sf 是字体缩放比例。 thickness=self.tf 是字体厚度。
w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
# 增加 3 个像素以填充文本。
h += 3 # add pixels to pad text
# 判断标签是否应该绘制在边界框的外部( outside )。如果 p1[1] (左上角的 y 坐标)大于文本高度 h ,则认为标签应该绘制在边界框的外部。
outside = p1[1] >= h # label fits outside box
# 如果标签的宽度超出了图像的右侧边界,则调整标签的起始点,使其不超出图像范围。 self.im.shape[1] 是图像的宽度。
if p1[0] > self.im.shape[1] - w: # shape is (h, w), check if label extend beyond right side of image
p1 = self.im.shape[1] - w, p1[1]
# 绘制 背景矩形 ,确保标签有合适的背景颜色。根据 outside 的值,背景矩形的位置会有所不同。
# 如果 outside 为 True ,背景矩形绘制在边界框的上方。
# 如果 outside 为 False ,背景矩形绘制在边界框的下方。
p2 = p1[0] + w, p1[1] - h if outside else p1[1] + h
# -1 表示填充矩形。
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
# 在背景矩形上绘制文本标签。根据 outside 的值,文本的位置会有所不同。
cv2.putText(
self.im,
label,
# 根据 outside 的值,文本的位置会有所不同。如果 outside 为 True ,文本绘制在背景矩形的上方。 如果 outside 为 False ,文本绘制在背景矩形的下方。
(p1[0], p1[1] - 2 if outside else p1[1] + h - 1),
# 字体类型( cv2.FONT_HERSHEY_SIMPLEX )。
0,
# 字体缩放比例。
self.sf,
# 文本颜色。
txt_color,
# 字体厚度。
thickness=self.tf,
# 表示使用抗锯齿线条。
lineType=cv2.LINE_AA,
)
# 这段代码的主要功能是在 OpenCV 模式下绘制边界框和文本标签,具体步骤如下。处理边界框类型:如果边界框是旋转的多边形框,使用 polylines 方法绘制。如果边界框是普通的矩形框,使用 rectangle 方法绘制。处理文本标签:计算文本的宽度和高度。判断标签是否应该绘制在边界框的外部。调整标签的位置,确保不超出图像范围。绘制背景矩形,确保标签有合适的背景颜色。在背景矩形上绘制文本标签。这种方法适用于在图像中标注边界框和文本信息,尤其是在需要突出显示某些关键信息时。支持普通矩形框和旋转框的绘制,使其更加灵活。
# 这个方法的主要功能是在图像上绘制边界框和标签,具体步骤如下。选择绘制工具:根据是否使用 PIL 或标签是否包含非 ASCII 字符,选择使用 PIL 或 OpenCV 进行绘制。处理旋转框:如果边界框是旋转的多边形框,使用 polygon 或 polylines 方法绘制。处理普通矩形框:如果边界框是普通的矩形框,使用 rectangle 方法绘制。绘制背景矩形:根据文本的宽度和高度,绘制一个背景矩形,确保标签有合适的背景颜色。绘制文本标签:在背景矩形上绘制文本标签,确保文本颜色与背景颜色对比明显。这种方法适用于在图像中标注边界框和文本信息,尤其是在需要突出显示某些关键信息时。支持普通矩形框和旋转框的绘制,使其更加灵活。
# 这段代码定义了 Annotator 类中的一个方法 masks ,用于在图像上绘制预测的掩码(masks)。该方法支持将掩码与原始图像进行融合,并可以选择是否使用高分辨率掩码(retina masks)。
# 定义了一个方法 masks ,接受以下参数 :
# 1.masks :预测的掩码,形状为 (n, h, w) ,其中 n 是掩码的数量, h 和 w 是图像的高度和宽度。
# 2.colors :掩码的颜色,形状为 (n, 3) ,每个掩码对应一个 RGB 颜色。
# 3.im_gpu :原始图像,存储在 GPU 上,形状为 (3, h, w) ,范围为 [0, 1] 。
# 4.alpha :掩码的透明度,默认为 0.5(0.0 完全透明,1.0 完全不透明)。
# 5.retina_masks :是否使用高分辨率掩码,默认为 False 。
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
# 在图像上绘制蒙版。
# 参数:
# mask (tensor):cuda 上的预测蒙版,形状:[n, h, w]
# colors (List[List[Int]]):预测蒙版的颜色,[[r, g, b] * n]
# im_gpu (tensor):图像在 cuda 中,形状:[3, h, w],范围:[0, 1]
# alpha (float):蒙版透明度:0.0 完全透明,1.0 不透明
# retina_masks (bool):是否使用高分辨率蒙版。默认为 False。
"""
Plot masks on image.
Args:
masks (tensor): Predicted masks on cuda, shape: [n, h, w]
colors (List[List[Int]]): Colors for predicted masks, [[r, g, b] * n]
im_gpu (tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]
alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaque
retina_masks (bool): Whether to use high resolution masks or not. Defaults to False.
"""
# 如果当前使用的是 PIL 模式,则将 PIL 图像转换为 NumPy 数组,以便后续操作。
if self.pil:
# Convert to numpy first
self.im = np.asarray(self.im).copy()
# 如果没有掩码( masks 为空),则直接将原始图像 im_gpu 转换为 NumPy 数组,并将其赋值给 self.im 。 im_gpu 需要先进行通道翻转(从 (3, h, w) 转换为 (h, w, 3) ),然后移动到 CPU 并转换为 NumPy 数组,最后乘以 255 以恢复原始像素值。
if len(masks) == 0:
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
# 如果原始图像 im_gpu 和掩码 masks 不在同一个设备上(例如一个在 GPU 上,另一个在 CPU 上),则将 im_gpu 移动到与 masks 相同的设备上。
if im_gpu.device != masks.device:
im_gpu = im_gpu.to(masks.device)
# 将颜色列表 colors 转换为 PyTorch 张量,并将其归一化到 [0, 1] 范围内。
colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
# 然后通过增加两个维度,将其形状从 (n, 3) 转换为 (n, 1, 1, 3) ,以便后续与掩码进行逐元素相乘。
colors = colors[:, None, None] # shape(n,1,1,3)
# 将掩码 masks 的形状从 (n, h, w) 转换为 (n, h, w, 1) ,以便后续与颜色张量进行逐元素相乘。
masks = masks.unsqueeze(3) # shape(n,h,w,1)
# 将 每个掩码 与 对应的颜色 相乘,并应用透明度 alpha ,得到 带有颜色的掩码 。结果的形状为 (n, h, w, 3) 。
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
# torch.cumprod(input, dim, *, out=None) -> Tensor
# torch.cumprod() 是 PyTorch 库中的一个函数,用于计算张量(tensor)的累积乘积(cumulative product)。这个函数会返回一个新的张量,其中每个元素是输入张量中到当前位置为止的所有元素的乘积。
# 参数 :
# input :输入的张量。
# dim :沿着哪个维度计算累积乘积。
# out :(可选)输出张量,用于存储结果。
# 返回值 :
# 返回一个新的张量,包含了输入张量沿着指定维度的累积乘积。
# torch.cumprod() 函数在处理序列数据或者需要累积计算的场景中非常有用,例如在某些类型的递归神经网络(RNN)或者特定的统计计算中。
# 计算每个掩码的 逆透明度累积乘积 ,用于 在多个掩码叠加时保持背景图像的可见性 。 inv_alpha_masks 的形状为 (n, h, w, 1) 。
inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
# 计算 所有掩码颜色的最大值 ,以确保在多个掩码叠加时,每个像素点的颜色是所有掩码中最强的颜色。结果的形状为 (h, w, 3) 。
mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
# 这段代码是 masks 方法的一部分,用于将融合后的掩码图像与原始图像结合,并根据需要调整图像格式和大小。
# torch.flip(input, dims)
# torch.flip() 是 PyTorch 中的一个函数,用于沿指定维度翻转张量(tensor)。这个函数可以对张量进行类似于 NumPy 中 np.flip() 的操作。
# 参数 :
# input ( Tensor ) : 需要被翻转的张量。
# dims ( int 或 tuple 的 int ): 指定要翻转的维度。如果是单个整数,则只沿该维度翻转;如果是元组,则沿这些维度翻转。
# 返回值 :
# 返回一个新的张量,它是输入张量 input 沿指定维度翻转的结果。
# im_gpu = im_gpu.flip(dims=[0]) # flip channel
# im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
# 示例 :
# 假设有一个形状为 (3, 2, 2) 的张量 im_gpu ,其值如下(假设是 RGB 格式) :
# im_gpu = torch.tensor([
# [[1, 2], [3, 4]], # R通道
# [[5, 6], [7, 8]], # G通道
# [[9, 10], [11, 12]] # B通道
# ])
# 第一行代码 : im_gpu.flip(dims=[0]) 。
# 翻转后的张量(BGR 格式) :
# torch.tensor([
# [[9, 10], [11, 12]], # B通道
# [[5, 6], [7, 8]], # G通道
# [[1, 2], [3, 4]] # R通道
# ])
# 第二行代码 : im_gpu.permute(1, 2, 0).contiguous() 。
# 调整维度后的张量(形状为 (h, w, 3) ,通道顺序为 BGR) :
# torch.tensor([
# [[9, 5, 1], [10, 6, 2]],
# [[11, 7, 3], [12, 8, 4]]
# ])
# 总结 :
# 第一行代码 :将通道顺序从 RGB 翻转为 BGR,但形状仍然是 (3, h, w) 。
# 第二行代码 :将形状从 (3, h, w) 调整为 (h, w, 3) ,通道顺序仍然是 BGR。
# 这行代码的作用是将通道顺序翻转。具体来说 :
# 如果输入图像 im_gpu 的通道顺序是 RGB(即形状为 (3, h, w) ,其中第一个维度表示 RGB 三个通道),那么 flip(dims=[0]) 会将其翻转为 BGR 顺序。
# 这是因为 flip(dims=[0]) 是沿着第一个维度(通道维度)进行翻转,将 [0, 1, 2] 翻转为 [2, 1, 0] 。
# 注意 :这行代码不会改变张量的形状,形状仍然是 (3, h, w) ,只是通道的顺序发生了变化。
im_gpu = im_gpu.flip(dims=[0]) # flip channel
# 这行代码的作用是重新排列张量的维度,将形状从 (3, h, w) 调整为 (h, w, 3) 。
# 具体来说, permute(1, 2, 0) 将第一个维度(通道)移到最后,而将原来的高和宽维度移到前面。 contiguous() 确保张量在内存中是连续的,这对于后续操作(如转换为 NumPy 数组)是必要的。
# 综合效果 :
# 第一行代码:将通道顺序从 RGB 翻转为 BGR,但形状仍然是 (3, h, w) 。
# 第二行代码:将形状从 (3, h, w) 调整为 (h, w, 3) ,通道顺序仍然是 BGR。
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
# 将 原始图像 与 掩码颜色 进行融合。 inv_alpha_masks[-1] 是最后一个掩码的逆透明度累积乘积,用于保持背景图像的可见性。 mcs 是所有掩码颜色的最大值,表示每个像素点的最终颜色。 通过 im_gpu * inv_alpha_masks[-1] + mcs ,将原始图像与掩码颜色融合。
im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
# 将融合后的图像 im_gpu 的像素值从 [0, 1] 范围恢复到 [0, 255] 范围,以便后续保存为图像文件。
im_mask = im_gpu * 255
# 将融合后的图像 im_mask 转换为 NumPy 数组。 byte() 将张量转换为无符号 8 位整数( uint8 )。 cpu() 将张量从 GPU 移动到 CPU。 numpy() 将张量转换为 NumPy 数组。
im_mask_np = im_mask.byte().cpu().numpy()
# 根据 是否使用高分辨率掩码 ( retina_masks ),将融合后的图像赋值给 self.im 。
# 如果 retina_masks=True ,直接使用融合后的图像 im_mask_np 。
# 如果 retina_masks=False ,使用 ops.scale_image 方法将融合后的图像缩放到 self.im 的原始形状。
self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
# 如果当前使用的是 PIL 模式,则将 NumPy 数组转换回 PIL 图像,并更新 self.draw 对象。
if self.pil:
# Convert im back to PIL and update draw
# 将 NumPy 数组转换为 PIL 图像,并更新 self.im 和 self.draw 。
self.fromarray(self.im)
# 这段代码的主要功能是将融合后的掩码图像与原始图像结合,并根据需要调整图像格式和大小。具体步骤如下。通道翻转:将原始图像的通道顺序从 CHW 翻转为 HWC 。融合图像:将原始图像与掩码颜色进行融合。恢复像素值:将融合后的图像的像素值从 [0, 1] 范围恢复到 [0, 255] 范围。转换为 NumPy 数组:将融合后的图像转换为 NumPy 数组。调整图像大小:根据是否使用高分辨率掩码,调整图像大小。转换回 PIL:如果使用 PIL 模式,将 NumPy 数组转换回 PIL 图像,并更新绘图对象。这种方法适用于在图像上绘制和融合多个掩码,尤其是在需要突出显示某些区域时。支持高分辨率掩码,使其更加灵活。
# 这个方法的主要功能是在图像上绘制预测的掩码,并将掩码与原始图像进行融合。具体步骤如下。处理 PIL 模式:如果使用 PIL 模式,将 PIL 图像转换为 NumPy 数组。处理空掩码:如果没有掩码,则直接使用原始图像。设备同步:确保原始图像和掩码在同一个设备上(CPU 或 GPU)。颜色和掩码处理:将颜色归一化并调整形状。将掩码调整形状。将掩码与颜色相乘,并应用透明度。逆透明度累积乘积:计算逆透明度累积乘积,用于保持背景图像的可见性。融合图像:将原始图像与掩码颜色进行融合。恢复像素值:将融合后的图像恢复为原始像素值。缩放图像:如果需要,将融合后的图像缩放到原始图像的形状。转换回 PIL:如果使用 PIL 模式,将 NumPy 数组转换回 PIL 图像。这种方法适用于在图像上绘制和融合多个掩码,尤其是在需要突出显示某些区域时。支持高分辨率掩码,使其更加灵活。
# 这段代码定义了 Annotator 类中的一个方法 kpts ,用于在图像上绘制关键点(keypoints)和关键点之间的连接线(skeleton)。
# 定义了一个方法 kpts ,接受以下参数 :
# 1.kpts :关键点数据,形状为 (n, 2) 或 (n, 3) ,其中 n 是关键点的数量,每个关键点包含 (x, y) 坐标,可选的第三个值为置信度。
# 2.shape :图像的形状 (h, w) ,默认为 (640, 640) 。
# 3.radius :关键点的绘制半径,默认为 None ,如果未指定,则使用 self.lw 。
# 4.kpt_line :是否绘制关键点之间的连接线,默认为 True 。
# 5.conf_thres :关键点的置信度阈值,默认为 0.25 ,低于此阈值的关键点将不被绘制。
# 6.kpt_color :关键点的颜色,默认为 None ,如果未指定,则使用 self.kpt_color 或默认颜色。
def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None):
"""
Plot keypoints on the image.
Args:
kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
shape (tuple, optional): Image shape (h, w). Defaults to (640, 640).
radius (int, optional): Keypoint radius. Defaults to 5.
kpt_line (bool, optional): Draw lines between keypoints. Defaults to True.
conf_thres (float, optional): Confidence threshold. Defaults to 0.25.
kpt_color (tuple, optional): Keypoint color (B, G, R). Defaults to None.
Note:
- `kpt_line=True` currently only supports human pose plotting.
- Modifies self.im in-place.
- If self.pil is True, converts image to numpy array and back to PIL.
"""
# 这段代码是 kpts 方法的一部分,用于在图像上绘制关键点。
# 设置关键点的绘制半径。 如果 radius 参数已指定,则使用该值。 如果未指定,则使用 self.lw (线条宽度)作为默认值。
radius = radius if radius is not None else self.lw
# 如果当前使用的是 PIL 模式。
if self.pil:
# Convert to numpy first
# 则将 PIL 图像转换为 NumPy 数组,以便后续操作。这是因为 OpenCV 操作通常需要 NumPy 数组作为输入。
self.im = np.asarray(self.im).copy()
# 获取关键点数据的形状。 nkpt 是关键点的数量。 ndim 是每个关键点的维度(通常是 2 或 3,表示 (x, y) 坐标,可选的第三个值为置信度)。
nkpt, ndim = kpts.shape
# 检查 是否为人体姿态关键点 。 如果关键点数量为 17 且每个关键点的维度为 2 或 3,则认为是人体姿态关键点。
is_pose = nkpt == 17 and ndim in {2, 3}
# 如果 kpt_line 为 True ,则只有在当前绘制的是人体姿态关键点时, 才会绘制关键点之间的连接线 。
kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting
# 遍历每个关键点 k 。
for i, k in enumerate(kpts):
# 选择关键点的颜色。 如果指定了 kpt_color ,则使用 kpt_color 。 如果未指定且当前绘制的是人体姿态关键点,则使用 self.kpt_color[i] 。 否则,使用默认颜色 colors(i) 。
color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i))
# 获取 关键点的坐标 (x_coord, y_coord) 。
x_coord, y_coord = k[0], k[1]
# 检查关键点是否在图像范围内。 x_coord % shape[1] != 0 确保 x 坐标不超出图像宽度。 y_coord % shape[0] != 0 确保 y 坐标不超出图像高度。
if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
# 如果关键点有置信度( len(k) == 3 ),且置信度低于阈值 conf_thres ,则跳过该关键点。
if len(k) == 3:
conf = k[2]
if conf < conf_thres:
continue
# 使用 OpenCV 的 circle 方法绘制关键点。 self.im 是目标图像。 (int(x_coord), int(y_coord)) 是关键点的坐标。 radius 是关键点的半径。 color_k 是关键点的颜色。 -1 表示填充圆形。 cv2.LINE_AA 表示使用抗锯齿线条。
cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)
# 这段代码的主要功能是在图像上绘制关键点,具体步骤如下。设置关键点半径:如果未指定 radius ,则使用 self.lw 。处理 PIL 模式:如果使用 PIL 模式,将 PIL 图像转换为 NumPy 数组。检查关键点数量和维度:确定是否为人体姿态关键点。遍历关键点:选择关键点的颜色。获取关键点的坐标。检查关键点是否在图像范围内。如果关键点有置信度,且置信度低于阈值,则跳过该关键点。使用 OpenCV 的 circle 方法绘制关键点。这种方法适用于在图像中标注关键点,尤其是在需要突出显示人体姿态或其他关键点信息时。
# 这段代码是 kpts 方法的一部分,用于在图像上绘制关键点之间的连接线(skeleton)。
# 如果 kpt_line 为 True ,则绘制关键点之间的连接线。 kpt_line 是一个布尔参数,用于控制是否绘制连接线。
if kpt_line:
# 获取 每个关键点的维度 ndim 。 kpts.shape[-1] 表示每个关键点的最后一个维度的大小,通常是 2( (x, y) 坐标)或 3( (x, y, confidence) )。
ndim = kpts.shape[-1]
# 遍历 self.skeleton ,这是一个包含关键点连接关系的列表。每个元素 sk 是一个包含两个关键点索引的列表,表示这两个关键点之间有一条连接线。
for i, sk in enumerate(self.skeleton):
# 获取连接线的两个关键点的坐标。
# pos1 是第一个关键点的坐标。
# sk[0] - 1 和 sk[1] - 1 是关键点的索引,减 1 是因为索引从 0 开始。
pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))
# pos2 是第二个关键点的坐标。
pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))
# 如果每个关键点有置信度( ndim == 3 )。
if ndim == 3:
conf1 = kpts[(sk[0] - 1), 2]
conf2 = kpts[(sk[1] - 1), 2]
# 则检查两个关键点的置信度是否低于阈值 conf_thres 。
if conf1 < conf_thres or conf2 < conf_thres:
# 如果任何一个关键点的置信度低于阈值,则跳过该连接线。
continue
# 检查两个关键点的坐标是否在图像范围内。
# pos1[0] % shape[1] == 0 和 pos1[1] % shape[0] == 0 检查坐标是否超出图像边界。
# pos1[0] < 0 和 pos1[1] < 0 检查坐标是否为负值。
# 如果任何一个关键点的坐标超出图像范围,则跳过该连接线。
if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:
continue
if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:
continue
# 使用 OpenCV 的 line 方法绘制连接线。
cv2.line(
# 目标图像。
self.im,
# pos1 和 pos2 是连接线的两个端点。
pos1,
pos2,
# 连接线的颜色,如果未指定,则使用 self.limb_color[i] 。
kpt_color or self.limb_color[i].tolist(),
# 连接线的宽度,设置为 int(np.ceil(self.lw / 2)) 。
thickness=int(np.ceil(self.lw / 2)),
# 表示使用抗锯齿线条。
lineType=cv2.LINE_AA,
)
# 如果当前使用的是 PIL 模式,则将 NumPy 数组转换回 PIL 图像,并更新 self.draw 对象。
if self.pil:
# Convert im back to PIL and update draw
self.fromarray(self.im)
# 这段代码的主要功能是在图像上绘制关键点之间的连接线,具体步骤如下。检查是否绘制连接线:如果 kpt_line 为 True ,则继续绘制连接线。获取关键点维度:确定每个关键点是否有置信度。遍历连接关系:遍历 self.skeleton ,获取每条连接线的两个关键点的坐标。检查置信度:如果关键点有置信度,且置信度低于阈值,则跳过该连接线。检查坐标范围:如果关键点的坐标超出图像范围,则跳过该连接线。绘制连接线:使用 OpenCV 的 line 方法绘制连接线。转换回 PIL:如果使用 PIL 模式,将 NumPy 数组转换回 PIL 图像,并更新绘图对象。这种方法适用于在图像中标注关键点之间的连接线,尤其是在需要突出显示人体姿态或其他关键点信息时。
# 这个方法的主要功能是在图像上绘制关键点和关键点之间的连接线,具体步骤如下。设置关键点半径:如果未指定 radius ,则使用 self.lw 。处理 PIL 模式:如果使用 PIL 模式,将 PIL 图像转换为 NumPy 数组。检查关键点数量和维度:确定是否为人体姿态关键点。绘制关键点:遍历每个关键点。选择关键点的颜色。检查关键点是否在图像范围内。如果关键点有置信度,且置信度低于阈值,则跳过该关键点。使用 OpenCV 的 circle 方法绘制关键点。绘制连接线:遍历每个连接线。获取两个关键点的坐标。如果关键点有置信度,且置信度低于阈值,则跳过该连接线。检查两个关键点是否在图像范围内。使用 OpenCV 的 line 方法绘制连接线。转换回 PIL:如果使用 PIL 模式,将 NumPy 数组转换回 PIL 图像,并更新绘图对象。这种方法适用于在图像中标注关键点和关键点之间的连接线,尤其是在需要突出显示人体姿态或其他关键点信息时。
# 这段代码定义了 Annotator 类中的一个方法 rectangle ,用于在 PIL 模式下向图像中添加一个矩形。
# 定义了一个方法 rectangle ,接受以下参数 :
# 1.xy :矩形的坐标,格式为 (x1, y1, x2, y2) ,其中 (x1, y1) 是矩形的左上角坐标, (x2, y2) 是矩形的右下角坐标。
# 2.fill :矩形的填充颜色,默认为 None ,表示不填充。
# 3.outline :矩形的轮廓颜色,默认为 None ,表示不绘制轮廓。
# 4.width :轮廓的宽度,默认为 1。
def rectangle(self, xy, fill=None, outline=None, width=1):
# 向图像添加矩形(仅限 PIL)。
"""Add rectangle to image (PIL-only)."""
# 使用 PIL 的 ImageDraw.Draw.rectangle 方法在图像上绘制矩形。
# self.draw 是一个 ImageDraw.Draw 对象,用于在 PIL 图像上进行绘制。
# xy 是矩形的坐标。
# fill 是矩形的填充颜色。
# outline 是矩形的轮廓颜色。
# width 是轮廓的宽度。
self.draw.rectangle(xy, fill, outline, width)
# 这个方法的主要功能是在 PIL 模式下向图像中添加一个矩形,具体步骤如下。参数解析: xy :矩形的坐标。 fill :填充颜色。 outline :轮廓颜色。 width :轮廓宽度。绘制矩形:使用 PIL 的 ImageDraw.Draw.rectangle 方法在图像上绘制矩形。这种方法适用于在图像中标注矩形区域,尤其是在需要突出显示某些区域时。由于该方法仅在 PIL 模式下有效,因此在使用时需要确保 self.pil 为 True 。
# 这段代码定义了 Annotator 类中的一个方法 text ,用于在图像上绘制文本。该方法支持在 PIL 模式和 OpenCV 模式下绘制文本,并且可以选择是否为文本添加背景框。
# 定义了一个方法 text ,接受以下参数 :
# 1.xy :文本的起始坐标 (x, y) 。
# 2.text :要绘制的文本内容。
# 3.txt_color :文本颜色,默认为白色 (255, 255, 255) 。
# 4.anchor :文本的锚点位置,默认为 "top" ,表示文本从顶部开始绘制。如果设置为 "bottom" ,则从底部开始绘制。
# 5.box_style :是否为文本添加背景框,默认为 False 。
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):
# 使用 PIL 或 cv2 向图像添加文本。
"""Adds text to an image using PIL or cv2."""
# 如果 anchor 设置为 "bottom" ,则调整 y 坐标,使文本从底部开始绘制。
if anchor == "bottom": # start y from font bottom
# 使用 self.font.getsize 获取文本的 宽度 和 高度 。
w, h = self.font.getsize(text) # text width, height
# 调整 y 坐标, xy[1] += 1 - h ,使文本的底部与指定的 y 坐标对齐。
# 这行代码的作用是调整文本的垂直位置,使得文本的底部与指定的 y 坐标对齐。
# 示例说明 :
# 假设有一个文本 "Hello" ,需要将其绘制在图像上,且文本的起始坐标为 (x, y) 。希望文本的底部与 y 坐标对齐,而不是文本的顶部。
# 初始设置 :
# 文本内容 : "Hello" 。
# 初始坐标 : xy = [100, 150] (即 x = 100 , y = 150 ) 。
# 文本高度 :假设通过 self.font.getsize("Hello") 获取到的文本高度为 h = 20 。
# 计算调整后的 y 坐标需要将文本的底部与 y = 150 对齐。文本的高度为 20 ,因此文本的底部实际上位于 y + h = 150 + 20 = 170 。
# 为了使文本的底部与 y = 150 对齐,需要将 y 坐标向上调整 h 个像素。具体计算如下 :
# xy[1] += 1 - h
# xy[1] 是当前的 y 坐标,即 150 。
# h 是文本的高度,即 20 。
# 1 - h 的值为 1 - 20 = -19 。因此,调整后的 y 坐标为 :
# xy[1] = 150 + (-19) = 131
# 结果 :调整后的坐标为 (100, 131) 。这样,文本 "Hello" 的底部将与原始的 y = 150 对齐。
xy[1] += 1 - h
# 这段代码是 text 方法的一部分,用于在 PIL 模式下绘制文本,并根据需要为文本添加背景框。
# 如果当前使用的是 PIL 模式,则进入以下逻辑。
if self.pil:
# 如果 box_style 为 True ,则为文本添加背景框。
if box_style:
# 使用 self.font.getsize 方法获取文本的 宽度 w 和 高度 h 。
w, h = self.font.getsize(text)
# 绘制背景框。
# 背景框的左上角坐标为 (xy[0], xy[1]) 。
# 背景框的右下角坐标为 (xy[0] + w + 1, xy[1] + h + 1) ,其中 +1 是为了确保文本不会超出背景框。
# 使用 fill=txt_color 填充背景框。
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
# Using `txt_color` for background and draw fg with white color
# 将文本颜色设置为白色 (255, 255, 255) ,以确保文本在背景框上清晰可见。
txt_color = (255, 255, 255)
# 如果文本中包含换行符 \n ,则逐行绘制文本。
if "\n" in text:
# 将文本按换行符分割为多行。
lines = text.split("\n")
# 获取 文本的高度 h ,用于计算每行文本的间距。
_, h = self.font.getsize(text)
# 遍历每一行文本。
for line in lines:
# 使用 self.draw.text 方法绘制当前行的文本。
self.draw.text(xy, line, fill=txt_color, font=self.font)
# 每绘制一行后,将 y 坐标 xy[1] 增加 h ,以便下一行文本在正确的位置绘制。
xy[1] += h
# 如果文本中不包含换行符,则直接绘制文本。
else:
# 使用 self.draw.text 方法绘制文本。 xy 是文本的起始坐标。 text 是要绘制的文本内容。 fill=txt_color 是文本颜色。 font=self.font 是字体对象。
self.draw.text(xy, text, fill=txt_color, font=self.font)
# 这段代码的主要功能是在 PIL 模式下绘制文本,并根据需要为文本添加背景框。具体步骤如下。检查是否使用 PIL 模式:如果使用 PIL 模式,则进入后续逻辑。添加背景框:如果 box_style 为 True ,则计算文本尺寸并绘制背景框。将文本颜色设置为白色,以确保文本在背景框上清晰可见。处理多行文本:如果文本中包含换行符,则逐行绘制文本。每绘制一行后,更新 y 坐标,以便下一行文本在正确的位置绘制。直接绘制文本:如果文本中不包含换行符,则直接绘制文本。这种方法适用于在图像中标注多行文本,并为文本添加背景框,以提高可读性。
# 这段代码是 text 方法的一部分,用于在 OpenCV 模式下绘制文本,并根据需要为文本添加背景框。
# 如果当前使用的是 OpenCV 模式,则进入以下逻辑。
else:
# 如果 box_style 为 True ,则为文本添加背景框。
if box_style:
# 使用 OpenCV 的 getTextSize 方法获取文本的宽度 w 和高度 h 。 0 是字体类型( cv2.FONT_HERSHEY_SIMPLEX )。 self.sf 是字体缩放比例。 self.tf 是字体厚度。 cv2.getTextSize 返回一个元组,其中第一个元素是 (w, h) 。
w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
# 增加一些像素以填充文本,确保文本不会紧贴背景框的边缘。
h += 3 # add pixels to pad text
# 判断文本是否应该绘制在指定位置的外部。 如果 xy[1] (文本的起始 y 坐标)大于或等于文本高度 h ,则认为文本应该绘制在外部。
outside = xy[1] >= h # label fits outside box
# 计算 背景框的右下角坐标 p2 。如果 outside 为 True ,则背景框的右下角坐标为 (xy[0] + w, xy[1] - h) 。 如果 outside 为 False ,则背景框的右下角坐标为 (xy[0] + w, xy[1] + h) 。
p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h
# 使用 OpenCV 的 rectangle 方法绘制背景框。 self.im 是目标图像。 xy 是背景框的左上角坐标。 p2 是背景框的右下角坐标。 txt_color 是背景框的颜色。 -1 表示填充矩形。 cv2.LINE_AA 表示使用抗锯齿线条。
cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA) # filled
# Using `txt_color` for background and draw fg with white color
# 将文本颜色设置为白色 (255, 255, 255) ,以确保文本在背景框上清晰可见。
txt_color = (255, 255, 255)
# 使用 OpenCV 的 putText 方法绘制文本。 self.im 是目标图像。 text 是要绘制的文本内容。 xy 是文本的起始坐标。 0 是字体类型( cv2.FONT_HERSHEY_SIMPLEX )。 self.sf 是字体缩放比例。 txt_color 是文本颜色。 self.tf 是字体厚度。 cv2.LINE_AA 表示使用抗锯齿线条。
cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
# 这段代码的主要功能是在 OpenCV 模式下绘制文本,并根据需要为文本添加背景框。具体步骤如下。检查是否使用 OpenCV 模式:如果使用 OpenCV 模式,则进入后续逻辑。添加背景框:如果 box_style 为 True ,则计算文本尺寸并绘制背景框。将文本颜色设置为白色,以确保文本在背景框上清晰可见。绘制文本:使用 OpenCV 的 putText 方法绘制文本。这种方法适用于在图像中标注文本信息,并为文本添加背景框,以提高可读性。
# 这个方法的主要功能是在图像上绘制文本,并根据需要为文本添加背景框。具体步骤如下。调整文本位置:根据 anchor 参数调整文本的起始位置。处理 PIL 模式:如果 box_style 为 True ,则为文本添加背景框。如果文本包含换行符,则逐行绘制文本。否则,直接绘制文本。处理 OpenCV 模式:如果 box_style 为 True ,则为文本添加背景框。使用 OpenCV 的 putText 方法绘制文本。这种方法适用于在图像中标注文本信息,尤其是在需要突出显示某些关键信息时。支持多行文本和背景框,使其更加灵活。
# 这段代码定义了 Annotator 类中的一个方法 fromarray ,用于将 NumPy 数组转换为 PIL 图像,并更新绘图对象。
# 定义了一个方法 fromarray ,接受一个参数。
# 1.im :该参数可以是 NumPy 数组或 PIL 图像。
def fromarray(self, im):
# "从 numpy 数组更新 self.im。
"""Update self.im from a numpy array."""
# 检查传入的 im 是否已经是 PIL 图像( isinstance(im, Image.Image) )。 如果 im 已经是 PIL 图像,则直接将其赋值给 self.im 。 如果 im 是 NumPy 数组,则使用 Image.fromarray 方法将其转换为 PIL 图像,并赋值给 self.im 。
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
# 创建一个 ImageDraw.Draw 对象,用于在 PIL 图像上进行绘制。 将该对象赋值给 self.draw ,以便后续可以使用它进行绘图操作。
self.draw = ImageDraw.Draw(self.im)
# 这个方法的主要功能是将 NumPy 数组转换为 PIL 图像,并更新绘图对象。具体步骤如下。检查输入类型:判断传入的 im 是否已经是 PIL 图像。转换为 PIL 图像:如果 im 是 NumPy 数组,则使用 Image.fromarray 方法将其转换为 PIL 图像。如果 im 已经是 PIL 图像,则直接使用。更新绘图对象:创建一个新的 ImageDraw.Draw 对象,并将其赋值给 self.draw 。这种方法适用于在图像处理过程中需要将 NumPy 数组转换为 PIL 图像的情况,尤其是在需要使用 PIL 的绘图功能时。
# 这段代码定义了 Annotator 类中的一个方法 result ,用于将当前的 PIL 图像转换为 NumPy 数组,并返回该数组。
# 定义了一个方法 result ,该方法不接受任何参数。
def result(self):
# 以数组形式返回带注释的图像。
"""Return annotated image as array."""
# 使用 np.asarray 方法将 self.im (当前的 PIL 图像)转换为 NumPy 数组。 返回转换后的 NumPy 数组。
return np.asarray(self.im)
# 这个方法的主要功能是将当前的 PIL 图像转换为 NumPy 数组。具体步骤如下。转换图像:使用 np.asarray 方法将 PIL 图像 self.im 转换为 NumPy 数组。返回结果:返回转换后的 NumPy 数组。这种方法适用于在需要将 PIL 图像转换为 NumPy 数组进行进一步处理的场景,例如在图像处理或机器学习任务中。
# 这段代码定义了 Annotator 类中的一个方法 show ,用于显示当前的图像。该方法支持在不同的环境中显示图像,包括 Google Colab、Kaggle 和本地环境。
# 定义了一个方法 show ,接受一个可选参数。
# 1.title :用于设置显示图像的标题。
def show(self, title=None):
# 显示带注释的图像。
"""Show the annotated image."""
# 将 self.im (当前的图像)转换为 NumPy 数组。 使用 [..., ::-1] 将 RGB 格式转换为 BGR 格式(因为 OpenCV 默认使用 BGR 格式)。 使用 Image.fromarray 将 NumPy 数组转换回 PIL 图像。
im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
# 检查当前环境是否为 Google Colab 或 Kaggle。这些环境通常不支持直接调用 im.show() ,因此需要使用 display 函数。
if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
# 在 Colab 或 Kaggle 环境中,尝试使用 display 函数显示图像。
try:
display(im) # noqa - display() function only available in ipython environments
# 如果 display 函数不可用(例如在非 Jupyter 环境中),捕获 ImportError 异常,并记录警告信息。
except ImportError as e:
LOGGER.warning(f"Unable to display image in Jupyter notebooks: {e}") # 无法在 Jupyter 笔记本中显示图像:{e}。
# 如果当前环境不是 Colab 或 Kaggle,则直接使用 PIL 的 show 方法显示图像。 title 参数用于设置显示窗口的标题。
else:
im.show(title=title)
# 这个方法的主要功能是显示当前的图像,具体步骤如下。转换图像:将当前图像 self.im 转换为 NumPy 数组。将 RGB 格式转换为 BGR 格式。将 NumPy 数组转换回 PIL 图像。检查环境:如果当前环境是 Google Colab 或 Kaggle,使用 display 函数显示图像。如果 display 函数不可用,记录警告信息。显示图像:如果当前环境不是 Colab 或 Kaggle,直接使用 PIL 的 show 方法显示图像。这种方法适用于在不同的环境中显示图像,尤其是在需要兼容多种运行环境时。
# 这段代码定义了 Annotator 类中的一个方法 save ,用于将当前的图像保存到指定的文件路径。
# 定义了一个方法 save ,接受一个可选参数。
# 1.filename :用于指定保存图像的文件名。默认值为 "image.jpg" 。
def save(self, filename="image.jpg"):
# 将带注释的图像保存至‘文件名’。
"""Save the annotated image to 'filename'."""
# 将 self.im (当前的图像)转换为 NumPy 数组。
# 使用 OpenCV 的 cv2.imwrite 方法将 NumPy 数组保存到指定的文件路径 filename 。
# cv2.imwrite 方法会根据文件扩展名自动选择合适的格式保存图像(例如 .jpg 、 .png 等)。
cv2.imwrite(filename, np.asarray(self.im))
# 这个方法的主要功能是将当前的图像保存到指定的文件路径。具体步骤如下。转换图像:将当前图像 self.im 转换为 NumPy 数组。保存图像:使用 OpenCV 的 cv2.imwrite 方法将 NumPy 数组保存到指定的文件路径。这种方法适用于在处理图像后需要将结果保存为文件的场景。使用 OpenCV 的 cv2.imwrite 方法可以方便地保存图像,并且支持多种图像格式。
# 这段代码定义了 Annotator 类中的一个静态方法 get_bbox_dimension ,用于计算边界框的宽度、高度和面积。
@staticmethod
# 定义了一个静态方法 get_bbox_dimension ,接受一个可选参数。
# 1.bbox :该参数是一个包含边界框坐标的元组或列表,格式为 (x_min, y_min, x_max, y_max) 。如果未指定 bbox ,则默认为 None 。
def get_bbox_dimension(bbox=None):
# 计算边界框的面积。
"""
Calculate the area of a bounding box.
Args:
bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
Returns:
width (float): Width of the bounding box.
height (float): Height of the bounding box.
area (float): Area enclosed by the bounding box.
"""
# 从 bbox 中解包边界框的坐标。
# x_min :边界框的左上角 x 坐标。
# y_min :边界框的左上角 y 坐标。
# x_max :边界框的右下角 x 坐标。
# y_max :边界框的右下角 y 坐标。
x_min, y_min, x_max, y_max = bbox
# 计算 边界框的宽度 。 宽度等于右下角 x 坐标减去左上角 x 坐标。
width = x_max - x_min
# 计算 边界框的高度 。 高度等于右下角 y 坐标减去左上角 y 坐标。
height = y_max - y_min
# 返回 边界框的宽度 、 高度 和 面积 。 面积等于宽度乘以高度。
return width, height, width * height
# 这个静态方法的主要功能是计算边界框的宽度、高度和面积。具体步骤如下。解包边界框坐标:从 bbox 中提取左上角和右下角的坐标。计算宽度和高度:宽度等于右下角 x 坐标减去左上角 x 坐标。高度等于右下角 y 坐标减去左上角 y 坐标。计算面积:面积等于宽度乘以高度。返回结果:返回宽度、高度和面积。这种方法适用于在图像处理或目标检测任务中需要计算边界框尺寸的场景。使用静态方法可以方便地调用该功能,而无需创建类的实例。
# 这段代码定义了 Annotator 类中的一个方法 draw_region ,用于在图像上绘制一个多边形区域,并在多边形的每个顶点处绘制小圆圈。
# 定义了一个方法 draw_region ,接受以下参数 :
# 1.reg_pts :多边形区域的顶点坐标,格式为一个列表,每个元素是一个包含两个整数的元组 (x, y) 。如果未指定,则默认为 None 。
# 2.color :多边形和顶点圆圈的颜色,默认为绿色 (0, 255, 0) 。
# 3.thickness :多边形边的厚度,默认为 5。
def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):
# 绘制区域线。
"""
Draw region line.
Args:
reg_pts (list): Region Points (for line 2 points, for region 4 points)
color (tuple): Region Color value
thickness (int): Region area thickness value
"""
# 将 reg_pts 转换为 NumPy 数组,并确保其数据类型为 np.int32 。
# 使用 OpenCV 的 polylines 方法绘制多边形。
# self.im 是目标图像。
# [np.array(reg_pts, dtype=np.int32)] 是一个包含多边形顶点的列表。
# isClosed=True 表示多边形是闭合的。
# color 是多边形的颜色。
# thickness 是多边形边的厚度。
cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
# Draw small circles at the corner points
# 遍历每个顶点 point 。
for point in reg_pts:
# 使用 OpenCV 的 circle 方法在每个顶点处绘制一个小圆圈。 self.im 是目标图像。 (point[0], point[1]) 是圆圈的中心坐标。 thickness * 2 是圆圈的半径,这里设置为边厚度的两倍。 color 是圆圈的颜色。 -1 表示填充圆圈。
cv2.circle(self.im, (point[0], point[1]), thickness * 2, color, -1) # -1 fills the circle
# 这个方法的主要功能是在图像上绘制一个多边形区域,并在多边形的每个顶点处绘制小圆圈。具体步骤如下。绘制多边形:将顶点坐标转换为 NumPy 数组。使用 OpenCV 的 polylines 方法绘制多边形。绘制顶点圆圈:遍历每个顶点。使用 OpenCV 的 circle 方法在每个顶点处绘制一个小圆圈。这种方法适用于在图像中标注多边形区域,尤其是在需要突出显示区域的顶点时。
# 这段代码定义了 Annotator 类中的一个方法 draw_centroid_and_tracks ,用于在图像上绘制轨迹(tracks)和轨迹的最后一个点(centroid)。
# 定义了一个方法 draw_centroid_and_tracks ,接受以下参数 :
# 1.track :轨迹点的列表,每个轨迹点是一个包含两个浮点数的列表或元组,表示 (x, y) 坐标。
# 2.color :轨迹和中心点的颜色,默认为紫色 (255, 0, 255) 。
# 3.track_thickness :轨迹线的厚度,默认为 2。
def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
# 绘制质心点和轨迹。
"""
Draw centroid point and track trails.
Args:
track (list): object tracking points for trails display
color (tuple): tracks line color
track_thickness (int): track line thickness value
"""
# np.hstack(tup)
# np.hstack() 是 NumPy 库中的一个函数,用于水平(按列顺序)堆叠数组。
# 参数 :
# tup :一个元组或列表,包含要水平堆叠的数组。这些数组必须有相同的形状,除了第二维(列)。
# 返回值 :
# 返回一个数组,它是输入数组水平堆叠的结果。
# 说明 :
# np.hstack() 函数将多个数组水平(沿着第二维)堆叠起来。这意味着所有输入数组的第一维(行)必须相同,而第二维(列)可以不同。
# 如果输入数组的维度大于2,那么除了第一维和第二维之外,其他维度的大小必须相同。
# 该函数常用于将具有相同行数的多个数组合并为一个更宽的数组。
# 将 track 中的所有点合并为一个 NumPy 数组。
# 使用 astype(np.int32) 将坐标转换为整数类型,因为 OpenCV 要求坐标为整数。
# 使用 reshape((-1, 1, 2)) 将数组重新调整为形状 (n, 1, 2) ,其中 n 是轨迹点的数量。这是 cv2.polylines 方法所需的格式。
# n :表示轨迹点的数量。例如,如果 track 包含 5 个点,那么 n 就是 5。
# 1 :这是一个额外的维度,用于表示每个点是一个单独的点。在 OpenCV 中, cv2.polylines 方法要求输入的点数组的形状为 (n, 1, 2) ,其中 n 是点的数量。这个额外的维度是为了满足 OpenCV 的要求。
# 2 :表示每个点有 2 个坐标值,即 x 和 y 。
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
# 使用 OpenCV 的 polylines 方法绘制轨迹。
# self.im 是目标图像。
# [points] 是包含轨迹点的列表。
# isClosed=False 表示轨迹不是闭合的。
# color 是轨迹的颜色。
# track_thickness 是轨迹线的厚度。
cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
# 使用 OpenCV 的 circle 方法在轨迹的最后一个点(centroid)处绘制一个小圆圈。
# self.im 是目标图像。
# (int(track[-1][0]), int(track[-1][1])) 是最后一个轨迹点的坐标。
# track_thickness * 2 是圆圈的半径,这里设置为轨迹厚度的两倍。
# color 是圆圈的颜色。
# -1 表示填充圆圈。
cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
# 这个方法的主要功能是在图像上绘制轨迹和轨迹的最后一个点(centroid)。具体步骤如下。转换轨迹点:将轨迹点合并为一个 NumPy 数组,并转换为整数类型。将数组调整为形状 (n, 1, 2) ,以满足 cv2.polylines 方法的要求。绘制轨迹:使用 OpenCV 的 polylines 方法绘制轨迹。绘制最后一个点:使用 OpenCV 的 circle 方法在轨迹的最后一个点处绘制一个小圆圈。这种方法适用于在图像中标注轨迹和轨迹的最后一个点,尤其是在需要突出显示对象的运动轨迹时。
# 这段代码定义了 Annotator 类中的一个方法 queue_counts_display ,用于在图像上显示队列计数信息(例如人数计数)。该方法会在指定的区域中心显示文本标签,并为文本添加一个背景矩形以提高可读性。
# 定义了一个方法 queue_counts_display ,接受以下参数 :
# 1.label :要显示的文本标签,例如队列计数信息。
# 2.points :一个包含点坐标的列表,每个点是一个包含两个整数的元组 (x, y) ,用于计算显示文本的中心位置。
# 3.region_color :背景矩形的颜色,默认为白色 (255, 255, 255) 。
# 4.txt_color :文本颜色,默认为黑色 (0, 0, 0) 。
def queue_counts_display(self, label, points=None, region_color=(255, 255, 255), txt_color=(0, 0, 0)):
# 在以点为中心的图像上显示队列计数,字体大小和颜色可自定义。
"""
Displays queue counts on an image centered at the points with customizable font size and colors.
Args:
label (str): Queue counts label.
points (tuple): Region points for center point calculation to display text.
region_color (tuple): RGB queue region color.
txt_color (tuple): RGB text display color.
"""
# 从 points 中提取所有点的 x 坐标和 y 坐标,分别存储在 x_values 和 y_values 列表中。
x_values = [point[0] for point in points]
y_values = [point[1] for point in points]
# 计算所有点的平均 x 坐标和 y 坐标,得到显示文本的中心位置 (center_x, center_y) 。
center_x = sum(x_values) // len(points)
center_y = sum(y_values) // len(points)
# 使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。
# 0 是字体类型( cv2.FONT_HERSHEY_SIMPLEX )。
# self.sf 是字体缩放比例。
# self.tf 是字体厚度。
text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0]
# text_size[0] 返回一个元组 (text_width, text_height) 。
text_width = text_size[0]
text_height = text_size[1]
# 计算 背景矩形 的 宽度 和 高度 ,增加一些额外的像素以确保文本不会紧贴矩形边缘。
rect_width = text_width + 20
rect_height = text_height + 20
# 计算背景矩形的 左上角 和 右下角 坐标。
# rect_top_left 是左上角坐标。
rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2)
# rect_bottom_right 是右下角坐标。
rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2)
# 使用 OpenCV 的 rectangle 方法绘制背景矩形。 self.im 是目标图像。 rect_top_left 和 rect_bottom_right 是矩形的左上角和右下角坐标。 region_color 是矩形的颜色。 -1 表示填充矩形。
cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1)
# 计算 文本的起始坐标 ,使文本居中显示在背景矩形内。
text_x = center_x - text_width // 2
text_y = center_y + text_height // 2
# Draw text
# 使用 OpenCV 的 putText 方法绘制文本。
cv2.putText(
# 目标图像。
self.im,
# 要显示的文本内容。
label,
# 文本的起始坐标。
(text_x, text_y),
# 字体类型( cv2.FONT_HERSHEY_SIMPLEX )。
0,
# 字体缩放比例。
fontScale=self.sf,
# 文本颜色。
color=txt_color,
# 字体厚度。
thickness=self.tf,
# 表示使用抗锯齿线条。
lineType=cv2.LINE_AA,
)
# 这个方法的主要功能是在图像上显示队列计数信息,并为文本添加一个背景矩形以提高可读性。具体步骤如下。计算中心位置:从 points 中提取所有点的坐标,计算平均 x 和 y 坐标,得到显示文本的中心位置。获取文本尺寸:使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。计算背景矩形尺寸:根据文本尺寸计算背景矩形的宽度和高度,增加一些额外的像素。绘制背景矩形:使用 OpenCV 的 rectangle 方法绘制背景矩形。计算文本位置:使文本居中显示在背景矩形内。绘制文本:使用 OpenCV 的 putText 方法绘制文本。这种方法适用于在图像中标注队列计数信息,尤其是在需要突出显示某些关键信息时。
# 这段代码定义了 Annotator 类中的一个方法 display_objects_labels ,用于在图像上显示对象的标签。该方法会在指定的中心位置显示文本标签,并为文本添加一个背景矩形以提高可读性。
# 定义了一个方法 display_objects_labels ,接受以下参数 :
# 1.im0 :目标图像。
# 2.text :要显示的文本内容。
# 3.txt_color :文本颜色。
# 4.bg_color :背景矩形的颜色。
# 5.x_center :文本的中心 x 坐标。
# 6.y_center :文本的中心 y 坐标。
# 7.margin :文本与背景矩形边缘的间距。
def display_objects_labels(self, im0, text, txt_color, bg_color, x_center, y_center, margin):
# 在停车管理应用中显示边界框标签。
"""
Display the bounding boxes labels in parking management app.
Args:
im0 (ndarray): Inference image.
text (str): Object/class name.
txt_color (tuple): Display color for text foreground.
bg_color (tuple): Display color for text background.
x_center (float): The x position center point for bounding box.
y_center (float): The y position center point for bounding box.
margin (int): The gap between text and rectangle for better display.
"""
# 使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。 0 是字体类型( cv2.FONT_HERSHEY_SIMPLEX )。 self.sf 是字体缩放比例。 self.tf 是字体厚度。 text_size[0] 返回一个元组 (text_width, text_height) 。
text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
# 计算文本的起始坐标,使文本居中显示在指定的中心位置。
# text_x 是文本的起始 x 坐标,计算为 x_center - text_width // 2 。
text_x = x_center - text_size[0] // 2
# text_y 是文本的起始 y 坐标,计算为 y_center + text_height // 2 。
text_y = y_center + text_size[1] // 2
# 计算背景矩形的坐标。在文本的每个方向上增加 margin 以确保文本不会紧贴矩形边缘。
# rect_x1 和 rect_y1 是背景矩形的左上角坐标。
rect_x1 = text_x - margin
rect_y1 = text_y - text_size[1] - margin
# rect_x2 和 rect_y2 是背景矩形的右下角坐标。
rect_x2 = text_x + text_size[0] + margin
rect_y2 = text_y + margin
# 使用 OpenCV 的 rectangle 方法绘制背景矩形。 im0 是目标图像。 (rect_x1, rect_y1) 和 (rect_x2, rect_y2) 是矩形的左上角和右下角坐标。 bg_color 是矩形的颜色。 -1 表示填充矩形。
cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
# 使用 OpenCV 的 putText 方法绘制文本。 im0 是目标图像。 text 是要显示的文本内容。 (text_x, text_y) 是文本的起始坐标。 0 是字体类型( cv2.FONT_HERSHEY_SIMPLEX )。 self.sf 是字体缩放比例。 txt_color 是文本颜色。 self.tf 是字体厚度。 cv2.LINE_AA 表示使用抗锯齿线条。
cv2.putText(im0, text, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)
# 这个方法的主要功能是在图像上显示对象的标签,并为文本添加一个背景矩形以提高可读性。具体步骤如下。获取文本尺寸:使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。计算文本位置:使文本居中显示在指定的中心位置。计算背景矩形坐标:在文本的每个方向上增加 margin 以确保文本不会紧贴矩形边缘。绘制背景矩形:使用 OpenCV 的 rectangle 方法绘制背景矩形。绘制文本:使用 OpenCV 的 putText 方法绘制文本。这种方法适用于在图像中标注对象的标签,尤其是在需要突出显示某些关键信息时。
# 这段代码定义了 Annotator 类中的一个方法 display_analytics ,用于在图像上显示一系列的统计信息(例如停车数量、占用率等)。该方法会在图像的右上角依次显示这些信息,并为每个文本添加一个背景矩形以提高可读性。
# 定义了一个方法 display_analytics ,接受以下参数 :
# 1.im0 :目标图像。
# 2.text :一个字典,包含要显示的统计信息,键是标签,值是对应的数值。
# 3.txt_color :文本颜色。
# 4.bg_color :背景矩形的颜色。
# 5.margin :文本与背景矩形边缘的间距。
def display_analytics(self, im0, text, txt_color, bg_color, margin):
# 显示停车场的总体统计数据。
"""
Display the overall statistics for parking lots.
Args:
im0 (ndarray): Inference image.
text (dict): Labels dictionary.
txt_color (tuple): Display color for text foreground.
bg_color (tuple): Display color for text background.
margin (int): Gap between text and rectangle for better display.
"""
# 计算 水平和垂直方向上的间隙 。
# horizontal_gap 是图像宽度的 2%。
horizontal_gap = int(im0.shape[1] * 0.02)
# vertical_gap 是图像高度的 1%。
vertical_gap = int(im0.shape[0] * 0.01)
# 初始化 text_y_offset ,用于记录当前文本的垂直偏移量。
text_y_offset = 0
# 遍历 text 字典中的每个键值对(标签和数值)。
for label, value in text.items():
# 将标签和数值格式化为字符串,例如 "Total: 100" 。
txt = f"{label}: {value}"
# 使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。 0 是字体类型( cv2.FONT_HERSHEY_SIMPLEX )。 self.sf 是字体缩放比例。 self.tf 是字体厚度。
text_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0]
# 如果文本尺寸小于 5x5,将其设置为最小尺寸 (5, 5) 。
if text_size[0] < 5 or text_size[1] < 5:
text_size = (5, 5)
# 计算文本的起始坐标。
# text_x 是文本的起始 x 坐标,计算为图像宽度减去文本宽度、间距和水平间隙。
text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gap
# text_y 是文本的起始 y 坐标,计算为当前偏移量加上文本高度、间距和垂直间隙。
text_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap
# 计算背景矩形的坐标。在文本的每个方向上增加 margin * 2 以确保文本不会紧贴矩形边缘。
# rect_x1 和 rect_y1 是背景矩形的左上角坐标。
rect_x1 = text_x - margin * 2
rect_y1 = text_y - text_size[1] - margin * 2
# rect_x2 和 rect_y2 是背景矩形的右下角坐标。
rect_x2 = text_x + text_size[0] + margin * 2
rect_y2 = text_y + margin * 2
# 使用 OpenCV 的 rectangle 方法绘制背景矩形。 im0 是目标图像。 (rect_x1, rect_y1) 和 (rect_x2, rect_y2) 是矩形的左上角和右下角坐标。 bg_color 是矩形的颜色。 -1 表示填充矩形。
cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
# 使用 OpenCV 的 putText 方法绘制文本。 im0 是目标图像。 txt 是要显示的文本内容。 (text_x, text_y) 是文本的起始坐标。 0 是字体类型( cv2.FONT_HERSHEY_SIMPLEX )。 self.sf 是字体缩放比例。 txt_color 是文本颜色。 self.tf 是字体厚度。 cv2.LINE_AA 表示使用抗锯齿线条。
cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)
# 更新 text_y_offset 为当前背景矩形的右下角 y 坐标,以便下一条文本在正确的位置绘制。
text_y_offset = rect_y2
# 这个方法的主要功能是在图像上显示一系列的统计信息,并为每个文本添加一个背景矩形以提高可读性。具体步骤如下。计算间隙:根据图像尺寸计算水平和垂直方向上的间隙。遍历统计信息:遍历 text 字典中的每个键值对(标签和数值)。格式化文本:将标签和数值格式化为字符串。获取文本尺寸:使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。计算文本位置:计算文本的起始坐标,使其靠右上角显示。计算背景矩形坐标:在文本的每个方向上增加间距以确保文本不会紧贴矩形边缘。绘制背景矩形:使用 OpenCV 的 rectangle 方法绘制背景矩形。绘制文本:使用 OpenCV 的 putText 方法绘制文本。更新偏移量:更新文本的垂直偏移量,以便下一条文本在正确的位置绘制。这种方法适用于在图像中标注统计信息,尤其是在需要突出显示某些关键信息时。
# 这段代码定义了 Annotator 类中的一个静态方法 estimate_pose_angle ,用于计算由三个点 a 、 b 和 c 形成的角度。
@staticmethod
# 定义了一个静态方法 estimate_pose_angle ,接受三个参数。
# 1.a 、 2.b 和 3.c :每个参数是一个包含两个浮点数的列表或元组,表示 (x, y) 坐标。
def estimate_pose_angle(a, b, c):
# 计算物体的姿势角度。
"""
Calculate the pose angle for object.
Args:
a (float) : The value of pose point a
b (float): The value of pose point b
c (float): The value o pose point c
Returns:
angle (degree): Degree value of angle between three points
"""
# 将输入的点坐标转换为 NumPy 数组,以便进行向量运算。
a, b, c = np.array(a), np.array(b), np.array(c)
# 计算向量 bc 和向量 ba 之间的角度(以弧度为单位)。
# np.arctan2(c[1] - b[1], c[0] - b[0]) 计算向量 bc 与 x 轴之间的角度。
# np.arctan2(a[1] - b[1], a[0] - b[0]) 计算向量 ba 与 x 轴之间的角度。
# 两个角度相减得到向量 bc 和向量 ba 之间的角度。
radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])
# 将 弧度转换为角度 。 radians * 180.0 / np.pi 将弧度转换为角度。 np.abs 取绝对值,确保角度为正数。
angle = np.abs(radians * 180.0 / np.pi)
# 如果计算得到的角度大于 180 度,则用 360 度减去该角度,得到较小的角度值。
if angle > 180.0:
angle = 360 - angle
# 返回计算得到的角度值。
return angle
# 这个静态方法的主要功能是计算由三个点 a 、 b 和 c 形成的角度。具体步骤如下。转换点坐标:将输入的点坐标转换为 NumPy 数组。计算弧度:计算向量 bc 和向量 ba 之间的角度(以弧度为单位)。转换为角度:将弧度转换为角度,并取绝对值。调整角度:如果角度大于 180 度,则用 360 度减去该角度,得到较小的角度值。返回结果:返回计算得到的角度值。这种方法适用于在图像处理或机器学习任务中需要计算角度的场景,例如在人体姿态估计中计算关节角度。使用静态方法可以方便地调用该功能,而无需创建类的实例。
# 这段代码定义了 Annotator 类中的一个方法 draw_specific_points ,用于在图像上绘制特定的关键点,并在这些关键点之间绘制连接线。
# 定义了一个方法 draw_specific_points ,接受以下参数 :
# 1.keypoints :关键点数据,每个关键点是一个包含 (x, y, confidence) 的列表或元组。
# 2.indices :需要绘制的关键点索引,默认为 [2, 5, 7] 。
# 3.radius :关键点的绘制半径,默认为 2。
# 4.conf_thres :关键点的置信度阈值,默认为 0.25。
def draw_specific_points(self, keypoints, indices=None, radius=2, conf_thres=0.25):
# 绘制特定关键点以进行健身房步数计数。
# 注意:
# 关键点格式:[x, y] 或 [x, y, 置信度]。
# 就地修改 self.im。
"""
Draw specific keypoints for gym steps counting.
Args:
keypoints (list): Keypoints data to be plotted.
indices (list, optional): Keypoint indices to be plotted. Defaults to [2, 5, 7].
radius (int, optional): Keypoint radius. Defaults to 2.
conf_thres (float, optional): Confidence threshold for keypoints. Defaults to 0.25.
Returns:
(numpy.ndarray): Image with drawn keypoints.
Note:
Keypoint format: [x, y] or [x, y, confidence].
Modifies self.im in-place.
"""
# 如果未指定 indices ,则使用默认值 [2, 5, 7] 。
indices = indices or [2, 5, 7]
# 从 keypoints 中提取指定索引 indices 的关键点,并过滤掉置信度低于 conf_thres 的关键点。
# k[0] 和 k[1] 是关键点的 x 和 y 坐标。
# k[2] 是关键点的置信度。
# points 是一个包含过滤后的关键点坐标的列表。
points = [(int(k[0]), int(k[1])) for i, k in enumerate(keypoints) if i in indices and k[2] >= conf_thres]
# Draw lines between consecutive points
# 在过滤后的关键点之间绘制连接线。
# points[:-1] 是所有起始点。
# points[1:] 是所有结束点。
for start, end in zip(points[:-1], points[1:]):
# 使用 OpenCV 的 line 方法绘制连接线。 self.im 是目标图像。 start 和 end 是连接线的起始点和结束点。 (0, 255, 0) 是连接线的颜色(绿色)。 2 是连接线的厚度。 cv2.LINE_AA 表示使用抗锯齿线条。
cv2.line(self.im, start, end, (0, 255, 0), 2, lineType=cv2.LINE_AA)
# Draw circles for keypoints
# 在每个关键点处绘制一个小圆圈。 pt 是关键点的坐标。
for pt in points:
# 使用 OpenCV 的 circle 方法绘制圆圈。 self.im 是目标图像。 pt 是圆圈的中心坐标。 radius 是圆圈的半径。 (0, 0, 255) 是圆圈的颜色(红色)。 -1 表示填充圆圈。 cv2.LINE_AA 表示使用抗锯齿线条。
cv2.circle(self.im, pt, radius, (0, 0, 255), -1, lineType=cv2.LINE_AA)
# 返回 绘制后的图像 。
return self.im
# 这个方法的主要功能是在图像上绘制特定的关键点,并在这些关键点之间绘制连接线。具体步骤如下。设置默认索引:如果未指定 indices ,则使用默认值 [2, 5, 7] 。过滤关键点:从 keypoints 中提取指定索引的关键点,并过滤掉置信度低于 conf_thres 的关键点。绘制连接线:在过滤后的关键点之间绘制连接线。绘制关键点:在每个关键点处绘制一个小圆圈。返回结果:返回绘制后的图像。这种方法适用于在图像中标注特定的关键点,尤其是在需要突出显示某些关键点及其连接关系时。
# 这段代码定义了 Annotator 类中的一个方法 plot_workout_information ,用于在图像上绘制工作信息(例如健身动作的次数、角度等)。该方法会在指定的位置显示文本标签,并为文本添加一个背景矩形以提高可读性。
# 定义了一个方法 plot_workout_information ,接受以下参数 :
# 1.display_text :要显示的文本内容。
# 2.position :文本的起始位置 (x, y) 。
# 3.color :背景矩形的颜色,默认为 (104, 31, 17) 。
# 4.txt_color :文本颜色,默认为白色 (255, 255, 255) 。
def plot_workout_information(self, display_text, position, color=(104, 31, 17), txt_color=(255, 255, 255)):
# 在图像上绘制带有背景的文本。
"""
Draw text with a background on the image.
Args:
display_text (str): The text to be displayed.
position (tuple): Coordinates (x, y) on the image where the text will be placed.
color (tuple, optional): Text background color
txt_color (tuple, optional): Text foreground color
"""
# (cv2.Size(width, height), baseline) = cv2.getTextSize(text, fontFace, fontScale, thickness)
# cv2.getTextSize 函数是 OpenCV 库中的一个函数,它用于计算给定文本的尺寸(宽度和高度)。这个函数在绘制文本之前非常有用,因为它可以帮助你确定文本的尺寸,从而可以进行适当的定位和布局。
# 参数 :
# text :要测量的文本字符串。
# fontFace :字体类型,可以是 OpenCV 预定义的字体,如 cv2.FONT_HERSHEY_SIMPLEX 、 cv2.FONT_HERSHEY_PLAIN 等。
# fontScale :字体缩放因子,用于调整字体大小。
# thickness :字体线条的厚度。
# 返回值 :
# cv2.Size(width, height) :一个 cv2.Size 对象,包含文本的宽度和高度。
# baseline :文本基线(即文本底部到基线的距离),这个值可以用来确定文本的位置,以确保基线对齐。
# 使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。
# 0 是字体类型( cv2.FONT_HERSHEY_SIMPLEX )。
# self.sf 是字体缩放比例。
# self.tf 是字体厚度。
# cv2.getTextSize 返回一个元组 (text_width, text_height) 和基线高度,这里只使用前两个值。
(text_width, text_height), _ = cv2.getTextSize(display_text, 0, self.sf, self.tf)
# Draw background rectangle
# 绘制背景矩形。 position[0] 和 position[1] 是文本的起始位置。
cv2.rectangle(
self.im,
# 背景矩形的左上角坐标为 (position[0], position[1] - text_height - 5) 。
(position[0], position[1] - text_height - 5),
# 背景矩形的右下角坐标为 (position[0] + text_width + 10, position[1] - text_height - 5 + text_height + 10 + self.tf) 。
(position[0] + text_width + 10, position[1] - text_height - 5 + text_height + 10 + self.tf),
# 背景矩形的颜色。
color,
# 表示填充矩形。
-1,
)
# Draw text
# 绘制文本。
# self.im 是目标图像。
# display_text 是要显示的文本内容。
# position 是文本的起始位置。
# 0 是字体类型( cv2.FONT_HERSHEY_SIMPLEX )。
# self.sf 是字体缩放比例。
# txt_color 是文本颜色。
# self.tf 是字体厚度。
cv2.putText(self.im, display_text, position, 0, self.sf, txt_color, self.tf)
# 返回 文本的高度 。
return text_height
# 这个方法的主要功能是在图像上绘制工作信息,并为文本添加一个背景矩形以提高可读性。具体步骤如下。获取文本尺寸:使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。计算背景矩形坐标:根据文本尺寸和位置计算背景矩形的左上角和右下角坐标。绘制背景矩形:使用 OpenCV 的 rectangle 方法绘制背景矩形。绘制文本:使用 OpenCV 的 putText 方法绘制文本。返回文本高度:返回文本的高度。这种方法适用于在图像中标注工作信息,尤其是在需要突出显示某些关键信息时。
# 这段代码定义了 Annotator 类中的一个方法 plot_angle_and_count_and_stage ,用于在图像上绘制角度、计数和阶段信息。这些信息通常用于健身动作的标注,例如显示当前的角度、完成的步数和当前的动作阶段。
# 定义了一个方法 plot_angle_and_count_and_stage ,接受以下参数 :
# 1.angle_text :角度信息,通常是一个浮点数。
# 2.count_text :计数信息,通常是一个整数。
# 3.stage_text :阶段信息,通常是一个字符串。
# 4.center_kpt :中心关键点的坐标 (x, y) ,用于确定文本的起始位置。
# 5.color :背景矩形的颜色,默认为 (104, 31, 17) 。
# 6.txt_color :文本颜色,默认为白色 (255, 255, 255) 。
def plot_angle_and_count_and_stage(
self, angle_text, count_text, stage_text, center_kpt, color=(104, 31, 17), txt_color=(255, 255, 255)
):
# 绘制姿势角度、计数值和步进阶段。
"""
Plot the pose angle, count value, and step stage.
Args:
angle_text (str): Angle value for workout monitoring
count_text (str): Counts value for workout monitoring
stage_text (str): Stage decision for workout monitoring
center_kpt (list): Centroid pose index for workout monitoring
color (tuple, optional): Text background color
txt_color (tuple, optional): Text foreground color
"""
# Format text
# 格式化文本内容。
# angle_text :角度信息,保留两位小数。
# count_text :计数信息,前缀为 "Steps : "。
# stage_text :阶段信息,前后添加空格以增加可读性。
angle_text, count_text, stage_text = f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}"
# Draw angle, count and stage text
# 调用 plot_workout_information 方法绘制角度信息。
# angle_text 是要显示的文本内容。 center_kpt 是文本的起始位置。 color 是背景矩形的颜色。 txt_color 是文本颜色。
# 返回值 angle_height 是绘制的角度文本的高度。
angle_height = self.plot_workout_information(
angle_text, (int(center_kpt[0]), int(center_kpt[1])), color, txt_color
)
# 调用 plot_workout_information 方法绘制计数信息。
# count_text 是要显示的文本内容。 计数信息的起始位置在角度信息下方,垂直偏移量为 angle_height + 20 。 color 是背景矩形的颜色。 txt_color 是文本颜色。
# 返回值 count_height 是绘制的计数文本的高度。
count_height = self.plot_workout_information(
count_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + 20), color, txt_color
)
# 调用 plot_workout_information 方法绘制阶段信息。
# stage_text 是要显示的文本内容。 阶段信息的起始位置在计数信息下方,垂直偏移量为 angle_height + count_height + 40 。 color 是背景矩形的颜色。 txt_color 是文本颜色。
self.plot_workout_information(
stage_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + count_height + 40), color, txt_color
)
# 这个方法的主要功能是在图像上绘制角度、计数和阶段信息,并为每个文本添加一个背景矩形以提高可读性。具体步骤如下。格式化文本:将角度、计数和阶段信息格式化为字符串。绘制角度信息:调用 plot_workout_information 方法绘制角度信息,并获取其高度。绘制计数信息:在角度信息下方绘制计数信息,并获取其高度。绘制阶段信息:在计数信息下方绘制阶段信息。这种方法适用于在图像中标注健身动作的相关信息,尤其是在需要突出显示角度、计数和阶段时。通过逐行绘制文本并动态调整位置,可以确保信息的清晰显示。
# 这段代码定义了 Annotator 类中的一个方法 seg_bbox ,用于在图像上绘制分割掩码(segmentation mask)和可选的标签。
# 定义了一个方法 seg_bbox ,接受以下参数 :
# 1.mask :分割掩码的多边形点,格式为一个二维数组,其中每个点是一个包含两个整数的列表或元组 (x, y) 。
# 2.mask_color :掩码的颜色,默认为紫色 (255, 0, 255) 。
# 3.label :可选的文本标签,默认为 None 。
# 4.txt_color :文本颜色,默认为白色 (255, 255, 255) 。
def seg_bbox(self, mask, mask_color=(255, 0, 255), label=None, txt_color=(255, 255, 255)):
# 用于以边界框形状绘制分割对象的函数。
"""
Function for drawing segmented object in bounding box shape.
Args:
mask (np.ndarray): A 2D array of shape (N, 2) containing the contour points of the segmented object.
mask_color (tuple): RGB color for the contour and label background.
label (str, optional): Text label for the object. If None, no label is drawn.
txt_color (tuple): RGB color for the label text.
"""
# 如果掩码为空( mask.size == 0 ),则直接返回,不进行任何绘制操作。
if mask.size == 0: # no masks to plot
return
# 使用 OpenCV 的 polylines 方法 绘制分割掩码的多边形边界 。 self.im 是目标图像。 [np.int32([mask])] 是一个多边形点的列表,每个点的坐标被转换为整数类型。 isClosed=True 表示多边形是闭合的。 mask_color 是多边形的颜色。 thickness=2 是多边形边的厚度。
cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
# 如果提供了标签 label ,使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。 0 是字体类型( cv2.FONT_HERSHEY_SIMPLEX )。 self.sf 是字体缩放比例。 self.tf 是字体厚度。 text_size 是一个元组 (text_width, text_height) 。
text_size, _ = cv2.getTextSize(label, 0, self.sf, self.tf)
# 如果提供了标签 label ,则绘制背景矩形和文本。
if label:
# 绘制背景矩形。
cv2.rectangle(
self.im,
# 背景矩形的左上角坐标为 (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10) 。
(int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
# 背景矩形的右下角坐标为 (int(mask[0][0]) + text_size[0] // 2 + 10, int(mask[0][1] + 10)) 。
(int(mask[0][0]) + text_size[0] // 2 + 10, int(mask[0][1] + 10)),
# 背景矩形的颜色。
mask_color,
# 表示填充矩形。
-1,
)
# 绘制文本。
cv2.putText(
# self.im 是目标图像。 label 是要显示的文本内容。 文本的起始位置为 (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])) 。 0 是字体类型( cv2.FONT_HERSHEY_SIMPLEX )。 self.sf 是字体缩放比例。 txt_color 是文本颜色。 self.tf 是字体厚度。
self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])), 0, self.sf, txt_color, self.tf
)
# 这个方法的主要功能是在图像上绘制分割掩码的多边形边界,并在掩码的第一个点附近绘制可选的文本标签。具体步骤如下。检查掩码是否为空:如果掩码为空,则直接返回。绘制多边形边界:使用 OpenCV 的 polylines 方法绘制掩码的多边形边界。获取文本尺寸:如果提供了标签,使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。绘制背景矩形:在掩码的第一个点附近绘制背景矩形。绘制文本:在背景矩形上绘制文本。这种方法适用于在图像中标注分割掩码和相关标签,尤其是在需要突出显示某些区域时。
# 这段代码定义了 Annotator 类中的一个方法 sweep_annotator ,用于在图像上绘制一条扫描线(sweep line)和可选的标签。
# 定义了一个方法 sweep_annotator ,接受以下参数 :
# 1.line_x :扫描线的 x 坐标,默认为 0。
# 2.line_y :扫描线的 y 坐标限制,默认为 0。
# 3.label :可选的文本标签,默认为 None 。
# 4.color :扫描线和背景矩形的颜色,默认为 (221, 0, 186) 。
# 5.txt_color :文本颜色,默认为白色 (255, 255, 255) 。
def sweep_annotator(self, line_x=0, line_y=0, label=None, color=(221, 0, 186), txt_color=(255, 255, 255)):
# 用于绘制扫描注释线和可选标签的函数。
"""
Function for drawing a sweep annotation line and an optional label.
Args:
line_x (int): The x-coordinate of the sweep line.
line_y (int): The y-coordinate limit of the sweep line.
label (str, optional): Text label to be drawn in center of sweep line. If None, no label is drawn.
color (tuple): RGB color for the line and label background.
txt_color (tuple): RGB color for the label text.
"""
# Draw the sweep line
# 使用 OpenCV 的 line 方法绘制扫描线。 self.im 是目标图像。 (line_x, 0) 是扫描线的起始点。 (line_x, line_y) 是扫描线的结束点。 color 是扫描线的颜色。 self.tf * 2 是扫描线的厚度。
cv2.line(self.im, (line_x, 0), (line_x, line_y), color, self.tf * 2)
# Draw label, if provided
# 如果提供了标签 label ,则绘制背景矩形和文本。
if label:
# 使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。 cv2.FONT_HERSHEY_SIMPLEX 是字体类型。 self.sf 是字体缩放比例。 self.tf 是字体厚度。 text_width 和 text_height 是文本的宽度和高度。
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf, self.tf)
# 绘制背景矩形。
cv2.rectangle(
self.im,
# 背景矩形的左上角坐标为 (line_x - text_width // 2 - 10, line_y // 2 - text_height // 2 - 10) 。
(line_x - text_width // 2 - 10, line_y // 2 - text_height // 2 - 10),
# 背景矩形的右下角坐标为 (line_x + text_width // 2 + 10, line_y // 2 + text_height // 2 + 10) 。
(line_x + text_width // 2 + 10, line_y // 2 + text_height // 2 + 10),
# 背景矩形的颜色。
color,
# 表示填充矩形。
-1,
)
# 绘制文本。
cv2.putText(
# 目标图像。
self.im,
# 要显示的文本内容。
label,
# 文本的起始位置为 (line_x - text_width // 2, line_y // 2 + text_height // 2) 。
(line_x - text_width // 2, line_y // 2 + text_height // 2),
# 字体类型。
cv2.FONT_HERSHEY_SIMPLEX,
# 字体缩放比例。
self.sf,
# 文本颜色。
txt_color,
# 字体厚度。
self.tf,
)
# 这个方法的主要功能是在图像上绘制一条扫描线,并在扫描线的中心位置绘制可选的文本标签。具体步骤如下。绘制扫描线:使用 OpenCV 的 line 方法绘制扫描线。获取文本尺寸:如果提供了标签,使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。绘制背景矩形:在扫描线的中心位置绘制背景矩形。绘制文本:在背景矩形上绘制文本。这种方法适用于在图像中标注扫描线和相关标签,尤其是在需要突出显示某些特定位置时。
# 这段代码定义了 Annotator 类中的一个方法 plot_distance_and_line ,用于在图像上绘制两个质心(centroids)之间的距离和连接线。
# 定义了一个方法 plot_distance_and_line ,接受以下参数 :
# 1.pixels_distance :两个质心之间的像素距离,通常是一个浮点数。
# 2.centroids :一个包含两个质心坐标的列表,每个质心是一个包含两个整数的元组 (x, y) 。
# 3.line_color :连接线的颜色,默认为 (104, 31, 17) 。
# 4.centroid_color :质心的颜色,默认为 (255, 0, 255) 。
def plot_distance_and_line(
self, pixels_distance, centroids, line_color=(104, 31, 17), centroid_color=(255, 0, 255)
):
# 在框架上绘制距离和线。
"""
Plot the distance and line on frame.
Args:
pixels_distance (float): Pixels distance between two bbox centroids.
centroids (list): Bounding box centroids data.
line_color (tuple, optional): Distance line color.
centroid_color (tuple, optional): Bounding box centroid color.
"""
# Get the text size
# 格式化文本内容并获取文本的宽度和高度。
# text 是要显示的文本内容,格式为 "Pixels Distance: XX.XX" ,其中 XX.XX 是像素距离,保留两位小数。
text = f"Pixels Distance: {pixels_distance:.2f}" # 像素距离:{pixels_distance:.2f} 。
# 使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。 0 是字体类型( cv2.FONT_HERSHEY_SIMPLEX )。 self.sf 是字体缩放比例。 self.tf 是字体厚度。 text_width_m 和 text_height_m 是文本的宽度和高度。
(text_width_m, text_height_m), _ = cv2.getTextSize(text, 0, self.sf, self.tf)
# Define corners with 10-pixel margin and draw rectangle
# 绘制背景矩形。 背景矩形的左上角坐标为 (15, 25) 。 背景矩形的右下角坐标为 (15 + text_width_m + 20, 25 + text_height_m + 20) ,增加了 10 像素的边距。 line_color 是背景矩形的颜色。 -1 表示填充矩形。
cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 20, 25 + text_height_m + 20), line_color, -1)
# Calculate the position for the text with a 10-pixel margin and draw text
# 绘制文本。
# text_position 是文本的起始位置,计算为 (25, 25 + text_height_m + 10) ,增加了 10 像素的边距。
text_position = (25, 25 + text_height_m + 10)
# 使用 OpenCV 的 putText 方法绘制文本。
cv2.putText(
# 目标图像。
self.im,
# 要显示的文本内容。
text,
# 文本的起始位置。
text_position,
# 字体类型( cv2.FONT_HERSHEY_SIMPLEX )。
0,
# 字体缩放比例。
self.sf,
# 文本颜色(白色)。
(255, 255, 255),
# 字体厚度。
self.tf,
# 表示使用抗锯齿线条。
cv2.LINE_AA,
)
# 绘制两个质心之间的连接线。 self.im 是目标图像。 centroids[0] 和 centroids[1] 是两个质心的坐标。 line_color 是连接线的颜色。 3 是连接线的厚度。
cv2.line(self.im, centroids[0], centroids[1], line_color, 3)
# 绘制两个质心。
# self.im 是目标图像。
# centroids[0] 和 centroids[1] 是两个质心的坐标。
# 6 是圆圈的半径。
# centroid_color 是圆圈的颜色。
# -1 表示填充圆圈。
cv2.circle(self.im, centroids[0], 6, centroid_color, -1)
cv2.circle(self.im, centroids[1], 6, centroid_color, -1)
# 这个方法的主要功能是在图像上绘制两个质心之间的距离和连接线。具体步骤如下。格式化文本内容:将像素距离格式化为字符串。获取文本尺寸:使用 OpenCV 的 getTextSize 方法获取文本的宽度和高度。绘制背景矩形:在文本周围绘制一个背景矩形,增加边距以提高可读性。绘制文本:在背景矩形上绘制文本。绘制连接线:在两个质心之间绘制一条连接线。绘制质心:在两个质心位置绘制圆圈。这种方法适用于在图像中标注两个点之间的距离和连接关系,尤其是在需要突出显示某些特定位置时。
# 这段代码定义了 Annotator 类中的一个方法 visioneye ,用于在图像上绘制一个边界框的中心点和一个指定的中心点,并在这两个点之间绘制一条连接线。这种方法通常用于可视化目标检测中的“视觉焦点”或“注视点”。
# 定义了一个方法 visioneye ,接受以下参数 :
# 1.box :边界框的坐标,格式为 (x1, y1, x2, y2) ,其中 (x1, y1) 是左上角坐标, (x2, y2) 是右下角坐标。
# 2.center_point :指定的中心点坐标,格式为 (x, y) 。
# 3.color :边界框中心点和连接线的颜色,默认为 (235, 219, 11) 。
# 4.pin_color :指定中心点的颜色,默认为 (255, 0, 255) 。
def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255)):
# 用于精确定位人眼视觉映射和绘图的函数。
"""
Function for pinpoint human-vision eye mapping and plotting.
Args:
box (list): Bounding box coordinates
center_point (tuple): center point for vision eye view
color (tuple): object centroid and line color value
pin_color (tuple): visioneye point color value
"""
# 计算边界框的中心点坐标 center_bbox 。
# int((box[0] + box[2]) / 2) 是边界框的中心 x 坐标。
# int((box[1] + box[3]) / 2) 是边界框的中心 y 坐标。
center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
# 在指定的中心点 center_point 处绘制一个小圆圈。 self.im 是目标图像。 center_point 是圆圈的中心坐标。 self.tf * 2 是圆圈的半径。 pin_color 是圆圈的颜色。 -1 表示填充圆圈。
cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)
# 在边界框的中心点 center_bbox 处绘制一个小圆圈。 self.im 是目标图像。 center_bbox 是圆圈的中心坐标。 self.tf * 2 是圆圈的半径。 color 是圆圈的颜色。 -1 表示填充圆圈。
cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)
# 在指定的中心点 center_point 和边界框的中心点 center_bbox 之间绘制一条连接线。 self.im 是目标图像。 center_point 和 center_bbox 是连接线的起始点和结束点。 color 是连接线的颜色。 self.tf 是连接线的厚度。
cv2.line(self.im, center_point, center_bbox, color, self.tf)
# 这个方法的主要功能是在图像上绘制一个边界框的中心点和一个指定的中心点,并在这两个点之间绘制一条连接线。具体步骤如下。计算边界框中心点:根据边界框的坐标计算其中心点。绘制指定中心点:在指定的中心点处绘制一个小圆圈。绘制边界框中心点:在边界框的中心点处绘制一个小圆圈。绘制连接线:在两个中心点之间绘制一条连接线。这种方法适用于在图像中标注目标检测中的“视觉焦点”或“注视点”,尤其是在需要突出显示某些特定位置时。
# Annotator 类是一个功能丰富的图像标注工具类,广泛应用于计算机视觉任务中,特别是在目标检测、关键点标注和分割掩码可视化等场景。它提供了多种方法来绘制边界框、关键点、文本标签、分割掩码以及自定义图形,支持动态调整线条宽度、字体大小和颜色,能够根据输入图像的类型(PIL 图像或 NumPy 数组)自动选择合适的绘制库(PIL 或 OpenCV)。此外, Annotator 类还具备将标注结果保存为图像文件或直接显示的功能,极大地简化了图像标注流程,提高了开发效率。
4.def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
# 这段代码定义了一个名为 plot_labels 的函数,用于绘制训练标签,包括类别直方图和边界框统计信息。
# 使用 TryExcept 装饰器对函数进行包装。 TryExcept 用于捕获函数运行过程中可能出现的异常,并进行相应的处理(例如记录日志或返回默认值)。
# class TryExcept(contextlib.ContextDecorator):
# -> 这个类的作用是作为一个上下文管理器,用于捕获和处理异常,并可以选择性地打印异常信息。
# -> def __init__(self, msg="", verbose=True):
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
# 使用 plt_settings 装饰器对函数进行包装。 plt_settings 用于配置 Matplotlib 的绘图设置,例如全局样式、字体等。
# def plt_settings(rcparams=None, backend="Agg"): -> 的函数,它返回一个装饰器( decorator ),用于在执行被装饰的函数时临时设置 Matplotlib 的绘图参数( rcparams )和后端( backend )。返回 decorator 函数。这是 plt_settings 函数的返回值, decorator 函数将作为装饰器被使用。 -> return decorator
@plt_settings()
# 定义了函数 plot_labels ,它接收以下参数 :
# 1.boxes :边界框数据,二维数组,每行表示一个边界框的坐标(x, y, width, height)。
# 2.cls :类别标签,一维数组,每个标签对应一个边界框。
# 3.names :类别名称的字典,键为类别索引,值为类别名称。
# 4.save_dir :保存绘图结果的目录,默认为当前目录。
# 5.on_plot :可选的回调函数,用于在绘图完成后执行某些操作。
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
# 绘制训练标签,包括类别直方图和框统计。
"""Plot training labels including class histograms and box statistics."""
# 导入 pandas 模块,用于数据处理。注释表明这是为了加快 ultralytics 模块的导入速度。
import pandas # scope for faster 'import ultralytics'
# 导入 seaborn 模块,用于绘图。同样是为了加快 ultralytics 模块的导入速度。
import seaborn # scope for faster 'import ultralytics'
# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical
# 使用 warnings.filterwarnings 忽略特定的警告。
# 忽略 Matplotlib 的 UserWarning 。
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight") # 图形布局已更改为紧密。
# 忽略所有 FutureWarning ,这些警告通常与库的未来版本中的变化有关。
warnings.filterwarnings("ignore", category=FutureWarning)
# Plot dataset labels
# 使用日志记录器 LOGGER 记录一条信息,表明正在将标签绘制到指定路径。
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ") # 将标签绘制到 {save_dir / 'labels.jpg'}...
# 计算 类别总数 nc ,通过取 cls 的最大值加1得到。
nc = int(cls.max() + 1) # number of classes
# 将边界框数据限制为最多100万个,以避免内存问题。
boxes = boxes[:1000000] # limit to 1M boxes
# 将边界框数据转换为 pandas.DataFrame ,列名为 x 、 y 、 width 和 height 。
x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
# Seaborn correlogram
# 使用 seaborn.pairplot 绘制边界框数据的散点图矩阵(correlogram),并保存为 labels_correlogram.jpg 。设置了绘图参数,例如直方图的柱数为50,最大点密度为0.9。
seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
# 保存图像后关闭绘图窗口。
plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
plt.close()
# 这段代码是 plot_labels 函数中的一部分,主要负责使用 Matplotlib 和 Seaborn 绘制训练数据的标签信息。
# Matplotlib labels
# 使用 plt.subplots 创建一个 2x2 的子图布局,图像大小为 8x8 英寸,并启用紧凑布局( tight_layout=True )。
# plt.subplots 返回一个元组,其中第一个元素是 Figure 对象,第二个元素是子图的轴数组。
# 使用 [1].ravel() 获取子图的轴数组,并将其展平为一维数组,方便后续操作。
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
# 在第一个子图( ax[0] )中绘制 类别直方图 。 cls 是类别标签数组。 使用 np.linspace(0, nc, nc + 1) - 0.5 生成类别索引的边界,确保每个类别标签之间有明确的分隔。 rwidth=0.8 设置柱状图的宽度为 0.8。
# ax[0].hist 返回一个 包含直方图数据的元组 ,存储在变量 y 中。
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
# 遍历每个类别 i ,设置直方图柱状图的颜色。
for i in range(nc):
# y[2].patches[i] 是第 i 个柱状图。 使用 colors(i) 获取类别 i 对应的颜色,并将其从 RGB 格式转换为归一化值(除以 255)。
y[2].patches[i].set_color([x / 255 for x in colors(i)])
# 设置第一个子图的 y 轴标签为“instances”,表示每个类别的实例数量。
ax[0].set_ylabel("instances")
# 如果类别名称的数量在 1 到 30 之间。
if 0 < len(names) < 30:
# 将 x 轴刻度设置为类别索引。
ax[0].set_xticks(range(len(names)))
# 将 x 轴刻度标签设置为类别名称( names.values() ),并旋转 90 度显示,字体大小为 10。
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
# 如果类别数量超出范围,则将 x 轴标签设置为“classes”。
else:
ax[0].set_xlabel("classes")
# 在第三个子图( ax[2] )中绘制边界框中心点的二维直方图。 x 是包含边界框数据的 pandas.DataFrame 。 x="x" 和 y="y" 指定绘制 x 和 y 坐标。 bins=50 设置直方图的柱数。 pmax=0.9 设置最大点密度为 0.9。
seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
# 在第四个子图( ax[3] )中绘制边界框宽度和高度的二维直方图。 x="width" 和 y="height" 指定绘制宽度和高度。 其他参数与上一个直方图相同。
seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
# 这段代码的主要功能是。创建子图布局:使用 plt.subplots 创建一个 2x2 的子图布局。绘制类别直方图:在第一个子图中绘制类别直方图,显示每个类别的实例数量,并根据类别名称设置 x 轴标签。绘制边界框中心点分布:在第三个子图中绘制边界框中心点的二维直方图,展示边界框在图像中的分布情况。绘制边界框宽高分布:在第四个子图中绘制边界框宽度和高度的二维直方图,展示边界框的尺寸分布。通过这些可视化,可以直观地了解训练数据的类别分布和边界框的特征,有助于数据预处理和模型训练的优化。
# 这段代码是 plot_labels 函数的最后部分,主要负责绘制边界框的可视化图像,并将结果保存到文件中。
# Rectangles
# 将边界框的中心点坐标设置为 (0.5, 0.5) ,即图像的中心位置。 这里假设边界框的坐标是以 相对图像尺寸的比例 表示的,范围在 [0, 1] 之间。
boxes[:, 0:2] = 0.5 # center
# 使用 ops.xywh2xyxy 将边界框从中心宽高格式( x_center, y_center, width, height )转换为坐标格式( x1, y1, x2, y2 )。 将转换后的坐标乘以 1000,将边界框缩放到 1000x1000 的图像尺寸上。
boxes = ops.xywh2xyxy(boxes) * 1000
# 创建一个 1000x1000 的白色背景图像。 使用 np.ones((1000, 1000, 3), dtype=np.uint8) * 255 生成一个全白的三维数组。 使用 Image.fromarray 将数组转换为 PIL 图像。
img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
# 遍历前 500 个类别标签和边界框。
for cls, box in zip(cls[:500], boxes[:500]):
# 使用 ImageDraw.Draw(img) 创建一个绘图对象。 使用 rectangle 方法在图像上绘制边界框,框的宽度为 1。 边界框的颜色由 colors(cls) 提供, cls 是类别索引。
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
# 在第二个子图( ax[1] )中显示绘制了边界框的图像。
ax[1].imshow(img)
# 关闭第二个子图的坐标轴,使图像更清晰。
ax[1].axis("off")
# 遍历所有子图( ax[0] 、 ax[1] 、 ax[2] 、 ax[3] ),隐藏每个子图的边框( spines )。
for a in [0, 1, 2, 3]:
for s in ["top", "right", "left", "bottom"]:
# spines[s].set_visible(False) 将边框设置为不可见。
ax[a].spines[s].set_visible(False)
# 定义保存图像的文件名路径,文件名为 labels.jpg ,保存在 save_dir 目录中。
fname = save_dir / "labels.jpg"
# 将绘制的图像保存到指定路径,分辨率为 200 DPI。
plt.savefig(fname, dpi=200)
# 关闭绘图窗口,释放资源。
plt.close()
# 如果提供了回调函数 on_plot ,则调用它,并将保存的图像路径作为参数。
if on_plot:
on_plot(fname)
# 这段代码的主要功能是。绘制边界框:将边界框的中心点设置为图像中心。将边界框从中心宽高格式转换为坐标格式,并缩放到 1000x1000 的图像上。在白色背景图像上绘制前 500 个边界框,每个边界框的颜色由类别索引决定。显示和保存图像:在第二个子图中显示绘制了边界框的图像,并关闭坐标轴。隐藏所有子图的边框,使图像更简洁。将最终图像保存到指定路径。回调函数:如果提供了回调函数 on_plot ,则在保存图像后调用该函数,将保存的图像路径作为参数。通过这些操作,可以直观地展示边界框在图像中的分布情况,帮助分析训练数据的质量和特征。
# 这段代码定义了一个函数 plot_labels ,用于绘制训练数据的标签信息,包括。类别直方图:显示每个类别的实例数量。边界框中心点分布:绘制边界框中心点的二维直方图。边界框宽高分布:绘制边界框宽度和高度的二维直方图。边界框可视化:在图像上绘制边界框,直观展示边界框的分布情况。通过这些可视化,可以更好地理解训练数据的分布特征,帮助调整模型训练策略或优化数据预处理步骤。
5.def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
# 这段代码定义了一个函数 save_one_box ,用于从图像中裁剪出一个指定的边界框区域,并将其保存为一个新的图像文件。
# 定义了一个函数 save_one_box ,接受以下参数 :
# 1.xyxy :边界框的坐标,格式为 (x1, y1, x2, y2) ,可以是一个列表或 PyTorch 张量。
# 2.im :原始图像,通常是一个 NumPy 数组。
# 3.file :保存裁剪图像的文件路径,默认为 Path("im.jpg") 。
# 4.gain :边界框尺寸的放大比例,默认为 1.02。
# 5.pad :边界框的填充像素,默认为 10。
# 6.square :是否将边界框转换为正方形,默认为 False 。
# 7.BGR :是否保存为 BGR 格式,默认为 False (保存为 RGB 格式)。
# 8.save :是否保存裁剪的图像,默认为 True 。
def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
# 将图像裁剪保存为 {file},裁剪大小为 {gain} 和 {pad} 像素的倍数。保存和/或返回裁剪。
# 此函数接受边界框和图像,然后根据边界框保存裁剪的图像部分。裁剪可以是可选的平方,并且该函数允许对边界框进行增益和填充调整。
"""
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
This function takes a bounding box and an image, and then saves a cropped portion of the image according
to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
adjustments to the bounding box.
Args:
xyxy (torch.Tensor or list): A tensor or list representing the bounding box in xyxy format.
im (numpy.ndarray): The input image.
file (Path, optional): The path where the cropped image will be saved. Defaults to 'im.jpg'.
gain (float, optional): A multiplicative factor to increase the size of the bounding box. Defaults to 1.02.
pad (int, optional): The number of pixels to add to the width and height of the bounding box. Defaults to 10.
square (bool, optional): If True, the bounding box will be transformed into a square. Defaults to False.
BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB. Defaults to False.
save (bool, optional): If True, the cropped image will be saved to disk. Defaults to True.
Returns:
(numpy.ndarray): The cropped image.
Example:
```python
from ultralytics.utils.plotting import save_one_box
xyxy = [50, 50, 150, 150]
im = cv2.imread("image.jpg")
cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True)
```
"""
# 如果 xyxy 不是 PyTorch 张量,则将其转换为张量。这一步确保 xyxy 是一个张量,以便后续操作。
if not isinstance(xyxy, torch.Tensor): # may be list
xyxy = torch.stack(xyxy)
# 将边界框坐标从 (x1, y1, x2, y2) 转换为 (x_center, y_center, width, height) 格式。 xyxy.view(-1, 4) 确保输入是一个二维张量,每行表示一个边界框。
b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes
# 如果 square=True ,则将边界框转换为正方形。具体做法是取宽度和高度的最大值,使边界框的宽度和高度相等。
if square:
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
# 对边界框的宽度和高度进行放大,并添加填充像素。 gain 是放大比例, pad 是填充像素。
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
# 将边界框坐标从 (x_center, y_center, width, height) 转换回 (x1, y1, x2, y2) 格式,并将坐标转换为整数类型。
xyxy = ops.xywh2xyxy(b).long()
# 将裁剪的边界框坐标限制在原始图像的范围内,确保裁剪区域不会超出图像边界。
xyxy = ops.clip_boxes(xyxy, im.shape)
# 从原始图像中裁剪出指定的边界框区域。
# int(xyxy[0, 1]) : int(xyxy[0, 3]) 是裁剪区域的 y 范围。
# int(xyxy[0, 0]) : int(xyxy[0, 2]) 是裁剪区域的 x 范围。
# :: (1 if BGR else -1) 用于调整颜色通道顺序,如果 BGR=True ,则保持 BGR 格式;否则,转换为 RGB 格式。
crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
# 如果 save=True ,则保存裁剪的图像。
if save:
# 确保保存路径的目录存在。
file.parent.mkdir(parents=True, exist_ok=True) # make directory
# 生成一个唯一的文件路径,避免文件名冲突。
f = str(increment_path(file).with_suffix(".jpg"))
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
# 将裁剪的图像保存为 RGB 格式的 JPEG 文件。
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
# 返回裁剪的图像。
return crop
# 这个函数的主要功能是从图像中裁剪出一个指定的边界框区域,并将其保存为一个新的图像文件。具体步骤如下。转换边界框格式:将边界框坐标从 (x1, y1, x2, y2) 转换为 (x_center, y_center, width, height) 格式。调整边界框尺寸:如果需要,将边界框转换为正方形,并对宽度和高度进行放大,添加填充像素。裁剪图像:从原始图像中裁剪出指定的边界框区域。保存图像:如果指定保存,将裁剪的图像保存为一个新的文件。这种方法适用于在目标检测任务中裁剪和保存感兴趣的目标区域,尤其是在需要进一步处理或分析这些区域时。
6.def plot_images(images: Union[torch.Tensor, np.ndarray], batch_idx: Union[torch.Tensor, np.ndarray], cls: Union[torch.Tensor, np.ndarray], bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32), confs: Optional[Union[torch.Tensor, np.ndarray]] = None, masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8), kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32), paths: Optional[List[str]] = None, fname: str = "images.jpg", names: Optional[Dict[int, str]] = None, on_plot: Optional[Callable] = None, max_size: int = 1920, max_subplots: int = 16, save: bool = True, conf_thres: float = 0.25, ) -> Optional[np.ndarray]:
# 这段代码定义了一个函数 plot_images ,用于将一批图像及其标注(边界框、类别、置信度、掩码和关键点)绘制到一个马赛克图像中,并保存为一个文件。
# def threaded(func): -> 用于将目标函数多线程化。根据传入的参数决定是否将目标函数运行在单独的线程中 。返回内部函数 wrapper ,这是装饰器的标准行为。当 @threaded 装饰器被应用到某个函数时,实际上会用 wrapper 函数替换目标函数。 -> return wrapper
@threaded
# 定义了一个函数 plot_images ,接受以下参数 :
# 1.images :一批图像,可以是 PyTorch 张量或 NumPy 数组。
# 2.batch_idx :每个标注对应的图像索引,可以是 PyTorch 张量或 NumPy 数组。
# 3.cls :每个标注的类别,可以是 PyTorch 张量或 NumPy 数组。
# 4.bboxes :边界框坐标,格式为 (x1, y1, x2, y2) ,默认为空数组。
# 5.confs :置信度,可选,默认为 None 。
# 6.masks :分割掩码,格式为 (n, h, w) ,默认为空数组。
# 7.kpts :关键点坐标,格式为 (n, 51) ,默认为空数组。
# 8.paths :图像路径列表,可选,默认为 None 。
# 9.fname :保存的文件名,默认为 "images.jpg" 。
# 10.names :类别名称字典,可选,默认为 None 。
# 11.on_plot :绘图完成后的回调函数,可选,默认为 None 。
# 12.max_size :马赛克图像的最大尺寸,默认为 1920。
# 13.max_subplots :马赛克图像中的最大子图数量,默认为 16。
# 14.save :是否保存马赛克图像,默认为 True 。
# 15.conf_thres :置信度阈值,默认为 0.25。
def plot_images(
images: Union[torch.Tensor, np.ndarray],
batch_idx: Union[torch.Tensor, np.ndarray],
cls: Union[torch.Tensor, np.ndarray],
bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
confs: Optional[Union[torch.Tensor, np.ndarray]] = None,
masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),
paths: Optional[List[str]] = None,
fname: str = "images.jpg",
names: Optional[Dict[int, str]] = None,
on_plot: Optional[Callable] = None,
max_size: int = 1920,
max_subplots: int = 16,
save: bool = True,
conf_thres: float = 0.25,
) -> Optional[np.ndarray]:
# 绘制带有标签、边界框、掩码和关键点的图像网格。
# 注意:
# 此函数支持张量和 numpy 数组输入。它会自动将张量输入转换为 numpy 数组进行处理。
"""
Plot image grid with labels, bounding boxes, masks, and keypoints.
Args:
images: Batch of images to plot. Shape: (batch_size, channels, height, width).
batch_idx: Batch indices for each detection. Shape: (num_detections,).
cls: Class labels for each detection. Shape: (num_detections,).
bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.
confs: Confidence scores for each detection. Shape: (num_detections,).
masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).
kpts: Keypoints for each detection. Shape: (num_detections, 51).
paths: List of file paths for each image in the batch.
fname: Output filename for the plotted image grid.
names: Dictionary mapping class indices to class names.
on_plot: Optional callback function to be called after saving the plot.
max_size: Maximum size of the output image grid.
max_subplots: Maximum number of subplots in the image grid.
save: Whether to save the plotted image grid to a file.
conf_thres: Confidence threshold for displaying detections.
Returns:
np.ndarray: Plotted image grid as a numpy array if save is False, None otherwise.
Note:
This function supports both tensor and numpy array inputs. It will automatically
convert tensor inputs to numpy arrays for processing.
"""
# 这段代码的作用是将输入的 PyTorch 张量转换为 NumPy 数组,以便后续操作可以顺利进行。
# 检查 images 是否为 PyTorch 张量。
if isinstance(images, torch.Tensor):
# 如果是,将其移动到 CPU(如果不在 CPU 上),转换为浮点类型( float ),并最终转换为 NumPy 数组。
images = images.cpu().float().numpy()
# 检查 cls (类别标签)是否为 PyTorch 张量。
if isinstance(cls, torch.Tensor):
# 如果是,将其移动到 CPU 并转换为 NumPy 数组。
cls = cls.cpu().numpy()
# 检查 bboxes (边界框坐标)是否为 PyTorch 张量。
if isinstance(bboxes, torch.Tensor):
# 如果是,将其移动到 CPU 并转换为 NumPy 数组。
bboxes = bboxes.cpu().numpy()
# 检查 masks (分割掩码)是否为 PyTorch 张量。
if isinstance(masks, torch.Tensor):
# 如果是,将其移动到 CPU,转换为 NumPy 数组,并将数据类型转换为整数( int )。
masks = masks.cpu().numpy().astype(int)
# 检查 kpts (关键点坐标)是否为 PyTorch 张量。
if isinstance(kpts, torch.Tensor):
# 如果是,将其移动到 CPU 并转换为 NumPy 数组。
kpts = kpts.cpu().numpy()
# 检查 batch_idx (批处理索引)是否为 PyTorch 张量。
if isinstance(batch_idx, torch.Tensor):
# 如果是,将其移动到 CPU 并转换为 NumPy 数组。
batch_idx = batch_idx.cpu().numpy()
# 这段代码的主要功能是确保所有输入的 PyTorch 张量都被转换为 NumPy 数组。这一步是必要的,因为后续的图像处理和绘图操作通常需要 NumPy 数组作为输入。具体步骤如下。检查输入类型:逐个检查每个输入是否为 PyTorch 张量。转换为 NumPy 数组:如果输入是 PyTorch 张量,则将其移动到 CPU(如果不在 CPU 上)。将张量转换为 NumPy 数组。如果需要,将数据类型转换为适当的类型(例如,将掩码转换为整数类型)。这种方法确保了输入数据的一致性,使得后续操作可以顺利进行,尤其是在涉及 OpenCV 或其他基于 NumPy 的图像处理库时。
# 这段代码的作用是构建一个马赛克图像,将一批图像排列成一个网格,并根据需要调整图像的大小。
# 获取输入图像的 批量大小 bs 、 高度 h 和 宽度 w 。 _ 是一个占位符,表示忽略通道数(假设输入图像是 (batch_size, channels, height, width) 格式)。
bs, _, h, w = images.shape # batch size, _, height, width
# 将 批量大小 bs 限制为不超过 max_subplots ,以避免马赛克图像中子图过多。
bs = min(bs, max_subplots) # limit plot images
# 计算马赛克图像中的子图数量 ns ,使其为平方数(即每行和每列的子图数量相等)。 使用 np.ceil 确保 ns 是一个整数。
ns = np.ceil(bs**0.5) # number of subplots (square)
# 检查图像像素值是否在 [0, 1] 范围内。如果是,则将其转换为 [0, 255] 范围,以便后续操作。
if np.max(images[0]) <= 1:
images *= 255 # de-normalise (optional)
# Build Image
# 初始化马赛克图像 mosaic ,大小为 (ns * h, ns * w, 3) ,填充为白色(RGB 值为 255)。 dtype=np.uint8 确保图像数据类型为无符号 8 位整数。
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
# 遍历每张图像。
for i in range(bs):
# 计算其在马赛克图像中的位置 (x, y) 。
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
# 将每张图像放置到马赛克图像的相应位置。 使用 transpose(1, 2, 0) 将图像的通道顺序从 (channels, height, width) 转换为 (height, width, channels) ,以满足 NumPy 数组的格式要求。
mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
# Resize (optional)
# 计算 缩放比例 scale ,确保马赛克图像的最大尺寸不超过 max_size 。 max_size / ns 是每行或每列的子图的最大尺寸。 max(h, w) 是子图的最大边长。
scale = max_size / ns / max(h, w)
# 如果需要缩小图像( scale < 1 ),则调整子图的高宽 h 和 w 。
if scale < 1:
h = math.ceil(scale * h)
w = math.ceil(scale * w)
# 使用 cv2.resize 方法将马赛克图像调整到新的大小。
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
# 这段代码的主要功能是构建一个马赛克图像,将一批图像排列成一个网格,并根据需要调整图像的大小。具体步骤如下。获取图像尺寸:获取输入图像的批量大小、高度和宽度。限制批量大小:将批量大小限制为不超过 max_subplots 。计算子图数量:计算马赛克图像中的子图数量,使其为平方数。调整像素值范围:如果图像像素值在 [0, 1] 范围内,则将其转换为 [0, 255] 范围。初始化马赛克图像:创建一个白色背景的马赛克图像。放置子图:将每张图像放置到马赛克图像的相应位置。调整马赛克图像大小:根据 max_size 缩放马赛克图像。这种方法适用于在目标检测任务中可视化一批图像及其标注,尤其是在需要将多个图像排列成一个网格时。
# 这段代码的作用是在构建的马赛克图像上进行标注,包括绘制边界框、显示文件名、类别标签等。
# Annotate
# 计算字体大小 fs ,基于图像的高度 h 、宽度 w 和子图数量 ns 。字体大小与图像尺寸成正比,确保标注在不同尺寸的图像上都能清晰可见。
fs = int((h + w) * ns * 0.01) # font size
# 创建一个 Annotator 对象,用于在马赛克图像上进行标注。
# mosaic 是马赛克图像。 line_width 是线条宽度,设置为字体大小的十分之一。 font_size 是字体大小。 pil=True 表示使用 PIL 进行绘制。 example=names 提供类别名称字典,用于生成示例文本。
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
# 遍历每张图像。
for i in range(bs):
# 计算其在马赛克图像中的起始位置 (x, y) 。 i // ns 和 i % ns 分别计算当前图像所在的行和列索引。
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
# 使用 Annotator 对象在每张图像的周围绘制白色边框。 边框的坐标为 [x, y, x + w, y + h] 。 边框颜色为白色 (255, 255, 255) 。 边框宽度为 2。
# def rectangle(self, xy, fill=None, outline=None, width=1): -> 用于在 PIL 模式下向图像中添加一个矩形。
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
# 如果提供了 路径列表 paths 。
if paths:
# 则在每张图像的左上角显示文件名。 文件名取路径的最后 40 个字符,以避免过长的文件名超出图像边界。 文本颜色为浅灰色 (220, 220, 220) 。
# def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False): -> 用于在图像上绘制文本。该方法支持在 PIL 模式和 OpenCV 模式下绘制文本,并且可以选择是否为文本添加背景框。
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
# 如果提供了 类别 cls ,则进行以下操作
if len(cls) > 0:
# 找到 当前图像对应的标注索引 idx 。
idx = batch_idx == i
# 提取 当前图像的类别 classes 。
classes = cls[idx].astype("int")
# 检查是否提供了置信度 confs ,如果没有提供,则 labels=True ,表示这些是标签而不是预测结果。
labels = confs is None
# 这段代码的主要功能是在马赛克图像上进行标注,包括绘制边界框、显示文件名和类别标签。具体步骤如下。计算字体大小:根据图像尺寸计算合适的字体大小。创建标注器:初始化 Annotator 对象,用于在马赛克图像上进行标注。遍历每张图像:计算每张图像在马赛克图像中的位置。绘制边框:在每张图像的周围绘制白色边框。显示文件名:如果提供了路径列表 paths ,则在每张图像的左上角显示文件名。处理类别信息:如果提供了类别 cls ,则提取当前图像的类别信息,并检查是否提供了置信度。这种方法适用于在目标检测任务中可视化一批图像及其标注,尤其是在需要将多个图像排列成一个网格并显示详细信息时。
# 这段代码的作用是处理边界框( bboxes )并将其绘制到马赛克图像上。它支持处理归一化坐标和绝对坐标,并根据置信度阈值筛选边界框。
# 检查是否有边界框数据。如果 bboxes 为空,则跳过后续处理。
if len(bboxes):
# 根据 当前图像的索引 idx ,提取 对应的边界框数据 。
boxes = bboxes[idx]
# 如果提供了 置信度 confs ,则提取当前图像的置信度数据;如果没有提供,则将 conf 设置为 None 。这一步用于区分标签(无置信度)和预测结果(有置信度)。
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
# 如果提取到的边界框不为空,则继续处理。
if len(boxes):
# 检查边界框坐标是否为归一化坐标(最大值不超过 1.1,允许一定的容差)。如果是归一化坐标,则将其转换为像素坐标。
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
# 将 x 坐标乘以图像宽度 w 。
boxes[..., [0, 2]] *= w # scale to pixels
# 将 y 坐标乘以图像高度 h 。
boxes[..., [1, 3]] *= h
# 如果边界框坐标是绝对坐标,并且图像被缩放( scale < 1 ),则对边界框坐标进行缩放。
elif scale < 1: # absolute coords need scale if image scales
boxes[..., :4] *= scale
# 将边界框的 x 和 y 坐标偏移,使其在马赛克图像中的位置正确。
boxes[..., 0] += x
boxes[..., 1] += y
# 检查边界框是否为旋转边界框( xywhr 格式,包含 5 个值)。如果是,则 is_obb=True 。
is_obb = boxes.shape[-1] == 5 # xywhr
# 将边界框坐标从 xywhr (旋转边界框)或 xywh (普通边界框)格式转换为 xyxy 格式(左上角和右下角坐标)。 如果 is_obb=True ,则使用 ops.xywhr2xyxyxyxy 转换旋转边界框。 否则,使用 ops.xywh2xyxy 转换普通边界框。
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
# 遍历每个边界框,将其转换为整数列表。
for j, box in enumerate(boxes.astype(np.int64).tolist()):
# 提取当前边界框的 类别 c 。
c = classes[j]
# 获取类别对应的颜色 color 。
color = colors(c)
# 如果提供了类别名称字典 names ,则将类别索引转换为类别名称;否则,直接使用类别索引。
c = names.get(c, c) if names else c
# 如果当前边界框是标签( labels=True ),或者置信度大于阈值 conf_thres ,则进行绘制。
if labels or conf[j] > conf_thres:
# 格式化标签文本。
# 如果是标签,则只显示类别名称。
# 如果是预测结果,则显示类别名称和置信度(保留一位小数)。
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
# 使用 Annotator 对象在马赛克图像上绘制 边界框 和 标签 。 box 是边界框的坐标。 label 是标签文本。 color 是边界框和标签的颜色。 rotated=is_obb 指示是否绘制旋转边界框。
# def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): -> 用于在图像上绘制边界框(bounding box)和标签。该方法支持绘制普通的矩形框和旋转的多边形框,并且可以选择使用 PIL 或 OpenCV 进行绘制。
annotator.box_label(box, label, color=color, rotated=is_obb)
# 这段代码的主要功能是处理边界框数据,并将其绘制到马赛克图像上。具体步骤如下。提取边界框数据:根据当前图像的索引提取边界框和置信度数据。检查边界框坐标类型:判断边界框坐标是归一化坐标还是绝对坐标,并进行相应的转换。调整边界框位置:将边界框的坐标偏移,使其在马赛克图像中的位置正确。转换边界框格式:将边界框坐标从 xywhr 或 xywh 格式转换为 xyxy 格式。遍历每个边界框:对每个边界框进行处理,包括提取类别、获取颜色、格式化标签文本。绘制边界框和标签:使用 Annotator 对象在马赛克图像上绘制边界框和标签。这种方法适用于在目标检测任务中可视化边界框及其类别信息,尤其是在需要将多个图像排列成一个网格并显示详细标注时。
# 这段代码的作用是在马赛克图像上绘制类别标签,当没有边界框数据但有类别信息时。
# 如果 classes 列表不为空(即存在类别信息),但没有边界框数据,则进入以下逻辑。
elif len(classes):
# 遍历 classes 列表中的每个类别索引 c 。
for c in classes:
# 根据类别索引 c 获取对应的类别颜色。 colors 是一个函数或字典,用于将类别索引映射到颜色值。
color = colors(c)
# 如果提供了类别名称字典 names ,则将类别索引 c 转换为对应的类别名称。 如果没有提供 names ,则直接使用类别索引 c 。
c = names.get(c, c) if names else c
# 使用 Annotator 对象在马赛克图像上绘制类别标签。 (x, y) 是文本的起始位置,通常是当前子图的左上角。 f"{c}" 是要显示的文本内容,即类别名称或类别索引。 txt_color=color 是文本颜色,根据类别索引获取的颜色值。 box_style=True 表示为文本绘制一个背景框,以提高可读性。
annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
# 这段代码的主要功能是在马赛克图像上绘制类别标签,适用于没有边界框数据但有类别信息的情况。具体步骤如下。检查类别信息:如果 classes 列表不为空,则进入绘制逻辑。遍历类别:遍历 classes 列表中的每个类别索引。获取类别颜色:根据类别索引获取对应的类别颜色。转换类别名称:如果提供了类别名称字典 names ,则将类别索引转换为类别名称;否则,直接使用类别索引。绘制文本:使用 Annotator 对象在马赛克图像上绘制类别标签,并为文本添加背景框。这种方法适用于在目标检测任务中可视化类别信息,尤其是在需要显示类别标签但没有边界框数据时。
# 这段代码的作用是在马赛克图像上绘制关键点( kpts ),并根据需要调整关键点的坐标。
# Plot keypoints
# 检查是否有关键点数据。如果 kpts 为空,则跳过后续处理。
if len(kpts):
# 根据当前图像的索引 idx ,提取对应的关鍵点数据,并创建一个副本以避免修改原始数据。
kpts_ = kpts[idx].copy()
# 如果提取到的关键点不为空,则继续处理。
if len(kpts_):
# 检查关键点的坐标是否为归一化坐标(最大值不超过 1.01,允许一定的容差)。如果是归一化坐标,则将其转换为像素坐标。
if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01
# 将 x 坐标乘以图像宽度 w 。
kpts_[..., 0] *= w # scale to pixels
# 将 y 坐标乘以图像高度 h 。
kpts_[..., 1] *= h
# 如果关键点坐标是绝对坐标,并且图像被缩放( scale < 1 ),则对关键点坐标进行缩放。
elif scale < 1: # absolute coords need scale if image scales
kpts_ *= scale
# 将关键点的 x 和 y 坐标偏移,使其在马赛克图像中的位置正确。
kpts_[..., 0] += x
kpts_[..., 1] += y
# 遍历每个关键点。
for j in range(len(kpts_)):
# 如果当前关键点是标签( labels=True ),或者置信度大于阈值 conf_thres ,则进行绘制。
if labels or conf[j] > conf_thres:
# 使用 Annotator 对象在马赛克图像上绘制关键点。 kpts_[j] 是当前关键点的坐标。 conf_thres 是置信度阈值,用于过滤关键点。
# def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None): -> 用于在图像上绘制关键点(keypoints)和关键点之间的连接线(skeleton)。
annotator.kpts(kpts_[j], conf_thres=conf_thres)
# 这段代码的主要功能是在马赛克图像上绘制关键点,并根据置信度阈值筛选关键点。具体步骤如下。检查关键点数据:如果 kpts 为空,则跳过后续处理。提取关键点数据:根据当前图像的索引提取关键点数据。检查关键点坐标类型:判断关键点坐标是归一化坐标还是绝对坐标,并进行相应的转换。调整关键点位置:将关键点的坐标偏移,使其在马赛克图像中的位置正确。遍历每个关键点:对每个关键点进行处理,包括检查置信度是否大于阈值。绘制关键点:使用 Annotator 对象在马赛克图像上绘制关键点。这种方法适用于在目标检测任务中可视化关键点,尤其是在需要将多个图像排列成一个网格并显示关键点时。
# 这段代码的作用是在马赛克图像上绘制分割掩码( masks ),并根据置信度阈值筛选掩码。
# Plot masks
# 检查是否有分割掩码数据。如果 masks 为空,则跳过后续处理。
if len(masks):
# 如果 idx 的长度等于 masks 的长度,说明每个掩码对应一个边界框( overlap_masks=False ),直接提取当前图像的掩码。
if idx.shape[0] == masks.shape[0]: # overlap_masks=False
image_masks = masks[idx]
# 如果 idx 的长度不等于 masks 的长度,说明所有掩码重叠( overlap_masks=True )。
else: # overlap_masks=True
# 提取当前图像的掩码(假设掩码是单个图像的掩码)。
image_masks = masks[[i]] # (1, 640, 640)
# 计算当前图像中掩码的数量 nl 。
nl = idx.sum()
# 创建一个索引数组 index ,用于区分不同的掩码。
index = np.arange(nl).reshape((nl, 1, 1)) + 1
# 使用 np.repeat 将掩码重复 nl 次。
image_masks = np.repeat(image_masks, nl, axis=0)
# 通过 np.where 将掩码值设置为 1.0 或 0.0。
image_masks = np.where(image_masks == index, 1.0, 0.0)
# 将当前的马赛克图像转换为 NumPy 数组,并创建一个副本以避免修改原始图像。
im = np.asarray(annotator.im).copy()
# 遍历每个掩码。
for j in range(len(image_masks)):
# 如果当前掩码是标签( labels=True ),或者置信度大于阈值 conf_thres ,则进行绘制。
if labels or conf[j] > conf_thres:
# 根据类别索引 classes[j] 获取对应的类别颜色。
color = colors(classes[j])
# 获取当前掩码 image_masks[j] 的高度 mh 和宽度 mw 。
mh, mw = image_masks[j].shape
# 检查掩码的尺寸是否与目标图像的尺寸不一致。如果掩码的尺寸与目标图像的尺寸不一致,则需要调整掩码的尺寸。
if mh != h or mw != w:
# 将掩码转换为 np.uint8 类型,以便后续使用 OpenCV 的 cv2.resize 方法。
mask = image_masks[j].astype(np.uint8)
# 使用 OpenCV 的 cv2.resize 方法将掩码调整为目标图像的尺寸 (w, h) 。
mask = cv2.resize(mask, (w, h))
# 将调整后的掩码转换为布尔类型,以便用作布尔索引。
mask = mask.astype(bool)
# 如果掩码的尺寸与目标图像的尺寸一致,则直接将掩码转换为布尔类型。
else:
mask = image_masks[j].astype(bool)
# 尝试将掩码绘制到马赛克图像的指定区域。
try:
# im[y : y + h, x : x + w, :] 是马赛克图像中当前子图的区域。
# mask 是布尔掩码,用于选择该区域内的像素。
im[y : y + h, x : x + w, :][mask] = (
# 将目标图像的像素值与掩码颜色混合,实现掩码的可视化。
# im[y : y + h, x : x + w, :][mask] * 0.4 :目标图像的像素值乘以 0.4(保留部分原始图像信息)。
# np.array(color) * 0.6 :掩码颜色乘以 0.6(覆盖部分原始图像信息)。
# 两者相加,实现半透明效果。
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
)
# 如果在绘制过程中发生异常(例如掩码尺寸不匹配),则跳过当前掩码的绘制。
except Exception:
pass
# 将修改后的图像重新转换为 Annotator 对象的内部图像格式。
# def fromarray(self, im): -> 用于将 NumPy 数组转换为 PIL 图像,并更新绘图对象。
annotator.fromarray(im)
# 这段代码的主要功能是在马赛克图像上绘制分割掩码,并根据置信度阈值筛选掩码。具体步骤如下。检查掩码数据:如果 masks 为空,则跳过后续处理。提取掩码数据:根据当前图像的索引提取掩码数据。处理掩码重叠情况:根据 overlap_masks 的值处理掩码数据。遍历每个掩码:对每个掩码进行处理,包括检查置信度是否大于阈值。调整掩码尺寸:如果掩码尺寸与目标图像尺寸不一致,则调整掩码尺寸。绘制掩码:将掩码与目标图像混合,实现掩码的可视化。更新图像:将修改后的图像重新转换为 Annotator 对象的内部图像格式。这种方法适用于在目标检测任务中可视化分割掩码,尤其是在需要将多个图像排列成一个网格并显示分割掩码时。
# 这段代码的作用是根据 save 参数的值决定是否保存绘制后的图像,并在必要时调用回调函数。
# 如果 save 参数为 False 。
if not save:
# 则不保存图像,而是直接返回绘制后的图像作为 NumPy 数组。 np.asarray(annotator.im) 将 Annotator 对象中的图像转换为 NumPy 数组。
return np.asarray(annotator.im)
# 如果 save 参数为 True ,则保存绘制后的图像到指定的文件路径 fname 。 annotator.im.save(fname) 使用 PIL 的 save 方法保存图像。
annotator.im.save(fname) # save
# 如果提供了回调函数 on_plot ,则调用该函数,并将保存的文件路径 fname 作为参数传递。 这通常用于在图像保存后执行一些额外的操作,例如显示图像或记录日志。
if on_plot:
on_plot(fname)
# 这段代码的主要功能是根据 save 参数的值决定是否保存绘制后的图像,并在必要时调用回调函数。具体步骤如下。检查是否保存图像:如果 save=False ,则返回绘制后的图像作为 NumPy 数组。如果 save=True ,则保存图像到指定的文件路径。调用回调函数:如果提供了回调函数 on_plot ,则调用该函数,并将保存的文件路径作为参数传递。这种方法适用于在图像处理任务中灵活处理图像的保存和后续操作,特别是在需要根据用户需求动态决定是否保存图像时。
# plot_images 函数是一个用于可视化目标检测结果的工具,能够将一批图像及其标注(包括边界框、类别、置信度、关键点和分割掩码)绘制到一个马赛克图像中。该函数支持处理归一化和绝对坐标,根据置信度阈值筛选标注,并在图像上绘制类别标签、边界框、关键点和分割掩码。此外,它还支持保存最终的马赛克图像,并在保存后调用回调函数以执行额外操作。
# def plot_images(images: Union[torch.Tensor, np.ndarray], batch_idx: Union[torch.Tensor, np.ndarray], cls: Union[torch.Tensor, np.ndarray], bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
# confs: Optional[Union[torch.Tensor, np.ndarray]] = None, masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8), kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32), paths: Optional[List[str]] = None,
# fname: str = "images.jpg", names: Optional[Dict[int, str]] = None, on_plot: Optional[Callable] = None, max_size: int = 1920, max_subplots: int = 16, save: bool = True, conf_thres: float = 0.25, ) -> Optional[np.ndarray]:
7.def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
# 这段代码定义了一个名为 plot_results 的函数,用于从 CSV 文件中读取数据并生成图表,支持多种任务类型(分类、分割、姿态估计等),并可以将结果保存为图像文件。
# 一个装饰器,用于设置绘图相关的全局配置。
@plt_settings()
# 定义了一个函数 plot_results ,接受以下参数 :
# 1.file :默认值为 "path/to/results.csv" ,指定包含结果数据的 CSV 文件路径。
# 2.dir :默认为空字符串,指定包含结果文件的目录路径。
# 3.segment 、 4.pose 、 5.classify :布尔值参数,分别用于指定任务类型(分割、姿态估计、分类)。这些参数用于确定图表的布局和数据索引。
# 6.on_plot :可选参数,用于在绘图完成后调用的回调函数,接收保存的图像文件路径作为参数。
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
# 从结果 CSV 文件绘制训练结果。该函数支持各种类型的数据,包括分割、姿势估计和分类。绘图将以“results.png”的形式保存在 CSV 所在的目录中。
"""
Plot training results from a results CSV file. The function supports various types of data including segmentation,
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
Args:
file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.
dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.
segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.
pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.
classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.
on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
Defaults to None.
Example:
```python
from ultralytics.utils.plotting import plot_results
plot_results("path/to/results.csv", segment=True)
```
"""
# 导入 Pandas 库,用于读取和处理 CSV 文件中的数据。注释提到这是为了加快 import ultralytics 的速度,但代码中并未使用 ultralytics ,可能是为其他部分预留的优化。
import pandas as pd # scope for faster 'import ultralytics'
# 从 scipy.ndimage 模块导入 gaussian_filter1d 函数,用于对数据进行平滑处理,生成平滑曲线。
from scipy.ndimage import gaussian_filter1d
# 使用 Path (来自 pathlib 模块)确定保存目录。如果指定了 file 参数,则保存目录为该文件的父目录;否则使用 dir 参数作为保存目录。
save_dir = Path(file).parent if file else Path(dir)
# 如果任务类型为 分类 ( classify=True )。
if classify:
# 则创建一个 2×2 的子图布局,总尺寸为 6×6 英寸,并启用紧凑布局。
fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
# 定义 index 列表,用于指定 CSV 文件中哪些列的数据需要绘制。
index = [2, 5, 3, 4]
# 如果任务类型为 分割 ( segment=True )。
elif segment:
# 则创建一个 2×8 的子图布局,总尺寸为 18×6 英寸,并启用紧凑布局。
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
# 定义 index 列表,指定需要绘制的 CSV 列。
index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]
# 如果任务类型为 姿态估计 ( pose=True )。
elif pose:
# 则创建一个 2×9 的子图布局,总尺寸为 21×6 英寸,并启用紧凑布局。
fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
# 定义 index 列表,指定需要绘制的 CSV 列。
index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]
# 如果任务类型未指定为分类、分割或姿态估计。
else:
# 则创建一个 2×5 的子图布局,总尺寸为 12×6 英寸,并启用紧凑布局。
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
# 定义 index 列表,指定需要绘制的 CSV 列。
index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]
# np.ndarray.ravel()
# ravel() 函数是 NumPy 库中的一个方法,它将多维数组(ndarray)展平成一维数组。这个方法不会复制数组的数据,而是返回一个新的视图(view),这个视图与原始数组共享相同的数据。
# 参数 :
# order :可选参数,指定展平的顺序。默认是 'C',表示按行主序(C-style),即按列优先顺序展平。如果设置为 'F',则表示按列主序(Fortran-style),即按行优先顺序展平。如果设置为 'A',则保持数组的原始顺序。
# 返回值 :
# 返回一个展平后的一维数组。
# 将二维的 ax 数组展平为一维数组,方便后续通过索引访问每个子图。
ax = ax.ravel()
# 使用 glob 方法查找保存目录中所有以 "results" 开头的 CSV 文件,并将结果存储为列表。
files = list(save_dir.glob("results*.csv"))
# 断言检查是否找到了 CSV 文件。如果没有找到,抛出异常并提示用户。
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot." # 在 {save_dir.resolve()} 中未找到 results.csv 文件,没有可绘制的内容。
# 遍历找到的 CSV 文件列表。
for f in files:
# 使用 try 块捕获可能的异常,确保即使某个文件处理失败,程序仍能继续处理其他文件。
try:
# 使用 Pandas 读取当前 CSV 文件的内容。
data = pd.read_csv(f)
# 获取 CSV 文件的列名,并去除首尾空格。
s = [x.strip() for x in data.columns]
# 提取 CSV 文件的第一列数据作为横坐标 x 。
x = data.values[:, 0]
# 遍历 index 列表,通过索引 j 获取 需要绘制的列 。
for i, j in enumerate(index):
# 提取第 j 列的数据,并将其转换为浮点类型。
y = data.values[:, j].astype("float")
# 这行代码被注释掉了,原本的作用是将值为 0 的数据替换为 NaN ,从而在绘图时不显示这些值。
# y[y == 0] = np.nan # don't show zero values
# 在第 i 个子图中绘制 原始数据曲线 ,使用点标记( marker="." ),并设置标签为文件名( f.stem )。
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
# 在第 i 个子图中绘制平滑后的曲线,使用虚线( ":" ),平滑通过高斯滤波实现,标准差为 3。
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
# 设置子图的标题为 CSV 文件的第 j 列的列名。
ax[i].set_title(s[j], fontsize=12)
# 这行代码被注释掉了,原本的作用是让某些子图共享 y 轴。
# if j in {8, 9, 10}: # share train and val loss y axes
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
# 如果在处理某个文件时发生异常,记录警告信息,但不会中断程序。
except Exception as e:
LOGGER.warning(f"WARNING: Plotting error for {f}: {e}") # 警告:{f} 的绘图错误:{e}。
# 在第二个子图中添加图例。
ax[1].legend()
# 定义保存图像文件的路径为保存目录下的 "results.png" 。
fname = save_dir / "results.png"
# 将整个图表保存为 PNG 文件,分辨率为 200 DPI。
fig.savefig(fname, dpi=200)
# 关闭当前图表,释放资源。
plt.close()
# 如果指定了 on_plot 回调函数,则调用该函数,并将保存的图像文件路径作为参数传递。
if on_plot:
on_plot(fname)
# 这段代码实现了一个通用的绘图函数,用于从 CSV 文件中读取数据并生成图表。它支持多种任务类型(分类、分割、姿态估计等),并根据任务类型动态调整图表布局和数据列的索引。代码还提供了平滑曲线的功能,并支持将结果保存为图像文件。此外,它还具备错误处理机制,能够记录警告信息并继续处理其他文件。
8.def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
# 这段代码定义了一个函数 plt_color_scatter ,用于根据数据点的密度绘制彩色散点图。
# 定义了一个函数 plt_color_scatter ,接受以下参数 :
# 1.v 和 2.f :两个一维数组,分别表示散点图的横坐标和纵坐标。
# 3.bins :默认值为 20,表示用于计算二维直方图的分箱数。
# 4.cmap :默认值为 "viridis" ,表示用于散点图的颜色映射表(colormap)。
# 5.alpha :默认值为 0.8,表示散点的透明度。
# 6.edgecolors :默认值为 "none" ,表示散点的边缘颜色。
def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
# 绘制基于二维直方图的点着色散点图。
"""
Plots a scatter plot with points colored based on a 2D histogram.
Args:
v (array-like): Values for the x-axis.
f (array-like): Values for the y-axis.
bins (int, optional): Number of bins for the histogram. Defaults to 20.
cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'.
alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8.
edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'.
Examples:
>>> v = np.random.rand(100)
>>> f = np.random.rand(100)
>>> plt_color_scatter(v, f)
"""
# Calculate 2D histogram and corresponding colors
# 使用 np.histogram2d 计算二维直方图。 v 和 f 是输入数据,分别表示 x 和 y 的值。 bins 指定了直方图的分箱数。 返回值 :
# hist :二维直方图的值,表示每个分箱中的点数。
# xedges 和 yedges :分别是 x 和 y 方向的分箱边界。
hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
# 计算每个数据点的颜色值。
colors = [
hist[
# np.digitize(v[i], xedges, right=True) 和 np.digitize(f[i], yedges, right=True) 分别计算每个数据点在 x 和 y 方向上的分箱索引。
# 由于分箱索引从 1 开始,因此需要减去 1 转换为从 0 开始的索引。
# 使用 min 函数确保索引不会超出 hist 的范围。
min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),
]
# 对于每个数据点,通过索引从 hist 中获取对应的值(即该点所在的分箱中的点数),并将其作为颜色值。
for i in range(len(v))
]
# Scatter plot
# 使用 plt.scatter 绘制散点图。
# v 和 f 分别作为 x 和 y 坐标。
# c=colors :将计算得到的颜色值传递给 c 参数,用于根据密度着色。
# cmap=cmap :指定颜色映射表。
# alpha=alpha :设置散点的透明度。
# edgecolors=edgecolors :设置散点的边缘颜色。
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
# 这段代码实现了一个根据数据点密度着色的散点图绘制函数。它通过计算二维直方图来确定每个数据点所在的分箱,并根据分箱中的点数为每个点分配颜色。颜色值越高,表示该区域的点越密集。这种可视化方式特别适合展示数据的分布密度,同时保持了散点图的直观性。
9.def plot_tune_results(csv_file="tune_results.csv"):
# 这段代码定义了一个函数 plot_tune_results ,用于从 CSV 文件中读取超参数调优结果,并生成两种类型的图表。散点图:展示每个超参数与模型性能(fitness)的关系。性能随迭代次数变化的曲线图:展示模型性能(fitness)随迭代次数的变化趋势。
# 定义了一个函数 plot_tune_results ,接受一个参数。
# 1.csv_file :默认值为 "tune_results.csv" ,表示包含超参数调优结果的 CSV 文件路径。
def plot_tune_results(csv_file="tune_results.csv"):
# 绘制存储在“tune_results.csv”文件中的演化结果。该函数为 CSV 中的每个键生成一个散点图,并根据适应度得分进行颜色编码。性能最佳的配置会在图中突出显示。
"""
Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
Args:
csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.
Examples:
>>> plot_tune_results("path/to/tune_results.csv")
"""
# 导入必要的库。
# pandas 用于读取 CSV 文件。
import pandas as pd # scope for faster 'import ultralytics'
# gaussian_filter1d 用于对数据进行平滑处理。
from scipy.ndimage import gaussian_filter1d
# 这段代码定义了一个辅助函数 _save_one_file ,用于保存当前 Matplotlib 图表到指定文件路径,并记录保存信息。
# 定义了一个函数 _save_one_file ,接受一个参数。
# 1.file :表示要保存的文件路径。
def _save_one_file(file):
# 将一个 matplotlib 图保存到“文件”。
"""Save one matplotlib plot to 'file'."""
# 使用 Matplotlib 的 savefig 方法将当前图表保存到指定路径 file 。 参数 dpi=200 指定了保存图像的分辨率,即每英寸 200 像素。较高的 DPI 值可以生成更清晰的图像。
plt.savefig(file, dpi=200)
# 调用 plt.close() 关闭当前图表。这是为了释放绘图资源,避免在绘制多个图表时占用过多内存。如果不关闭当前图表,可能会导致内存泄漏或资源不足的问题。
plt.close()
# 使用 LOGGER (是一个日志记录器)记录一条信息,提示用户图表已成功保存到指定路径。 f"Saved {file}" 是一个格式化字符串,将 file 的值插入到日志消息中。
LOGGER.info(f"Saved {file}")
# 这个函数的主要作用是封装 Matplotlib 图表的保存操作,同时记录保存信息。它的优点包括。代码复用:通过将保存操作封装到一个函数中,可以在需要保存图表时直接调用,避免重复代码。资源管理:通过调用 plt.close() ,确保每次保存后释放绘图资源,避免资源泄漏。日志记录:通过记录保存信息,方便用户跟踪图表的保存路径和状态。在实际使用中,这个函数通常与绘图代码结合,用于保存生成的图表到文件系统中。
# 这段代码的功能是生成超参数调优结果的散点图,每个超参数对应一个子图,展示其与模型性能(fitness)的关系。
# Scatter plots for each hyperparameter
# 将输入的 csv_file 路径转换为 Path 对象(来自 pathlib 模块)。
csv_file = Path(csv_file)
# 使用 Pandas 的 read_csv 方法读取 CSV 文件内容到 data 数据框中。
data = pd.read_csv(csv_file)
# 假设 CSV 文件的第一列是模型性能指标(fitness),其余列是超参数。
num_metrics_columns = 1
# keys 是 超参数的名称列表 ,通过去除第一列后获取,并去除列名中的多余空格。
keys = [x.strip() for x in data.columns][num_metrics_columns:]
# 将数据框 data 转换为 NumPy 数组 x 。
x = data.values
# 提取第一列作为性能指标(fitness)。
fitness = x[:, 0] # fitness
# 使用 np.argmax 找到性能最高的索引 j ,即性能最优的实验配置。
j = np.argmax(fitness) # max fitness index
# 计算 子图的行列数 n ,使得子图数量接近正方形( math.ceil(len(keys) ** 0.5) )。
n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
# 创建一个 10×10 英寸的图表,并启用紧凑布局( tight_layout=True ),以避免子图之间的标签重叠。
plt.figure(figsize=(10, 10), tight_layout=True)
# 遍历每个 超参数 k ,索引为 i 。
for i, k in enumerate(keys):
# 提取当前超参数的值 v (从数组 x 中获取)。
v = x[:, i + num_metrics_columns]
# 找到 性能最高时该超参数的值 mu (即索引为 j 的值)。
mu = v[j] # best single result
# 创建一个子图,位置为 (n, n, i + 1) ,即第 i + 1 个子图。
plt.subplot(n, n, i + 1)
# 调用 plt_color_scatter 函数绘制散点图。 v 是当前超参数的值。 fitness 是模型性能。 使用 viridis 颜色映射表(colormap),透明度为 0.8,散点无边框( edgecolors="none" )。 注意 : plt_color_scatter 是一个自定义函数,用于根据数据点的密度着色。
plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")
# 在性能最高点处绘制一个黑色十字标记( "k+" ),标记大小为 15。 这表示当前超参数的最优值及其对应的性能。
plt.plot(mu, fitness.max(), "k+", markersize=15)
# 设置子图标题,显示 超参数名称 k 和 其最优值 mu ,保留三位有效数字( :.3g )。 使用字体大小为 9 的字体。
plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters
# 设置 x 轴和 y 轴的标签字体大小为 8。
plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
# 如果当前子图不是第一列,则隐藏 y 轴刻度,以避免重复刻度并保持图表整洁。
if i % n != 0:
plt.yticks([])
# 使用 _save_one_file 函数保存整个图表为 "tune_scatter_plots.png" ,文件名基于原始 CSV 文件路径。 _save_one_file 函数会保存图表到文件,并记录保存信息。
_save_one_file(csv_file.with_name("tune_scatter_plots.png"))
# 这段代码实现了一个用于可视化超参数调优结果的散点图绘制功能。每个超参数对应一个子图,展示了超参数值与模型性能的关系,并标记了最优值。这种可视化方式可以帮助用户快速了解每个超参数对模型性能的影响,以及最优超参数配置。
# 这段代码的功能是绘制模型性能(fitness)随迭代次数变化的曲线图。它展示了原始性能数据以及经过平滑处理后的趋势线。
# Fitness vs iteration
# 创建一个范围为 [1, len(fitness)] 的序列 x ,表示迭代次数。这里假设 fitness 是一个一维数组,其长度等于迭代次数。
x = range(1, len(fitness) + 1)
# 创建一个新的 Matplotlib 图表,大小为 10×6 英寸,并启用紧凑布局( tight_layout=True ),以确保标题和标签不会重叠。
plt.figure(figsize=(10, 6), tight_layout=True)
# 使用 plt.plot 绘制原始性能数据。 x 是迭代次数。 fitness 是每个迭代的性能值。 设置标记为圆圈( marker="o" ),不显示连线( linestyle="none" )。 添加图例标签为 "fitness" 。
plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
# 使用 gaussian_filter1d 对性能数据进行平滑处理,标准差( sigma )设置为 3。 绘制平滑后的曲线: 使用虚线( ":" )表示平滑曲线。 设置线宽为 2( linewidth=2 )。 添加图例标签为 "smoothed" 。
plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line
# 设置图表标题为 "Fitness vs Iteration" 。
plt.title("Fitness vs Iteration")
# 设置 x 轴标签为 "Iteration" ,表示迭代次数。
plt.xlabel("Iteration")
# 设置 y 轴标签为 "Fitness" ,表示模型性能。
plt.ylabel("Fitness")
# 启用网格( plt.grid(True) ),以便更清晰地观察性能变化。
plt.grid(True)
# 添加图例( plt.legend() ),显示原始性能和平滑曲线的标签。
plt.legend()
# 使用 _save_one_file 函数保存图表为 "tune_fitness.png" ,文件名基于原始 CSV 文件路径。 _save_one_file 函数会保存图表到文件,并记录保存信息。
_save_one_file(csv_file.with_name("tune_fitness.png"))
# 这段代码生成了一个展示模型性能随迭代次数变化的图表,帮助用户观察调优过程中的性能趋势。它通过绘制原始性能数据和经过平滑处理的曲线,提供了以下功能。原始性能曲线:通过圆圈标记展示每个迭代的性能值,直观地显示性能的波动。平滑曲线:通过高斯滤波平滑性能数据,帮助用户观察整体趋势,避免因局部波动而误判。图表美化:通过设置标题、标签、网格和图例,使图表清晰易读。保存功能:将图表保存为图像文件,方便后续查看和报告。这种可视化方式特别适合分析超参数调优过程中的性能变化,帮助用户判断调优是否有效,以及何时达到性能稳定。
# 这段代码实现了一个用于可视化超参数调优结果的函数。它生成两种图表。散点图:展示每个超参数与模型性能的关系,颜色表示密度,最优值用十字标记。性能随迭代次数变化的曲线图:展示模型性能随迭代次数的变化趋势,并提供平滑曲线以观察整体趋势。这种可视化方式可以帮助用户快速了解超参数对模型性能的影响,以及调优过程中的性能变化。
10.def output_to_target(output, max_det=300):
# 这段代码定义了一个函数 output_to_target ,用于将模型的输出转换为目标格式 [batch_id, class_id, x, y, w, h, conf] ,以便用于绘图或其他处理。
# 定义了一个函数 output_to_target ,接受以下参数 :
# 1.output :模型的输出,通常是一个包含多个批次检测结果的张量列表。
# 2.max_det :默认值为 300,表示每个批次中最多保留的检测结果数量。
def output_to_target(output, max_det=300):
# 将模型输出转换为目标格式 [batch_id, class_id, x, y, w, h, conf] 以进行绘图。
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
# 初始化一个空列表 targets ,用于存储转换后的目标格式数据。
targets = []
# 遍历模型的输出 output ,其中 i 是批次索引, o 是每个批次的检测结果。
for i, o in enumerate(output):
# 从每个批次的检测结果 o 中提取前 max_det 个检测结果(最多 300 个)。
# 使用 .split((4, 1, 1), 1) 将检测结果分割为三个部分。
# box :检测框的坐标(4 个值,表示 x1, y1, x2, y2)。
# conf :置信度(1 个值)。
# cls :类别 ID(1 个值)。
# .cpu() 确保张量在 CPU 上,以便后续操作。
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
# 创建一个形状为 (conf.shape[0], 1) 的张量 j ,填充值为当前批次索引 i 。 这个张量将作为目标格式的第一个列( batch_id )。
j = torch.full((conf.shape[0], 1), i)
# 使用 torch.cat 将以下内容合并为一个张量。
# j :批次索引。
# cls :类别 ID。
# ops.xyxy2xywh(box) :将检测框从 (x1, y1, x2, y2) 格式转换为 (x, y, w, h) 格式。
# conf :置信度。
# 将合并后的张量添加到 targets 列表中。
targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
# 使用 torch.cat 将 targets 列表中的所有张量合并为一个二维张量。 调用 .numpy() 将张量转换为 NumPy 数组,以便后续处理。
targets = torch.cat(targets, 0).numpy()
# 返回四个部分。 targets[:, 0] 批次索引( batch_id )。 targets[:, 1] 类别 ID( class_id )。 targets[:, 2:-1] 检测框的 (x, y, w, h) 。 targets[:, -1] 置信度( conf )。
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
# 这个函数的主要功能是将模型的输出(通常是检测框、置信度和类别 ID)转换为目标格式 [batch_id, class_id, x, y, w, h, conf] ,以便用于绘图或其他进一步处理。它实现了以下功能。限制检测数量:通过 max_det 参数限制每个批次的检测结果数量。坐标转换:将检测框从 (x1, y1, x2, y2) 格式转换为 (x, y, w, h) 格式。添加批次索引:为每个检测结果添加批次索引,以便区分不同批次的结果。返回结构化数据:将结果拆分为批次索引、类别 ID、检测框坐标和置信度,方便后续处理。这种转换方式特别适用于目标检测任务,尤其是在需要将模型输出用于可视化(如绘制检测框)或评估(如计算 mAP)时。
11.def output_to_rotated_target(output, max_det=300):
# 这段代码定义了一个函数 output_to_rotated_target ,用于将模型的输出转换为目标格式 [batch_id, class_id, x, y, w, h, angle, conf] ,其中 angle 是新增的旋转角度信息。这种格式适用于处理旋转目标检测任务的输出,以便用于绘图或其他处理。
# 定义了一个函数 output_to_rotated_target ,接受以下参数 :
# 1.output :模型的输出,通常是一个包含多个批次检测结果的张量列表。
# 2.max_det :默认值为 300,表示每个批次中最多保留的检测结果数量。
def output_to_rotated_target(output, max_det=300):
# 将模型输出转换为目标格式 [batch_id, class_id, x, y, w, h, conf] 以进行绘图。
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
# 初始化一个空列表 targets ,用于存储转换后的目标格式数据。
targets = []
# 遍历模型的输出 output ,其中 i 是批次索引, o 是每个批次的检测结果。
for i, o in enumerate(output):
# 从每个批次的检测结果 o 中提取前 max_det 个检测结果(最多 300 个)。
# 使用 .split((4, 1, 1, 1), 1) 将检测结果分割为四个部分。
# box :检测框的坐标(4 个值,表示 x1, y1, x2, y2)。
# conf :置信度(1 个值)。
# cls :类别 ID(1 个值)。
# angle :旋转角度(1 个值)。
# .cpu() 确保张量在 CPU 上,以便后续操作。
box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)
# 创建一个形状为 (conf.shape[0], 1) 的张量 j ,填充值为 当前批次索引 i 。 这个张量将作为目标格式的第一个列( batch_id )。
j = torch.full((conf.shape[0], 1), i)
# 使用 torch.cat 将以下内容合并为一个张量。
# j :批次索引。
# cls :类别 ID。
# box :检测框的坐标( x1, y1, x2, y2 )。
# angle :旋转角度。
# conf :置信度。
# 将合并后的张量添加到 targets 列表中。
targets.append(torch.cat((j, cls, box, angle, conf), 1))
# 使用 torch.cat 将 targets 列表中的所有张量合并为一个二维张量。 调用 .numpy() 将张量转换为 NumPy 数组,以便后续处理。
targets = torch.cat(targets, 0).numpy()
# 返回四个部分。 targets[:, 0] 批次索引( batch_id )。 targets[:, 1] 类别 ID( class_id )。 targets[:, 2:-1] 检测框的坐标和旋转角度( x1, y1, x2, y2, angle )。 targets[:, -1] 置信度( conf )。
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
# 这个函数的主要功能是将模型的输出(通常是检测框、置信度、类别 ID 和旋转角度)转换为目标格式 [batch_id, class_id, x1, y1, x2, y2, angle, conf] ,以便用于绘图或其他进一步处理。它实现了以下功能。限制检测数量:通过 max_det 参数限制每个批次的检测结果数量。添加批次索引:为每个检测结果添加批次索引,以便区分不同批次的结果。返回结构化数据:将结果拆分为批次索引、类别 ID、检测框坐标、旋转角度和置信度,方便后续处理。这种转换方式特别适用于旋转目标检测任务,尤其是在需要将模型输出用于可视化(如绘制旋转检测框)或评估(如计算 mAP)时。
12.def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
# 这段代码定义了一个名为 feature_visualization 的函数,用于可视化神经网络中间层的特征图。
# 定义了一个函数 feature_visualization ,接受以下参数 :
# 1.x :特征图的张量。
# 2.module_type :模块类型(例如网络层的名称)。
# 3.stage :当前阶段(例如网络的第几层)。
# 4.n :要可视化的特征图数量,默认为 32。
# 5.save_dir :保存可视化结果的目录,默认为 Path("runs/detect/exp") 。
def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
# 在推理过程中可视化给定模型模块的特征图。
"""
Visualize feature maps of a given model module during inference.
Args:
x (torch.Tensor): Features to be visualized.
module_type (str): Module type.
stage (int): Module stage within the model.
n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
"""
# 遍历一个包含模型头部名称的集合(如 Detect 、 Segment 等)。
for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}: # all model heads
# 如果 module_type 中包含这些头部名称之一。
if m in module_type:
# 则直接返回,不进行可视化。这是为了避免对模型的头部层进行可视化。
return
# 检查输入 x 是否为 torch.Tensor 类型。
if isinstance(x, torch.Tensor):
# 如果是,则提取张量的形状,分别获取 批次大小 ( _ )、 通道数 ( channels )、 高度 ( height )和 宽度 ( width )。
_, channels, height, width = x.shape # batch, channels, height, width
# 检查特征图的高度和宽度是否大于 1,以确保特征图具有一定的空间维度。
if height > 1 and width > 1:
# 如果满足条件,则构造保存可视化结果的文件名 f ,格式为 stage{stage}_{module_type}_features.png ,其中 module_type 取最后一个点( . )之后的部分。
f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
# 从张量 x 中选择第一个批次( x[0] ),并将其移动到 CPU 上。 使用 torch.chunk 将特征图按通道分割成多个块,每个块包含一个通道的特征图。
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
# 将要可视化的特征图数量 n 限制为通道数 channels 的最小值,以避免超出通道数。
n = min(n, channels) # number of plots
# 使用 plt.subplots 创建一个子图网格,网格的行数为 math.ceil(n / 8) ,列数为 8,以容纳最多 n 个子图。 tight_layout=True 确保子图之间布局紧凑。
_, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
# 将子图数组 ax 展平为一维数组,以便后续可以通过索引访问每个子图。
ax = ax.ravel()
# 调整子图之间的间距,水平间距为 0.05,垂直间距为 0.05。
plt.subplots_adjust(wspace=0.05, hspace=0.05)
# 遍历前 n 个特征图块。
for i in range(n):
# 使用 imshow 将每个特征图显示在对应的子图上。使用 squeeze 去掉多余的维度。
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
# 关闭每个子图的坐标轴。
ax[i].axis("off")
# 使用全局日志记录器 LOGGER 记录保存文件的信息,包括文件名和可视化的特征图数量。
LOGGER.info(f"Saving {f}... ({n}/{channels})") # 正在保存 {f}... ({n}/{channels})。
# 将可视化结果保存为图像文件,分辨率为 300 DPI,边框紧密。
plt.savefig(f, dpi=300, bbox_inches="tight")
# 关闭当前的图像窗口,释放资源。
plt.close()
# 将第一个批次的特征图保存为 NumPy 文件,文件名与图像文件名相同,但扩展名为 .npy 。
np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save
# 这个函数的作用是。过滤特定模块:避免对模型的头部层(如 Detect 、 Segment 等)进行可视化。特征图提取:从输入张量中提取特征图,并按通道分割。可视化:将特征图可视化为图像,并保存到指定目录。日志记录:记录保存的文件信息。NumPy 文件保存:将特征图保存为 NumPy 文件,便于后续分析。这个函数主要用于调试和分析神经网络中间层的特征图,帮助开发者理解模型的行为和特征提取过程。