当前位置: 首页 > article >正文

ML基础——分类模型的评估指标

以鸢尾花的分类分析为例:

首先我们对于3类分类做了个数字编码,012对应3个分类,

所以这就是3个分类下的预测结果指标:
分类为0的准确率、召回率、F1-score,以及support

因为前面我们数据集划分为了105 train:45 test,所以右边这里的support是每个类别都有多少样本,比如说0的support是17,就是说这45样本的test集中,实际标签为0的样本有17个。

F1-score是precision和recall的调和平均。
最后三行是各个指标计算结果的平均值,宏平均和加权平均(最后一列除外);

accuracy不用说了,看下文部分的定义,可以通过相关模块进行评估;

至于macro或weighted等指标计算,其实很简单,

可以参考:
https://mp.weixin.qq.com/s/2Sw_YpDZWGEi4bA2R0VsVw

解释如下:

一,分类:准确率、召回率、精确率和相关指标

真正例、假正例和假负例用于计算评估模型的几个实用指标。哪些评估指标最有意义,取决于具体模型和具体任务、不同错误分类的代价,以及数据集是平衡的还是不平衡的。

本部分中的所有指标均基于单个固定阈值计算得出,并且会随阈值的变化而变化。很多时候,用户会调整阈值以优化其中某个指标。

参考https://developers.google.com/machine-learning/crash-course/classification/accuracy-precision-recall?hl=zh-cn

1,准确率(Accuracy)

准确性是指所有分类(无论是正类还是负类)正确分类的比例。其数学定义为:

在垃圾邮件分类示例中,准确度衡量的是所有电子邮件正确分类所占的比例

完美的模型没有假正例和假负例,因此准确率为 1.0,即 100%。

由于精度包含混淆矩阵中的所有四种结果(TP、FP、TN、FN),因此在数据集平衡且两个类别中的示例数量相近的情况下,精度可以用作衡量模型质量的粗略指标。因此,它通常是执行通用或未指定任务的通用或未指定模型使用的默认评估指标。

不过,如果数据集不平衡,或者一种错误(假负例或假正例)的代价高于另一种错误(大多数实际应用中都是如此),则最好改为针对其他指标进行优化。

对于严重不均衡的数据集(其中一个类别出现的频率非常低,例如 1%),如果模型 100% 都预测为负类,则准确率得分为 99%,尽管该模型毫无用处。

注意:在机器学习 (ML) 中,recallprecisionaccuracy 等字词的数学定义可能与这些字词的常用含义不同,或更具体。

from sklearn.metrics import accuracy_score
accuracy = accuracy_score(Y_test, pre_res)
print(f'Accuracy: {accuracy}')

2,召回率或真正例率/真阳性率(Recall)

真正例率 (TPR,真阳性),即所有实际正例(正确分类的就是TP,错误分类的就是FN)正确分类为正例的比例,也称为召回率。

在数学上,召回率的定义为:

假负例(假阴性)是指被误分类为负例的实际正例,因此会出现在分母中。在垃圾邮件分类示例中,召回率衡量的是被正确分类为垃圾邮件的垃圾邮件电子邮件的比例。因此,召回率的另一个名称是检测概率:它可以回答“此模型检测到垃圾邮件的比例是多少?”这一问题。

假设一个完美的模型,其假负例率(假阴性)为 0,因此召回率 (TPR) 为 1.0,也就是说,检测率为 100%。

在实际正例数量非常少(例如总共 1-2 个样本)的不均衡数据集中,召回率作为指标的意义不大,作用不大。

from sklearn.metrics import recall_score

recall = recall_score(Y_test, pre_res)
print(f'Recall: {recall}')

和下面的precision一样,对于二分类可以像上面那样处理,但是多分类的话需要加其他参数;

from sklearn.metrics import recall_score

recall_micro = recall_score(Y_test, pre_res, average='micro')
print(f'Recall: {recall_micro}')

recall_macro = recall_score(Y_test, pre_res, average='macro')
print(f'Recall: {recall_macro}')

recall_weighted = recall_score(Y_test, pre_res, average='weighted')
print(f'Recall: {recall_weighted}')

3,假正例率/假阳性率

假正例率 (FPR) 是指被错误地归类为正例的在所有实际负例所占的比例,在所有实际负样本中假阳性所占的比例,也称为误报概率。其数学定义为:

假正例是被错误分类的实际负例,因此会出现在分母中。在垃圾邮件分类示例中,FPR 用于衡量被错误分类为垃圾邮件的合法电子邮件的比例,或模型的误报率。

完美的模型没有假正例,因此 FPR 为 0.0,也就是说,假警报率为 0%。

在实际负例数量非常少(例如总共 1-2 个示例)的不平衡数据集中,FPR 作为一个指标就没有那么有意义和实用。

4,精确率(Precision)

