当前位置: 首页 > article >正文

【深度学习基础】一篇入门模型评估指标(分类篇)

🌈 个人主页:十二月的猫-CSDN博客
🔥 系列专栏: 🏀深度学习_十二月的猫的博客-CSDN博客

💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光

目录

1. 前言

2. 模型评估综述

2.1 什么是模型评估

2.2 评估类型

2.3 模型泛化能力

2.4 过拟合与欠拟合

3. 常见的分类模型评估方式

3.1 准确率(Accuracy)

3.2 精确率(Precision)

3.3 召回率(Recall)

3.4 F1-score

3.5 ROC曲线及AUC值

3.6 PR曲线

4. 代码实现

5. 总结 


1. 前言

本篇针对的是刚刚接触机器学习的友友~~

在猫猫刚刚入门机器学习时,有位学长问我,召回率是什么?

猫猫那是一脸懵逼,然后他又来了三连问。准确率是什么?精准率是什么?有什么曲线评估模型性能吗?

在当时猫猫的脑海中,模型好坏不就是看他预测准了多少个样本,准确率多少吗?哪来那么多东西,哈哈哈哈哈哈,机器学习还是好玩的。

想到很多刚刚入门的友友也可能苦恼于模型评估指标,猫猫便写了这篇博客,希望能给大家带来点帮助。

2. 模型评估综述

2.1 什么是模型评估

       模型评估是指在机器学习中,对于一个具体方法输出的最终模型,使用一些指标和方法来评估它的泛化能力。这一步通常在模型训练和模型选择之后,正式部署模型之前进行。模型评估不针对模型本身,而是针对问题和数据,因此可以用来评价不同方法的模型的泛化能力,以此决定最终模型的选择。

模型评估:量化模型在解决目标问题上的能力

2.2 评估类型

       机器学习的基本任务大致分为三类,分别是分类(Classification)回归(Regression)聚类(Clustering),在本文仅介绍分类的模型评估。

2.3 模型泛化能力

       模型的泛化能力是机器学习的一个重要概念和指标。是指一个模型在训练集之外的未知数据上的表现能力,泛化能力强的模型能够正确学习到数据的普遍规律并将其运用到新的数据上从而做出准确的预测。

       简单地说,当模型在训练样本上表现良好,并且能在新的数据样本上保持相同的表现,我们就可以说这个模型的泛化能力强。

2.4 过拟合与欠拟合

       过拟合与欠拟合是机器学习中常见的两个概念,描述的是模型在训练数据和未知数据上表现的差异。下图中第一到第三的图分别是欠拟合、正好和过拟合三种状态。

过拟合:

       过拟合表现为在训练数据上表现优秀的那在新数据集上表现较差,通常是因为模型过于复杂,学习到了数据集中的细节和噪声而不是数据的真实分布,因而泛化能力差。

举个很经典的例子,当我们设计一个模型用于判断一个物品是否是树叶,而训练集中有几个样本的叶片边缘带有尖刺,模型作者希望他的模型能够符合贴近他的每一个样本,因而将带有“尖刺边缘”这一不是明显特征的特点纳入了参数中,这样一来,模型就能够完美贴合训练样本,但在应用时会发现模型容易钻牛角尖,会将不带尖刺边缘的样本排除,无法识别不带尖刺边缘的树叶,这就导致了模型的泛化能力差。

解决模型过拟合的问题,有以下几种方法:

  • 增加数据量以提供更多信息,减少噪声的影响
  • 简化模型,减少模型参数的数量
  • 通过交叉验证评估模型的泛化能力

欠拟合:

       欠拟合则与过拟合相反,由于模型过于简单,无法学习到数据的足够特征,无法正确捕捉数据的复杂性和变化,没有学习到数据的规律。这就导致模型不管是在训练样本还是未知样本上的表现都不佳。

常见的解决欠拟合的方法有:

  • 增加模型的复杂度,如增加更多的特征或使用更复杂的模型
  • 收集更多的数据,提供更加丰富的信息给模型
  • 增加训练时间或者调整模型的超参数

