YOLOv7-0.1部分代码阅读笔记-metrics.py
metrics.py
utils\metrics.py
目录
metrics.py
1.所需的库和模块
2.def fitness(x):
3.def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
4.def compute_ap(recall, precision):
5.class ConfusionMatrix:
6.def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
7.def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
1.所需的库和模块
# Model validation metrics
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from . import general
2.def fitness(x):
# 这段代码定义了一个名为 fitness 的函数,它用于计算模型的适应度(fitness),作为一个加权组合的度量指标。这个函数通常用于机器学习中的进化算法,特别是在超参数优化的过程中,用于评估不同超参数组合的性能。
# 定义一个函数 fitness ,它接受一个参数。
# 1.x :这是一个包含多个度量指标的数组。
def fitness(x):
# 模型适应度作为指标的加权组合。
# Model fitness as a weighted combination of metrics
# 定义一个权重列表 w ,包含四个元素,分别对应于度量指标 [P, R, mAP@0.5, mAP@0.5:0.95] 的权重。
# 这里的 P 和 R 分别代表精确度(Precision)和召回率(Recall),mAP@0.5 是指在 IoU 阈值为 0.5 时的平均精度,mAP@0.5:0.95 是指在 IoU 阈值从 0.5 到 0.95 范围内的平均精度。
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
# 计算适应度值。首先, x[:, :4] 从输入数组 x 中选择所有行的前四列,即对应于度量指标的列。然后,将这些度量指标与相应的权重相乘,最后沿数组的第一个轴(即行)求和,得到每个样本的适应度值。
return (x[:, :4] * w).sum(1)
# 这个函数的返回值是一个数组,其中每个元素代表对应样本的适应度值。在超参数优化中,通常会选择适应度值最高的样本作为最优解。
3.def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
# 这段代码是一个用于计算每个类别的平均精度(Average Precision, AP)的函数,它是目标检测任务中常用的性能评估指标之一。这个函数还包含了绘制精确度-召回率曲线(Precision-Recall curve)和F1分数曲线的功能。
# 定义了一个函数 ap_per_class ,它接受四个主要参数。
# 1.tp (true positives) :真正例。
# 2.conf (confidence scores) :置信度分数。
# 3.pred_cls (predicted classes) :预测类别。
# 4.target_cls (target classes) :目标类别。
# 此外,还有三个可选参数 :
# 5.plot :是否绘图。
# 6.save_dir :保存目录。
# 7.names :类别名称。
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
tp: True positives (nparray, nx1 or nx10).
conf: Objectness value from 0-1 (nparray).
pred_cls: Predicted object classes (nparray).
target_cls: True object classes (nparray).
plot: Plot precision-recall curve at mAP@0.5
save_dir: Plot save directory
# Returns
The average precision as computed in py-faster-rcnn.
"""
# Sort by objectness
# 排序。这行代码对置信度分数 conf 进行降序排序,并返回排序后的索引数组 i 。 -conf 确保了置信度最高的预测排在前面。
i = np.argsort(-conf)
# 重新排列。使用上面得到的索引数组 i ,对 真正例 tp 、 置信度分数 conf 和 预测类别 pred_cls 进行重新排列,确保它们按照置信度从高到低的顺序排列。
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
# Find unique classes
# 寻找唯一类别。这行代码找出目标类别 target_cls 中所有唯一的类别,并将它们存储在 unique_classes 数组中。
unique_classes = np.unique(target_cls)
# 计算类别数量。计算唯一类别的数量,即 unique_classes 数组的长度,这个值表示类别的数量。
nc = unique_classes.shape[0] # number of classes, number of detections
# Create Precision-Recall curve and compute AP for each class
# 创建Precision-Recall曲线并计算AP。
# 创建一个从0到1的1000个点的数组 px ,用于绘制Precision-Recall曲线。 py 是一个空列表,用于存储绘制曲线时的Precision值。
px, py = np.linspace(0, 1, 1000), [] # for plotting
# 初始化三个数组,分别用于存储每个类别的 平均精度 ap 、 Precision值 p 和 Recall值 r 。
# 这些数组的形状是 (nc, tp.shape[1]) 和 (nc, 1000) ,其中 nc 是类别的数量, tp.shape[1] 是每个类别的预测数量,1000是用于绘制Precision-Recall曲线的点的数量。
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
# 遍历每个类别。这个循环遍历所有唯一的类别, ci 是类别的索引, c 是类别的值。
for ci, c in enumerate(unique_classes):
# 筛选特定类别的预测。创建一个布尔数组,其中 True 表示预测类别 pred_cls 等于当前类别 c 。
i = pred_cls == c
# 计算标签和预测的数量。
# 计算目标类别 target_cls 中等于当前类别 c 的标签数量。
n_l = (target_cls == c).sum() # number of labels
# 计算预测类别 pred_cls 中等于当前类别 c 的预测数量。
n_p = i.sum() # number of predictions
# 跳过没有预测或标签的类别。
if n_p == 0 or n_l == 0:
# 如果没有预测或标签,跳过当前类别,不进行后续计算。
continue
else:
# numpy.cumsum(a, axis=None, dtype=None, out=None, *, where=True)
# np.cumsum() 是 NumPy 库中的一个函数,用于计算沿指定轴的元素累积和(cumulative sum)。这个函数会返回一个新的数组,其中每个元素是原始数组中该位置及之前所有元素的和。
# 参数说明 :
# a :输入数组。
# axis :沿哪个轴计算累积和。如果为 None ,则数组会被展平后再计算累积和。默认为 None 。
# dtype :输出数组的类型。如果没有指定,则输出数组的类型与输入数组的类型相同。
# out :输出数组。如果指定,则计算结果会被存储在这个数组中。
# where :布尔数组,与 a 形状相同,用于选择性地计算累积和。只有在 a 的相应位置为 True 时,才会在该位置计算累积和。
# 返回值 :
# 累积和数组,与输入数组 a 形状相同,但数据类型可能不同,取决于 dtype 参数。
# Accumulate FPs and TPs
# 累积FP(假正例)和TP(真正例)。
# 计算累积的FP数量。
fpc = (1 - tp[i]).cumsum(0)
# 计算累积的TP数量。
tpc = tp[i].cumsum(0)
# Recall
# 计算Recall(召回率)。
# 计算Recall曲线, 1e-16 是为了防止除以零。
recall = tpc / (n_l + 1e-16) # recall curve
# numpy.interp(x, xp, fp, left=None, right=None, period=None)
# np.interp 是 NumPy 库中的一个函数,用于一维线性插值。给定一组数据点 x 和相应的值 xp ,以及一个新的查询点 x , np.interp 函数会找到 xp 中 x 值所在的区间,并使用线性插值来估计 x 对应的值。
# 参数说明 :
# x :查询点,即你想要插值的点。
# xp :数据点,一个一维数组,包含数据点的横坐标。
# fp : xp 对应的值,一个一维数组,包含数据点的纵坐标。
# left :可选参数,如果 x 中的值小于 xp 中的最小值,则使用这个值作为插值结果。
# right :可选参数,如果 x 中的值大于 xp 中的最大值,则使用这个值作为插值结果。
# period :可选参数,表示周期性,如果指定, xp 将被视为周期性的。
# 返回值 :
# 插值结果,一个与 x 形状相同的数组。
# 使用插值方法,根据置信度分数计算在不同阈值下的Recall值。
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
# Precision
# 计算Precision(精确率)。
# 计算Precision曲线。
precision = tpc / (tpc + fpc) # precision curve
# 使用插值方法,根据置信度分数计算在不同阈值下的Precision值。
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
# AP from recall-precision curve
# 计算AP(平均精度)。
# 遍历每个IoU阈值(交并比)。
for j in range(tp.shape[1]):
# 调用 compute_ap 函数计算当前类别和IoU阈值下的AP值。
# def compute_ap(recall, precision):
# -> 它用于计算给定召回率(recall)和精确率(precision)曲线的平均精度(Average Precision, AP)。返回计算得到的平均精度 ap ,以及 精确率包络 mpre 和 召回率数组 mrec 。
# -> return ap, mpre, mrec
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
if plot and j == 0:
# 如果需要绘图,并且在计算第一个IoU阈值下的AP时,将插值后的Precision值添加到 py 列表中,用于绘制Precision-Recall曲线。
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
# Compute F1 (harmonic mean of precision and recall)
# 计算 F1 分数。它是精确率(precision, p )和召回率(recall, r )的调和平均数。调和平均数的计算公式是 2 * (p * r) / (p + r) 。这里加上 1e-16 是为了避免分母为零的情况。
f1 = 2 * p * r / (p + r + 1e-16)
# 绘制曲线图。如果 plot 参数为 True ,则执行以下绘图操作。
if plot:
# 绘制 Precision-Recall 曲线,并保存为 PNG 文件。 px 和 py 是曲线的 x 和 y 值, ap 是平均精度, Path(save_dir) / 'PR_curve.png' 指定了保存路径和文件名, names 可能是类别名称或其他标签。
# def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()): -> 它用于绘制精确率-召回率(Precision-Recall, PR)曲线,并保存为图像文件。
plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
# 绘制 F1 分数曲线,并保存为 PNG 文件。 f1 是计算得到的 F1 分数数组。
# def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'): -> 它用于绘制度量-置信度(Metric-Confidence, MC)曲线,并保存为图像文件。
plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
# 绘制精确率曲线,并保存为 PNG 文件。
plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
# 绘制召回率曲线,并保存为 PNG 文件。
plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
# 找到 F1 分数沿第一个轴(通常是类别轴)的平均值最大的索引 i 。
i = f1.mean(0).argmax() # max F1 index
# 返回在索引 i 处的 精确率 、 召回率 、 平均精度 和 F1 分数 ,以及 转换为整数类型的类别索引 。
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
# 这个函数是目标检测模型评估流程的一部分,它可以帮助我们了解模型在不同类别上的性能,并可以通过绘图直观地展示这些性能指标。
4.def compute_ap(recall, precision):
# 这段代码定义了一个名为 compute_ap 的函数,它用于计算给定召回率(recall)和精确率(precision)曲线的平均精度(Average Precision, AP)。平均精度是评估分类模型性能的一个重要指标,特别是在目标检测和信息检索领域。
# 函数接收两个参数。
# 1.recall :召回率数组。
# 2.precision :精确率数组。
def compute_ap(recall, precision):
""" Compute the average precision, given the recall and precision curves
# Arguments
recall: The recall curve (list)
precision: The precision curve (list)
# Returns
Average precision, precision curve, recall curve
"""
# Append sentinel values to beginning and end
# 添加哨兵值。
# 在召回率曲线的开始和结束添加哨兵值,确保曲线从0开始,到1结束,并且在最后增加一个稍微大于1的值,以确保曲线闭合。
mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01]))
# 在精确率曲线的开始添加1,表示完美的精确率,结束添加0,表示精确率降为0。
mpre = np.concatenate(([1.], precision, [0.]))
# Compute the precision envelope
# np.maximum.accumulate(array, axis=None, dtype=None, out=None)
# np.maximum.accumulate 是 NumPy 库中的一个函数,它用于计算输入数组中元素的累积最大值。这个函数接受一个输入数组,然后返回一个累积最大值数组。
# 1.参数 :
# array :输入的 NumPy 数组。
# axis :指定沿哪个轴计算累积最大值。如果为 None ,则在扁平化后的数组上操作。
# dtype :指定输出数组的数据类型,如果没有指定,则会根据输入数组的数据类型自动决定。
# out :一个可选的输出数组,用于存储结果。
# 2. 返回值 :
# 返回一个新的数组,其中每个元素是输入数组中该位置及之前所有元素的最大值。
# 3. 工作原理 :
# 函数从数组的第一个元素开始,逐个比较元素,并保留到当前位置为止的最大值。
# 对于每个位置,函数比较当前元素和之前累积的最大值,然后更新累积最大值。
# 计算精确率包络。这一行代码计算精确率的累积最大值,用于创建精确率包络。 np.flip 将数组反转, np.maximum.accumulate 计算累积最大值,然后再 np.flip 回来。
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
# Integrate area under curve
# 积分曲线下面积。
# 选择积分方法,这里使用插值方法。
method = 'interp' # methods: 'continuous', 'interp'
# 如果选择插值方法( 'interp' )。
if method == 'interp':
# 创建一个从0到1的101个点的数组,用于插值。
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
# np.trapz(y, x=None, dx=1.0, axis=-1)
# np.trapz 是 NumPy 库中的一个函数,它使用复合梯形规则来计算数值积分。这个函数可以近似地计算定积分,即曲线下的面积。
# 1.参数 :
# y :数组,表示被积函数的值。
# x :数组,可选参数,表示自变量的取值。如果 x 为 None ,则假设样本点均匀分布,间距为 dx 。
# dx :标量,可选参数,当 x 为 None 时,样本点之间的间距。
# axis :整数,可选参数,指定沿哪个轴进行积分。
# 2. 返回值 :
# 返回由梯形规则近似的定积分值。如果 y 是一维数组,则结果为浮点数。如果 y 是多维数组,则结果是一个 n-1 维数组。
# 3. 工作原理 :
# 函数通过将曲线下的区域划分为多个梯形,并计算每个梯形的面积来近似积分。这种方法称为梯形规则,是一种简单的数值积分方法。
# 使用梯形法则( np.trapz )来积分插值后的精确率曲线。
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
# 如果选择连续方法( 'continuous' )。
else: # 'continuous'
# 找到召回率变化的点。
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
# 计算这些点构成的梯形面积之和,即曲线下的面积。
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
# 返回计算得到的平均精度 ap ,以及 精确率包络 mpre 和 召回率数组 mrec 。
return ap, mpre, mrec
# 这个函数通过计算召回率和精确率曲线下的面积来评估模型的平均精度。这是衡量模型在不同阈值下性能的一种方法,特别是在目标检测任务中,它可以帮助我们理解模型在不同置信度水平上的表现。
5.class ConfusionMatrix:
# 这段代码定义了一个名为 ConfusionMatrix 的类,用于在目标检测任务中生成混淆矩阵。混淆矩阵是一个表格,用于描述分类模型的性能,特别是模型预测的类别与真实类别之间的关系。
class ConfusionMatrix:
# Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
# 构造函数 __init__ 。
# 1.self :类的实例。
# 2.nc :一个整数,表示类别的数量。
# 3.conf :一个浮点数,表示置信度阈值,默认值为0.25。只有当模型预测的置信度高于这个阈值时,预测才会被考虑。
# 4.iou_thres :一个浮点数,表示交并比(Intersection over Union, IoU)阈值,默认值为0.45。只有当预测框与真实框的IoU高于这个阈值时,预测才被认为是正确的。
def __init__(self, nc, conf=0.25, iou_thres=0.45):
# 一个二维数组,用于存储混淆矩阵。数组的大小为 (nc + 1) x (nc + 1) ,其中 nc 是类别的数量。额外的一行和一列用于存储背景类(通常表示没有检测到任何对象的情况)。
self.matrix = np.zeros((nc + 1, nc + 1))
# 存储传入的类别数量 nc 。
self.nc = nc # number of classes
# 存储置信度阈值 conf 。
self.conf = conf
# 存储IoU阈值 iou_thres 。
self.iou_thres = iou_thres
# 这段代码是 ConfusionMatrix 类的一个方法,名为 process_batch ,它用于处理一批目标检测的结果,并更新混淆矩阵。
# 这个方法接收两个参数。
# detections :模型的检测结果。一个形状为 [N, 6] 的数组,其中每一行包含一个检测结果,格式为 x1, y1, x2, y2, confidence, class 。
# labels :代表模型的真实标签。一个形状为 [M, 5] 的数组,其中每一行包含一个真实标签,格式为 class, x1, y1, x2, y2 。方法逻辑:
def process_batch(self, detections, labels):
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
labels (Array[M, 5]), class, x1, y1, x2, y2
Returns:
None, updates confusion matrix accordingly
"""
# 筛选置信度高于阈值的检测结果。只保留置信度大于 self.conf 的检测结果。
detections = detections[detections[:, 4] > self.conf]
# 提取类别信息。
# 提取真实标签的类别信息。
gt_classes = labels[:, 0].int()
# 提取检测结果的类别信息。
detection_classes = detections[:, 5].int()
# 计算交并比(IoU)。计算真实标签和检测结果的IoU。
# def box_iou(box1, box2):
# -> 用于计算两个边界框集合之间的交并比(Intersection over Union, IoU)。这个函数返回一个形状为 (N, M) 的张量,其中包含了 box1 中每个边界框与 box2 中每个边界框之间的 IoU 值。
# -> return inter / (area1[:, None] + area2 - inter)
iou = general.box_iou(labels[:, 1:], detections[:, :4])
# torch.where(condition, x, y)
# torch.where 是 PyTorch 库中的一个函数,它返回一个新的张量,其中元素来自输入张量的元素,这些元素满足给定的条件。这个函数类似于 NumPy 中的 np.where 函数。
# 参数说明 :
# condition :一个布尔张量,指定条件。只有当 condition 中的元素为 True 时, x 中对应的元素才会被选中;否则, y 中对应的元素会被选中。
# x :一个张量,当 condition 为 True 时, x 中的元素会被返回。
# y :一个张量,当 condition 为 False 时, y 中的元素会被返回。
# 如果 x 和 y 是张量,它们必须具有相同的形状,或者能够广播到相同的形状。如果 x 和 y 是标量,则它们会被广播为与 condition 相同的形状。
# 返回值 :
# 一个张量,包含根据 condition 选择的 x 和 y 的元素。
# 找到IoU高于阈值的匹配。找到IoU大于 self.iou_thres 的索引。
x = torch.where(iou > self.iou_thres)
# 处理匹配结果。如果存在多个匹配,按照IoU降序排序,并去除重复的匹配。
if x[0].shape[0]:
# 将匹配的 索引 和 IoU值 合并为一个数组,并转换为NumPy数组。
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
# 按照IoU值降序排序。
matches = matches[matches[:, 2].argsort()[::-1]]
# 去除重复的ground truth匹配。
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# 再次按照IoU值降序排序。
matches = matches[matches[:, 2].argsort()[::-1]]
# 去除重复的检测结果匹配。
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
# 处理没有匹配的情况。
else:
# 如果没有匹配, matches 将是一个形状为 (0, 3) 的零数组。
matches = np.zeros((0, 3))
# 检查是否有匹配。检查 matches 数组是否非空,即是否有至少一个匹配。
n = matches.shape[0] > 0
# 提取匹配信息。从 matches 数组中提取匹配的索引和IoU值,并将它们转置和转换为 np.int16 类型。 m0 代表真实标签的索引, m1 代表检测结果的索引。
m0, m1, _ = matches.transpose().astype(np.int16)
# 更新正确匹配。
# 遍历每个真实标签的类别 gc (ground truth classes)。
for i, gc in enumerate(gt_classes):
# 找到与当前真实标签索引相匹配的检测结果。
j = m0 == i
# 如果 n 为真且 sum(j) 为1,说明有一个且只有一个检测结果与当前真实标签匹配。
if n and sum(j) == 1:
# 在混淆矩阵中增加正确匹配的数量。
self.matrix[gc, detection_classes[m1[j]]] += 1 # correct
# 更新背景假正例。
# 如果没有检测结果与某个真实标签匹配,增加背景假正例的数量。
else:
# 在混淆矩阵中增加背景假正例的数量。
self.matrix[self.nc, gc] += 1 # background FP
# 更新背景假负例。
# 如果 n 为真,遍历每个检测结果的类别 dc 。
if n:
for i, dc in enumerate(detection_classes):
# 检查是否有任何真实标签与当前检测结果匹配。
if not any(m1 == i):
# 如果没有,增加背景假负例的数量。在混淆矩阵中增加背景假负例的数量。
self.matrix[dc, self.nc] += 1 # background FN
# 这段代码定义了一个名为 matrix 的方法,它是 ConfusionMatrix 类的一个成员函数。这个方法非常简单,它没有参数,并且返回类实例中的 self.matrix 属性。
# 定义了一个名为 matrix 的方法,它接受 1.self 作为参数, self 代表类的实例本身。
def matrix(self):
# 方法返回 ConfusionMatrix 类实例中的 matrix 属性,这是一个二维数组,存储了混淆矩阵的值。
return self.matrix
# 假设你已经创建了一个 ConfusionMatrix 类的实例,并且已经通过调用 process_batch 方法更新了混淆矩阵,你可以使用 matrix 方法来获取当前的混淆矩阵。
# 这段代码是 ConfusionMatrix 类的一个方法,名为 plot ,它用于绘制并保存归一化的混淆矩阵的热图。
# 1.self :类的实例。
# 2.save_dir :一个字符串,指定保存图像的目录。默认为空字符串,意味着图像将保存在当前工作目录。
# 3.names :一个元组,包含类别名称,用于热图的 x 轴和 y 轴标签。
def plot(self, save_dir='', names=()):
try:
# 导入依赖库。尝试导入 seaborn 库,这是一个基于 Matplotlib 的统计图表库,常用于绘制热图。
import seaborn as sn
# 归一化混淆矩阵。将混淆矩阵的每一行除以其行和,以归一化矩阵。添加 1E-6 防止除以零。
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
# 处理低值。将小于 0.005 的值设置为 NaN,这些值在热图中不会显示注释。
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
# 创建绘图。创建一个图形对象,并设置图形大小。
fig = plt.figure(figsize=(12, 9), tight_layout=True)
# 在 Seaborn 库中, sn.set() 函数已经被弃用。取而代之的是 sns.set_theme() 或者 plt.rcParams.update() 。如果你想要设置 Seaborn 的全局样式,可以使用 sns.set_theme() ,并且可以通过传递参数来调整字体大小等样式设置。
# 以下是如何修改 sn.set(font_scale=1.0 if self.nc < 50 else 0.8) 代码的示例:
# import seaborn as sns
# # 使用 sns.set_theme() 设置 Seaborn 的全局样式
# if self.nc < 50:
# sns.set_theme(style="white", font_scale=1.0)
# else:
# sns.set_theme(style="white", font_scale=0.8)
# 或者,如果你想要更细粒度地控制样式,可以直接更新 Matplotlib 的配置参数:
# import matplotlib.pyplot as plt
# # 使用 plt.rcParams.update() 更新 Matplotlib 的配置参数
# if self.nc < 50:
# plt.rcParams.update({'font.size': 12}) # 假设 1.0 对应的字体大小是 12
# else:
# plt.rcParams.update({'font.size': 8}) # 假设 0.8 对应的字体大小是 8
# 设置 Seaborn 参数 根据类别数量调整字体大小。
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
# seaborn.heatmap(data, mask=None, cmap=None, annot=False, fmt=".2f", annot_kws={}, cbar=True, cbar_kws={}, cbar_ax=None, square=False, xticklabels=False, yticklabels=False, **kwargs)
# sns.heatmap 是 Seaborn 库中的一个函数,用于绘制热图(heatmap),它基于 Matplotlib 进行绘图。热图是一种数据可视化技术,用于显示矩阵数据,其中矩阵的值用颜色编码。
# 参数说明 :
# data :要绘制的数据,通常是一个二维数组或类似矩阵的数据结构。
# mask :一个与 data 形状相同的数组,用于隐藏热图中的某些值。
# cmap :一个颜色映射对象或颜色映射名称,用于指定热图中的颜色。
# annot :布尔值或数组,指定是否在热图的每个单元格中显示数值注释。
# fmt :字符串,指定注释的格式,例如 ".2f" 表示浮点数保留两位小数。
# annot_kws :一个字典,包含传递给注释文本的关键字参数,例如字体大小。
# cbar :布尔值,指定是否显示颜色条。
# cbar_kws :一个字典,包含传递给颜色条的关键字参数。
# cbar_ax :一个轴对象,用于指定颜色条的位置。
# square :布尔值,指定是否将热图的每个单元格绘制为正方形。
# xticklabels :布尔值或序列,指定是否显示 x 轴的刻度标签,或指定刻度标签。
# yticklabels :布尔值或序列,指定是否显示 y 轴的刻度标签,或指定刻度标签。
# **kwargs :其他关键字参数,用于传递给 Axes 对象。
# 返回值 :
# 返回一个 Axes 对象,代表绘制热图的轴。
# Axes.set_facecolor(color)
# set_facecolor 是 Matplotlib 中 Axes 对象的一个方法,用于设置特定 Axes (绘图区域)的背景颜色。
# 参数说明 :
# color :一个颜色值,可以是颜色名称(如 'red' )、RGB 元组(如 (1, 0, 0) )、RGBA 元组(如 (1, 0, 0, 0.5) )或十六进制颜色代码(如 '#FF0000' )。
# 返回值 :无返回值,该方法直接修改 Axes 对象的背景颜色。
# 绘制热图。使用 sn.heatmap 绘制归一化的混淆矩阵热图。
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
# 设置轴标签。设置 x 轴和 y 轴的标签为 'True' 和 'Predicted'。
fig.axes[0].set_xlabel('True')
fig.axes[0].set_ylabel('Predicted')
# 保存图像。使用 fig.savefig 将热图保存为 PNG 文件,文件名为 'confusion_matrix.png',指定 dpi 为 250。
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
except Exception as e:
# 异常处理。如果在绘图过程中发生任何异常,异常将被捕捉,但不会进行任何处理( pass )。
pass
# 这段代码定义了 ConfusionMatrix 类的一个方法,名为 print ,用于打印当前的混淆矩阵。这个方法遍历混淆矩阵的每一行,并将其打印出来。
# 定义了一个名为 print 的方法,它接受 1.self 作为参数, self 代表类的实例本身。
def print(self):
# 使用一个 for 循环遍历混淆矩阵的所有行。 self.nc 是类别的数量,因此 self.nc + 1 包括了所有类别 加上背景类 (如果没有检测到任何对象的情况)。
for i in range(self.nc + 1):
# 对于每一行,使用 map(str, self.matrix[i]) 将该行的所有元素转换为字符串,然后使用 join 方法将它们连接成一个由空格分隔的字符串,并打印出来。
print(' '.join(map(str, self.matrix[i])))
# 假设你已经创建了一个 ConfusionMatrix 类的实例,并且已经通过调用 process_batch 方法更新了混淆矩阵,你可以使用 print 方法来打印当前的混淆矩阵
6.def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
# Plots ----------------------------------------------------------------------------------------------------------------
# 这段代码定义了一个名为 plot_pr_curve 的函数,它用于绘制精确率-召回率(Precision-Recall, PR)曲线,并保存为图像文件。这个函数是目标检测和分类任务中常用的工具,用于可视化模型的性能。
# 1.px :召回率的值,通常是一个从0到1的数组。
# 2.py :对应的精确率值,可以是单一组精确率值,也可以是每类的精确率值的数组。 ap 参数应该是一个二维数组,其中每一行对应一个类别,第一列是该类别的AP值。
# 3.ap :平均精度(Average Precision),可以是单值或每类的AP值的数组。
# 4.save_dir :保存图像的路径,默认为 'pr_curve.png'。
# 5.names :类别名称的元组,用于图例标签。
def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
# Precision-recall curve
# 创建绘图环境。使用 plt.subplots 创建一个图形和轴对象,设置图形大小为9x6英寸,并启用紧凑布局。
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
# 堆叠精确率值。使用 np.stack 将 py 堆叠,以便可以处理多组数据。
py = np.stack(py, axis=1)
# 绘制PR曲线。
# 如果 names 元组中的类别数量在1到20之间,为每个类别绘制PR曲线,并添加图例。
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py.T):
# matplotlib.pyplot.plot(*args, scalex=True, scaley=True, data=None, **kwargs)
# plt.plot 是 Matplotlib 库中的一个函数,用于绘制二维图形中的线或标记。这个函数非常灵活,可以用于绘制简单的线条图,也可以通过各种参数定制复杂的图表。
# 参数说明 :
# *args :可变数量的位置参数,用于指定要绘制的数据。通常,第一个参数是x坐标,第二个参数是y坐标。也可以传递多个参数,用于绘制多条线。
# scalex 和 scaley :布尔值,指定是否自动缩放x轴和y轴。默认为 True 。
# data :一个可选的关键字参数,用于指定传递给绘图函数的数据结构。
# **kwargs :关键字参数,用于定制线条的样式、颜色、标记等。 kwargs 可以接受多种参数,以下是一些常用的 :
# x :x坐标数据。
# y :y坐标数据。
# fmt :字符串或字符串序列,指定绘图的格式(如 'b-' 表示蓝色实线)。
# color 或 c :颜色代码或名称。
# marker :标记类型(如 'o' 表示圆圈)。
# linestyle 或 ls :线型(如 '--' 表示虚线)。
# linewidth 或 lw :线宽。
# markersize 或 ms :标记大小。
# label :图例标签。
# 返回值 :
# 一个包含 Line2D 对象的列表,这些对象代表图表中绘制的线。
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
# 如果类别数量不在这个范围内,则不显示每个类别的图例,只绘制所有类别的PR曲线。
else:
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
# 绘制平均PR曲线。
# 计算所有类别的平均精确率,并绘制平均PR曲线,设置线宽为3,颜色为蓝色,并添加标签显示所有类别的平均精度(mAP@0.5)。
ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
# matplotlib.axes.Axes.set_xlabel(label, fontdict=None, labelpad=None, **kwargs)
# set_xlabel 是 Matplotlib 库中 Axes 对象的一个方法,用于设置图表的 x 轴标签。
# 参数说明 :
# label :一个字符串,表示要设置的 x 轴标签文本。
# fontdict :一个字典,用于指定标签的字体属性,如大小、重量、颜色等。
# labelpad :一个浮点数,表示标签与轴之间的距离。
# **kwargs :其他关键字参数,可以用于设置字体大小、颜色等属性。
# 返回值 :无返回值,该方法直接修改 Axes 对象的属性。
# 设置坐标轴标签和范围。
# 设置x轴标签为 'Recall',y轴标签为 'Precision'。
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
# matplotlib.axes.Axes.set_xlim(xmin=None, xmax=None, emit=True, auto=False, **kwargs)
# set_xlim 是 Matplotlib 库中 Axes 对象的一个方法,用于设置图表的 x 轴显示范围。
# 参数说明 :
# xmin :一个数值,指定 x 轴的最小值。如果不设置(或为 None ),则自动确定最小值。
# xmax :一个数值,指定 x 轴的最大值。如果不设置(或为 None ),则自动确定最大值。
# emit :一个布尔值,指定是否发出 xlim_changed 事件。默认为 True 。
# auto :一个布尔值,指定是否自动调整 x 轴的显示范围。默认为 False 。如果设置为 True ,则 Matplotlib 会根据数据自动调整 x 轴的范围。
# **kwargs :其他关键字参数,用于传递给定位器(locator)对象,例如设置次刻度的数量。
# 返回值 :无返回值,该方法直接修改 Axes 对象的属性。
# 设置x轴和y轴的范围都是从0到1。
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
# matplotlib.pyplot.legend(*args, **kwargs)
# plt.legend 是 Matplotlib 库中的一个函数,用于在图表中添加图例(legend),以标识不同的数据系列。
# 参数说明 :
# *args :可变参数,通常用于传递一个或多个 Axes 或 Patch 对象,这些对象将被包含在图例中。
# **kwargs :关键字参数,用于定制图例的外观和行为。 kwargs 可以接受多种参数,以下是一些常用的 :
# loc :字符串或数字,指定图例的位置。例如, 'upper right' 、 'lower left' 、 1 (最佳位置)等。
# bbox_to_anchor :元组,用于指定图例的确切位置,例如 (1.05, 1) 。
# labels :标签列表,用于覆盖默认的图例标签。
# title :字符串,图例的标题。
# fontsize 或 fontproperties :用于设置图例中字体的大小或属性。
# frameon :布尔值,指定是否显示图例边框。
# shadow :布尔值,指定图例是否显示阴影。
# fancybox :布尔值,指定图例边框是否为圆角。
# ncol :整数,指定图例中的列数。
# markerscale :浮点数,用于调整图例中标记的大小。
# handlelength 、 handletextpad 、 borderpad 、 labelspacing :用于调整图例中手柄(线条或标记)的长度、文本间距、边框间距和标签间距。
# 返回值 :
# 返回一个 Legend 对象,可以用于进一步定制图例。
# 添加图例。使用 plt.legend 添加图例,设置位置在图形的右上角。
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
# matplotlib.figure.Figure.savefig(fname, dpi=None, facecolor='w', edgecolor='w', orientation='portrait', papertype=None, format=None, bbox_inches=None, pad_inches=0.1, metadata=None, **kwargs)
# savefig 是 Matplotlib 库中 Figure 对象的一个方法,用于将当前的图形保存为文件。
# 参数说明 :
# fname :字符串或文件对象,指定保存文件的名称或路径。可以包括文件扩展名,如 'image.png' 或 'figure.pdf'。
# dpi :数字,指定图像的分辨率,以点每英寸(DPI)为单位。如果为 None ,则使用 Figure 对象的 dpi 属性值。
# facecolor :字符串或颜色代码,指定图像的背景颜色。
# edgecolor :字符串或颜色代码,指定图像边缘的颜色。
# orientation :字符串,指定页面方向,可以是 'portrait' 或 'landscape'。
# papertype :字符串,指定页面大小类型,如 'letter'、'A4' 等。需要与后端支持的页面大小兼容。
# format :字符串,指定文件格式。如果为 None ,则从文件扩展名推断。
# bbox_inches :字符串或 Bbox 对象,指定哪些内容应该被裁剪或包含在保存的文件中。
# pad_inches :数字,指定保存文件时在图像边缘的填充量。
# metadata :字典,包含额外的元数据,这些元数据将被保存在文件中。
# **kwargs :其他关键字参数,用于后端特定的文件保存选项。
# 返回值 :无返回值,该方法直接将图形保存为文件。
# 保存图像。使用 fig.savefig 将绘制的PR曲线保存为图像文件,指定路径和dpi(点每英寸)。
fig.savefig(Path(save_dir), dpi=250)
7.def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
# 这段代码定义了一个名为 plot_mc_curve 的函数,它用于绘制度量-置信度(Metric-Confidence, MC)曲线,并保存为图像文件。这个函数使用 matplotlib 库进行绘图,通常用于可视化模型在不同置信度水平下的性能度量,如精确度、召回率或 F1 分数等。
# 1.px :一个数组,包含置信度的值,通常范围从0到1。
# 2.py :一个数组或列表,包含对应的度量值,可以是每个类别的度量值。
# 3.save_dir :保存图像文件的路径,默认为 'mc_curve.png'。
# 4.names :一个元组,包含类别名称,用于图例标签。
# 5.xlabel :x轴的标签,默认为 'Confidence'。
# 6.ylabel :y轴的标签,默认为 'Metric'。
def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve
# 创建绘图环境。使用 plt.subplots 创建一个图形和轴对象,设置图形大小为9x6英寸,并启用紧凑布局。
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
# 绘制MC曲线。
# 如果 names 元组中的类别数量在1到20之间,为每个类别绘制MC曲线,并添加图例。
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py):
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
else:
# 如果类别数量不在这个范围内,则不显示每个类别的图例,只绘制所有类别的MC曲线。
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
# 绘制平均MC曲线。
# 计算所有类别的平均度量,并绘制平均MC曲线,设置线宽为3,颜色为蓝色,并添加标签显示所有类别的最大度量值及其对应的置信度。
y = py.mean(0)
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
# 设置坐标轴标签和范围。
# 使用 ax.set_xlabel 和 ax.set_ylabel 设置x轴和y轴的标签。
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
# 设置x轴和y轴的范围都是从0到1。
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
# 添加图例。使用 plt.legend 添加图例,设置位置在图形的右上角。
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
# 保存图像。使用 fig.savefig 将绘制的MC曲线保存为图像文件,指定路径和dpi(点每英寸)。
fig.savefig(Path(save_dir), dpi=250)