精确率是指模型所有正分类分类中实际正分类的比例,所有被分成阳性的样本中真阳性的比例。在数学上,其定义为:

在垃圾邮件分类示例中,精确率衡量的是被归类为垃圾邮件的电子邮件中实际是垃圾邮件的比例

假设有一个完美的模型,则该模型没有假正例(假阳性率),因此精确率为 1.0。

在实际正例数量非常少(例如总共 1-2 个示例)的不平衡数据集中,精确率作为指标的意义和实用性较低。

from sklearn.metrics import precision_score
precision = precision_score(Y_test, pre_res)
print(f'Precision: {precision}')

如果是2分类可以参照上面的方法进行操作,否则:

多分类问题:

precision_micro = precision_score(Y_test,pre_res,average='micro')
print(f'Precision (micro): {precision_micro}')

precision_macro = precision_score(Y_test, pre_res, average='macro')
print(f'Precision (macro): {precision_macro}')

precision_weighted = precision_score(Y_test, pre_res, average='weighted')
print(f'Precision (weighted): {precision_weighted}')

5,关于精确率precision和召回率recall:

随着假正例的减少,精确率会提高;随着假负例的减少,召回率会提高。

但正如前面所述,提高分类阈值往往会减少假正例的数量并增加假负例的数量,而降低阈值则会产生相反的效果。

所以,改变分类阈值——》影响FP、FN数目——》共轭影响precision和recall,且两个是呈反函数关系

因此,精确率和召回率通常呈反函数关系,其中一个提高improve会恶化另一个worsen

(1)可视化总结

(2)指标的选择和权衡

在评估模型和选择阈值时,您选择优先处理的指标取决于特定问题的费用、收益和风险。在垃圾邮件分类示例中,通常最好优先考虑召回率(抓取所有垃圾邮件)或准确率(尝试确保被标记为垃圾邮件的电子邮件实际上是垃圾邮件),或者在达到某个最低准确性水平的情况下,兼顾这两者。

指标指南
准确率作为平衡数据集的模型训练进度/收敛情况的粗略指标。
对于模型效果,请仅与其他指标搭配使用。
避免使用不平衡的数据集。考虑使用其他指标。
召回率 (真正例率)在假负例比假正例开销更高时使用。
假正例率在假正例比假负例开销更高时使用。
精确率当正例预测的准确性非常重要时,请使用此方法。

(3)F1-score

调和平均值大家中学学均值不等式应该都了解,上链接:
https://zh.wikipedia.org/wiki/%E8%B0%83%E5%92%8C%E5%B9%B3%E5%9D%87%E6%95%B0

所以F1-score就是recall和precision的二数的调和平均数,也就是鱼与熊掌兼得

同样的,需要区分是2分类还是多分类任务:否则会报错

from sklearn.metrics import f1_score
f1 = f1_score(Y_test, pre_res, average='macro')
print("F1-score:", f1)

from sklearn.metrics import f1_score

f1_micro = f1_score(Y_test, pre_res, average='micro')
print("F1-score:", f1_micro)

f1_macro = f1_score(Y_test, pre_res, average='macro')
print("F1-score:", f1_macro)

f1_weighted = f1_score(Y_test, pre_res, average='weighted')
print("F1-score:", f1_weighted)

一些练习用来掌握回顾:

二,wiki中的分类模型评估指标扩展

1,ROC曲线

完美的模型是没有Fx率的(也就是没有F开头的比率的),所以套进去公式看看,加上一点简单的数学

参考https://zh.wikipedia.org/wiki/ROC%E6%9B%B2%E7%BA%BF

受试者操作特征曲线,ROC曲线

这里就不得不提混淆矩阵了,基本上本科低年级学过数理统计的都应该倒背如流相应的概念

ROC曲线:假阳x,真阳y

x轴是假阳性率(FPR,假阳,那就是从阴性样本来考虑——》所有实际为阴性的样本中,被错误归类为阳性的样本FP比例)

y轴是真阳性率(TPR,真阳,那就是从阳性样本来考虑——》所有实际为阳性的样本中,被正确归类为阳性的样本TP比例)

然后在训练过程中,模型会迭代,

对于过程中的某一个模型+某一个阈值,我们都能计算出它的混淆矩阵等所有分类评估指标,那自然就能算出x轴以及y轴代表的模型指标,也就是说每个过程中的模型都能计算出1个坐标点(x,y);

然后完美的模型是没有F开头的比率的(就是完美的模型是不会产出假阳性假阴性假xxx的),所以结合x轴以及y轴的定义:

完美的模型就是TPR(y轴)=1,FPR(x轴)=0,也就是完美的模型的结果坐标点是(0,1),也就是左上角的点

然后非常差的模型就是没有T开头的比例,就是没有真阳真阴,都是假阳假阴,只有F开头的比例,