3. 常见的分类模型评估方式

       混淆矩阵是分类模型巩固的一个重要工具,可以直观展示模型的预测结果和实际结果之间的关系,通常由以下四个部分构成:

  • 真正类 (True Positives, TP): 模型正确地预测正类的数量。
  • 假负类 (False Negatives, FN): 模型错误地将正类预测为负类的数量。
  • 假正类 (False Positives, FP): 模型错误地将负类预测为正类的数量。
  • 真负类 (True Negatives, TN): 模型正确地预测负类的数量。

 二元混淆矩阵格式如下:

多元分类矩阵格式如下:

3.1 准确率(Accuracy)

准确率是指模型正确预测的样本总数占总样本总数的比例,其计算公式为:

Accuracy=\frac{TN+TF}{TN+FN+TP+FP}=\frac T{T+F}

3.2 精确率(Precision)

       精确率的概念比较容易与准确率的概念混淆,准确率的目标是所有样本,计算的是所有分类正确样本占总样本的多少,而精确率是指在所有被模型预测为正类的样本中,实际为正类的样本的比例,关注的是被模型分为此类的数据中有多少是正确的。

       当我们使用精确率(下使用Precision代替)作为考量时,优点是不容易出现假正类,但是,当precision值过高时,容易出现模型偏向某一类别的情况,因为此时模型会倾向于预测多数类别而忽略了少数类别的预测。

       举一个例子,假设在一个疾病诊断问题中,疾病发生的实际情况(正类)非常罕见。如果模型仅仅通过预测大多数人都是健康的(负类)来提高Precision,那么它可能会忽略真正的病例,因为这些病例在数据中占比很小。这样的模型虽然精确度高,但其实用性非常有限,因为它未能有效识别和预测少数但重要的正类样本。

       在混淆矩阵中,这表现为某一行的T除以本行上所有数字的和,如下列表格标注了颜色的A行就是红色的TA值除以TA加FA的和。

其公式为:

Precision=\frac{TP}{TP+FP}

3.3 召回率(Recall)

       召回率(下使用Recall替代)衡量的是模型正确识别为正类的实例(真正类)占所有实际正类实例的比例。当recall值高时,模型会更容易捕捉到正类,但也会导致假正类出现的比例增加的情况。

       举个例子说明,假设我们有一个用于检测信用卡欺诈的模型,其中正类(欺诈)非常罕见。在10,000个交易中,可能只有100个是欺诈性的。如果我们只关注召回率,模型可能会被调整为将更多的交易标记为欺诈,以确保它不会错过那些真正的欺诈案例。例如,模型可能会将1,000个交易标记为欺诈,其中包括所有100个真正的欺诈案例和900个实际上是合法的交易。在这种情况下,召回率是100%,因为所有的欺诈交易都被正确地识别了。然而,这样做的代价是产生了很多假正类(False Positives)—那些被错误标记为欺诈的合法交易。这会导致很多不必要的麻烦,比如客户满意度下降和增加的客户服务成本。

       在混淆矩阵中,这表现为某一列中的T除以本列上左右数字的和,如下列表格,计算recall即红色字体的值除以红色和绿色字体的值的和,计算设计的数据包括正确分类的本类型样本的值和被错误分类至本类型的样本的值。

其公式为:

Recall=\frac{TP}{TP+FN}

3.4 F1-score

       正如上面所说的,当我们单独地使Precision值或者Recall值增高,都会使模型走向极端,因而我们引入了F1值,即Recall和Precision的调和平均数,,因为F1值综合考虑了Recall和Precision,因而其尤其适合在数据不平衡(不同类别的样本数据量差异很大)的情况时进行使用。当分数更高时,说明模型再识别少数类方面的能力更强,同时保持了较高的Recall和Precision平衡。其公式如下:

F1=2\times\frac{Precision\times Recall}{Percision+Recall}

       值得一提的是,这是在Recall和Precision的比重相同的同属情况下使用的,如果你认为其中某个值更加重要,你可以使用Fβ-score,其公式为:

F_\beta=(1+\beta^2)\cdot\frac{Precision\cdot Recall}{(\beta^2\cdot Precision)+Recall}

       其中的β是Recall和Precision的比值,当β大于1时,该分数会给予Recall更高的权重。

3.5 ROC曲线及AUC值

       在二分类的过程中,我们通常会设置一个阈值(取值为0到1之间),大于阈值的会被归于正类,小于阈值的会被归为负类,当我们降低阈值时,样本会更容易地被归为正类,但也会更容易出现假正类,反之则更容易出现遗漏的情况,而随着阈值的变化,混淆矩阵也会出现变化,为了直观地体现这种变化,我们引入了ROC曲线。

       在ROC曲线中,x和y轴分别为假阳性率(TPR)真阳性率(FPR),其中,假阳性率表示在所有阴性(即Negative)样本中,被错误地预测为阳性(即Positive)的比例,计算公式为:

FPR=\frac{FP}{FP+TN}

       而真阳性率又称召回率,表示所有实际阳性样本被正确预测为阳性的比例,计算公式为:

TPR=\frac{TP}{TP+FN}

       通过ROC曲线,我们可以明确地直观地看出模型的好坏,为了模型准确率更高,我们自然而然地希望真阳性率更高而假阳性率更低,因此,当曲线越靠近左上角,我们会认定这条曲线所代表的模型判断准确率更高,如下图所示

       这是通过KNN算法和决策树算法对sklearn库内置的乳腺癌库进行训练预测的ROC曲线结果图,如图所示,由于代表KNN算法的折线更靠近左上角,因而我们可以说在这个数据集上,使用了KNN算法的模型表现更好。

       而AUC(Area Under the Curve)值就是指曲线下的面积,当AUC值越接近1,可说明模型的分类性能更好。如下图涂黄的部分就是该曲线的AUC值。

3.6 PR曲线

       PR曲线,即精确率-召回率曲线,它是以召回率(Recall)为x轴,精确率(Precision)为y轴的曲线,在机器学习中,尤其在再不平衡数据集中非常有用。

       当我们在改变模型的分类阈值时,TP、FP和FN等都会发生变化,从而导致Recall和Precision发生变化,PR曲线展示了这种变化关系,可以帮助我们理解模型在不同阈值下的性能表现。通常来说,曲线越靠近右上方,代表模型的表现越好。

       其AUC值同样是一个重要的性能指标,反映了模型在所有可能的分类阈值上的平均效果,当AUC值越大表示模型性能越好。

       如下图所示的是KNN算法和决策树算法对sklearn库内置的乳腺癌库进行训练预测的PR曲线的结果图:

4. 代码实现

使用sklearn库中内置的算法和数据集进行实操,仅仅作为一个补充练习:

from sklearn.datasets import load_breast_cancer  # 导入乳腺癌数据集
from sklearn.model_selection import train_test_split  # 导入数据集分割工具
from sklearn.neighbors import KNeighborsClassifier  # 导入K近邻分类器
from sklearn.tree import DecisionTreeClassifier  # 导入决策树分类器
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score  # 导入性能评估工具
import matplotlib.pyplot as plt  # 导入绘图库

# 加载乳腺癌数据集
data = load_breast_cancer()
X = data.data  # 特征数据
y = data.target  # 标签数据

# 将数据集分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)  # 将训练集和测试集按8:2的比例拆分

# 初始化并训练K近邻分类器
knn_classifier = KNeighborsClassifier(n_neighbors=5)
knn_classifier.fit(X_train, y_train)

# 初始化并训练决策树分类器
tree_classifier = DecisionTreeClassifier(random_state=42)
tree_classifier.fit(X_train, y_train)