那就是(1,0),就是右下角;

如果是随机的预测,那就是55开,也就是说混淆矩阵中每一列(注意是列)下面的值应该是55开,

那再结合x轴以及y轴定义,那就是TP/FN 55开,FP/TN 55开,那就是(0.5,0.5),

准确来说,如果是无识别的话,那么TPR=FPR(这里需要意译理解,联合上面的0.5),

也就是对角线;

如果模型分类效果好,比随机好,那就是TPR会好点>FPR(T开头的比率比F开头的比率好),也就是y>x,在对角线上,反之下;

然后上面是指模型+阈值定下来之后绘制的单个点,考虑不同模型不同阈值,会有各种可能性;

如果我们只考虑某种模型,然后将阈值进行修改(一般高于阈值是P阳性,低于阈值是N阴性)
那某个模型在阈值改动之后绘制出来的(x,y)模型评估指标坐标,就是该模型的ROC曲线

上面的这个分布图很好理解,上面蓝色的是所有实际上为阴性的样本,下面红色的是所有实际为阳性的样本;

然后竖线是阈值,阈值往上、右边是阳性(预测、评估为阳性),阈值左边为阴性;

那对于上面的蓝色图,在阈值竖线左边的就是实际为阴性,然后预测为阴性的,那就是真阴性;阈值竖线右边的就是实际为阴性,但是根据阈值预测为阳性的,也就是假阳性。

下面红色同样分类,竖线左侧是实际为阳性,被分类、被预测为阴性,假阴性;竖线右侧是实际上为阳性也被分类为阳性,就是真阳性。

下面的图也好理解:
阈值改动之后,如何理解ROC曲线中的(x,y)模型预测坐标如何移动;

如果阈值很高时,比如说是A点,或者是往阈值很高方向移动时

我们分别看竖线阈值左右两侧的数目比例,TP会变小,FP会变大,TN增大,FP减小,

然后总的分母FP+TN=all N不变(就是蓝色的总面积),同样红色的总面积,对应分母TP+FN=all P不变。

那么分母不变,我们看分子,TP和FP都会变小,所以会向(0,0)移动;

也就是说阈值变高,ROC曲线中的坐标点会偏向(0,0)左下角移动;

阈值变低,同样分析趋势,就是偏向(1,1)右上角移动;

====》

AUC,其实就是ROC曲线下的面积,严谨来说应该是AUC ROC

就是将面积积分起来

具体如何从数学意义上去理解AUC,可以在后续博客中深入

如果是2分类的话,绘制ROC曲线可以是:

import numpy as np
from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt

# 假设 y_true 是真实标签,y_prob 是预测的正类概率
y_true = np.array([0, 1, 1, 0, 1, 0])
y_prob = np.array([0.1, 0.4, 0.35, 0.8, 0.7, 0.2])

# 计算ROC曲线
fpr, tpr, thresholds = roc_curve(y_true, y_prob)

# 计算AUC值
auc = roc_auc_score(y_true, y_prob)

# 绘制ROC曲线
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()

print(f'AUC: {auc}')

我们现在主要是看多分类:
多分类的话需要从multiclass中导入OneVsRestClassifier函数,就是搞1 vs rest,强行凑2分类

from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.multiclass import OneVsRestClassifier
import matplotlib.pyplot as plt

# 使用OneVsRestClassifier处理多分类问题
classifier = OneVsRestClassifier(ldc)
y_prob = classifier.fit(X_train_spca, Y_train).predict_proba(X_test_spca)

# 计算每个类别的ROC曲线和AUC
for i in range(y_prob.shape[1]):
	fpr, tpr, _ = roc_curve(Y_test == i, y_prob[:, i])
	auc = roc_auc_score(Y_test == i, y_prob[:, i])
	plt.plot(fpr, tpr, label='Class {} (AUC = {:.2f})'.format(i, auc))

plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.show()

获取元祖的1索引,也就是列数

然后每一列,首先roc_curve提供真实y标签(test set)以及评估

本质上绘制的就是散点图,可以参考我的上一篇博客:
https://blog.csdn.net/weixin_62528784/article/details/145360599?spm=1001.2014.3001.5501

总之代码解释如下:

2,其他指标:

阳性 (P, positive)

阴性 (N, Negative)

真阳性 (TP, true positive)

正确的肯定。又称:命中 (hit)

真阴性 (TN, true negative)

正确的否定。又称:正确拒绝 (correct rejection)

伪阳性 (FP, false positive)

错误的肯定,又称:假警报 (false alarm),第一型错误

伪阴性 (FN, false negative)

错误的否定,又称:未命中 (miss),第二型错误

真阳性率 (TPR, true positive rate)

又称:命中率 (hit rate)、敏感度(sensitivity)

TPR = TP / P = TP / (TP+FN)