# 预测测试集的概率
y_scores_knn = knn_classifier.predict_proba(X_test)[:, 1]
y_scores_tree = tree_classifier.predict_proba(X_test)[:, 1]

# 计算KNN的ROC曲线和AUC值
fpr_knn, tpr_knn, _ = roc_curve(y_test, y_scores_knn)
roc_auc_knn = auc(fpr_knn, tpr_knn)

# 计算决策树的ROC曲线和AUC值
fpr_tree, tpr_tree, _ = roc_curve(y_test, y_scores_tree)
roc_auc_tree = auc(fpr_tree, tpr_tree)

# 计算KNN的精确度-召回率曲线和平均精确度
precision_knn, recall_knn, _ = precision_recall_curve(y_test, y_scores_knn)
average_precision_knn = average_precision_score(y_test, y_scores_knn)

# 计算决策树的精确度-召回率曲线和平均精确度
precision_tree, recall_tree, _ = precision_recall_curve(y_test, y_scores_tree)
average_precision_tree = average_precision_score(y_test, y_scores_tree)

# 绘制ROC曲线
plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
plt.plot(fpr_knn, tpr_knn, color='darkorange', lw=2, label='KNN (AUC = %0.2f)' % roc_auc_knn)
plt.plot(fpr_tree, tpr_tree, color='green', lw=2, label='Decision Tree (AUC = %0.2f)' % roc_auc_tree)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc="lower right")

# 绘制精确度-召回率曲线
plt.subplot(1, 2, 2)
plt.plot(recall_knn, precision_knn, color='blue', lw=2, label='KNN (AP = %0.2f)' % average_precision_knn)
plt.plot(recall_tree, precision_tree, color='purple', lw=2, label='Decision Tree (AP = %0.2f)' % average_precision_tree)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('P-R Curve')
plt.legend(loc="lower left")

plt.tight_layout()
plt.show()

5. 总结 

如果想要学习更多深度学习知识,大家可以点个关注并订阅,持续学习、天天进步

你的点赞就是我更新的动力,如果觉得对你有帮助,辛苦友友点个赞,收个藏呀~~~

本篇文章转载自 机器学习——常见模型评估指标 - Dronnnnn - 博客园


http://www.kler.cn/a/412652.html

相关文章:

  • JAVA篇05 —— 内部类(Local、Anonymous、Member、Static)
  • [创业之路-155] :《领先的密码-BLM方法论全面解读与应用指南》- 综合管理框架
  • ctfshow
  • 极狐GitLab 17.6 正式发布几十项与 DevSecOps 相关的功能【一】
  • git: 修改gitlab仓库提交地址
  • NLP论文速读(剑桥大学出品)|分解和利用专家模型中的偏好进行改进视觉模型的可信度
  • Linux 时间属性
  • SurfaceFlinger学习之一:概览
  • 大模型专栏--大模型开发框架
  • Spring | (七)AOP概念及工作流程
  • 【速通GO】数据类型与变量和常量
  • 丹摩 | 基于PyTorch的CIFAR-10图像分类实现
  • 第三方数据库连接免费使用和安装
  • 白光干涉仪:表面粗糙度形貌台阶高测量解决方案
  • Flutter 共性元素动画
  • 工业网络安全 智能电网,SCADA和其他工业控制系统等关键基础设施的网络安全(总结)
  • 无法通过外网连接访问mysql问题排查
  • 如何通过终端连接无线网
  • echarts使用示例
  • laravel官方升级引起的报错问题解决
  • 原子类、AtomicLong、AtomicReference、AtomicIntegerFieldUpdater、LongAdder
  • [python]poetry安装和使用
  • Vue前端面试进阶(五)
  • day29|leetcode 134. 加油站 , 135. 分发糖果 ,860.柠檬水找零 , 406.根据身高重建队列
  • 模型压缩理论简介及剪枝与稀疏化在 征程 5 上实践
  • 检测到“runtimelibrary”的不匹配项: 值“mtd_staticdebug”不匹配值“mdd_dynamic”