伪阳性率(FPR, false positive rate)

又称:错误命中率,假警报率 (false alarm rate)

FPR = FP / N = FP / (FP + TN)

准确度 ****(ACC, accuracy)

ACC = (TP + TN) / (P + N)

即:(真阳性+真阴性) / 总样本数

真阴性率 (TNR)

又称:特异度 (SPC, specificity)

SPC = TN / N = TN / (FP + TN) = 1 - FPR

阳性预测值 (PPV)

PPV = TP / (TP + FP)

阴性预测值 (NPV)

NPV = TN / (TN + FN)

假发现率 (FDR)

FDR = FP / (FP + TP)

Matthews相关系数 (MCC),即 Phi相关系数

F1评分

F1 = 2TP/(P+P’)

交叉熵损失:这个在DL深度学习多分类中就很常见了,softmax之后用cross entropy

如果是二分类,类似

from sklearn.metrics import log_loss
y_true = [0, 1, 1, 0, 1, 0]
y_prob = [0.2, 0.8, 0.9, 0.4, 0.7, 0.3]

loss = log_loss(y_true, y_prob)
print(f'Log Loss: {loss}')

多分类:

from sklearn.metrics import log_loss

# 计算交叉熵损失
loss = log_loss(Y_test, pd.get_dummies(pre_res), labels=[0, 1, 2])
print(f'Log Loss: {loss}')

3,上述指标的实现

(1)混淆矩阵:

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(Y_test, pre_res)
print(cm)

这个混淆矩阵大家都很熟悉了,三分类的

当然只是个array难看,最常见的就是可视化,用seaborn

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix


# 计算混淆矩阵
cm = confusion_matrix(Y_test, pre_res)
print(cm)

# 可视化混淆矩阵
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Class 0', 'Class 1', 'Class 2'], yticklabels=['Class 0', 'Class 1', 'Class 2'])
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()

(2)AUC/ROC的可视化见前面

(3)最基本的统计指标的分析也见前面

总结

指标适用场景特点
准确率(Accuracy)适用于类别平衡数据集简单直观,但在类别不平衡时不可靠
精确率(Precision)偏重减少假正例(FP)的场景,如垃圾邮件检测关注正类预测的精确性
召回率(Recall)偏重减少假负例(FN)的场景,如疾病诊断关注正类样本的识别率
F1值(F1 Score)当精确率和召回率同等重要时平衡精确率和召回率的权衡
AUC-ROC适合类别不平衡的数据集,评估模型整体性能显示模型的分类能力,尤其适用于不平衡数据集
混淆矩阵直观展示分类结果,适用于二分类和多分类问题展示每个类的分类情况,特别适合多分类问题
交叉熵损失(Cross Entropy Loss)主要用于多分类任务衡量概率分布差异,通常在深度学习中使用

每个分类评估指标都有特定的应用场景,选择合适的指标有助于更好地理解模型性能并针对问题进行优化。

参考:
https://developers.google.com/machine-learning/crash-course/classification/accuracy-precision-recall?hl=zh-cn

https://zh.wikipedia.org/wiki/ROC%E6%9B%B2%E7%BA%BF

https://mp.weixin.qq.com/s/YyzU67doaB_btbnUE6Q_Bg


http://www.kler.cn/a/523054.html

相关文章:

  • 数字化转型-工具变量(2024.1更新)-社科数据
  • 【C++题解】1393. 与7无关的数?
  • Java面试题2025-并发编程进阶(线程池和并发容器类)
  • 神经网络|(七)概率论基础知识-贝叶斯公式
  • 【实践案例】使用Dify构建文章生成工作流【在线搜索+封面图片生成+内容标题生成】
  • Ubuntu二进制部署K8S 1.29.2
  • STM32 TIM定时器配置
  • 虚幻基础08:组件接口
  • 在ubuntu下一键安装 Open WebUI
  • 能够对设备的历史数据进行学习与分析,通过与设备当前状态的比对,识别潜在故障并做出预判的名厨亮灶开源了。
  • 宝塔安装完redis 如何访问
  • 信息学奥赛一本通 1396:病毒(virus)
  • c++多态
  • JavaScript逆向高阶指南:突破基础,掌握核心逆向技术
  • Nginx 开发总结
  • 《网络数据安全管理条例》施行,企业如何推进未成年人个人信息保护(上)
  • 深入探索C++17的std::any:类型擦除与泛型编程的利器
  • STM32 LED呼吸灯
  • pycharm(2)
  • noteboolm 使用笔记
  • 面向对象编程简史
  • Facebook如何应对全球范围内的隐私保护挑战
  • Python vLLM 实战应用指南
  • OpenCV:图像轮廓
  • 二叉树的最大深度(遍历思想+分解思想)
  • 算法随笔_28:最大宽度坡_方法2