线性模型 - 多分类问题
前面的博文中,介绍了二分类问题,本文我们学习和掌握多分类问题。
多分类(Multi-class Classification)问题是指分类的类别数 𝐶 大于 2。多分类一般需要多个线性判别函数,但设计这些判别函数有很多种方式。
多分类问题是机器学习中的核心任务之一,其目标是将输入数据划分到三个或更多类别中。与二分类问题(仅区分两类)不同,多分类需要更复杂的模型设计和策略。以下从核心概念、实际场景、常用方法和注意事项四方面展开解释,帮助初学者系统理解。
一、核心概念
-
定义:
多分类问题(Multi-class Classification)指模型需从多个互斥的类别中,为输入样本选择一个最可能的类别标签。-
互斥性:每个样本只能属于一个类别(如动物分类为“猫、狗、鸟”)。
-
非互斥场景则属于多标签分类(Multi-label Classification,如一张图片可同时包含“猫”和“草地”)。
-
-
关键挑战:
-
类别数量增加导致模型复杂度上升。
-
类别间可能存在相似性(如区分“狼”和“哈士奇”)。
-
样本分布不均衡(某些类别数据极少)。
-
二、实际应用场景
三、多分类的常用方法
1. 基于二分类的扩展策略
-
One-vs-Rest (OvR):
为每个类别训练一个二分类器,判断样本是否属于该类别。
示例:区分猫、狗、鸟时:-
分类器1:猫 vs 非猫
-
分类器2:狗 vs 非狗
-
分类器3:鸟 vs 非鸟
最终结果:选择置信度最高的类别。
-
-
One-vs-One (OvO):
为每两个类别训练一个二分类器,通过投票确定最终类别。
示例:区分猫、狗、鸟时:-
分类器1:猫 vs 狗
-
分类器2:猫 vs 鸟
-
分类器3:狗 vs 鸟
最终结果:统计所有分类器的投票结果,选择票数最多的类别。
-
优缺点对比:
2.“argmax”方式(改进的“一对其余”方式,共需要 𝐶 个判 别函数)
在多分类问题中,我们通常为每个类别计算一个得分(或概率),然后采用“argmax”方式来做出预测。直观地讲,“argmax”操作就是找出哪个类别的得分最高,并把这个类别作为最终预测结果。
定义
“argmax”是“argument of the maximum”的缩写,它的意思是返回使某个函数取得最大值的那个自变量。例如,给定函数 f(k),
就表示在所有可能的 k 中,哪一个使 f(k)最大。
-
多分类中的应用
在多分类任务中,假设我们有 K 个类别,对每个输入 x,模型输出一个得分向量
其中 f_k(x) 表示输入 x 属于第 k 个类别的“得分”或“置信度”。
最终预测的类别为:也就是说,我们选择使得 f_k(x) 最大的那个类别 k。
-
举例说明
假设一个图像分类问题有三类:猫、狗和兔子。模型对某张图片的输出得分向量为
f(x)=[2.1, 3.5, 1.8],分别对应猫、狗、兔子的得分。
进行 argmax 操作,得到 argmax=2(假设类别索引从1开始,最大对应的3.5),说明第二个类别(狗)的得分最高,模型预测这张图片属于狗。 -
意义
- 简单直观:argmax 是一个非常直观的决策规则,只要选出得分最高的那个类别。
- 泛化性:这种方法适用于各种多分类模型,无论是基于概率输出(如 softmax 回归)还是直接输出得分的模型(如支持向量机多分类),都可以使用 argmax 来确定最终类别。
- 解释性:使用 argmax 操作时,输出的每个得分可以理解为模型对每个类别的信心,而 argmax 就是选出信心最高的那个。
在多分类问题中,“argmax”操作就是从模型输出的多个得分中选出最高得分对应的类别,它是将连续的得分转化为离散类别决策的关键步骤。这种方式简单、直观且适用于各种多分类模型。
3. 原生支持多分类的算法
-
决策树与随机森林:
通过信息增益(如基尼系数、熵)直接分割多类别数据。
示例:根据花瓣长度、宽度等特征直接分类鸢尾花(Setosa、Versicolor、Virginica)。 -
朴素贝叶斯:
计算每个类别的后验概率,选择概率最大的类别。
示例:根据邮件内容判断属于“广告、社交、工作”中的哪一类。 -
神经网络:
输出层使用Softmax激活函数,将输出转换为概率分布。
示例:ResNet模型对ImageNet数据集(1000类)进行分类。
以上这些模型,我们先暂时了解对应的概念,后续我们会挨个学习到。
四、关键技术细节
1. 损失函数
-
多分类交叉熵损失(Cross-Entropy Loss):
最常用的损失函数,结合Softmax输出:
-
-
C:类别总数
-
yi:真实标签的one-hot编码
-
pi:模型预测的第i类的概率
-
示例:
-
真实标签:狗(对应one-hot编码[0,1,0])
-
预测概率:[0.1, 0.7, 0.2]
-
损失值:−log(0.7)≈0.357
2. 评估指标
-
准确率(Accuracy):
Accuracy=正确预测数/总样本数
正确预测的样本比例,适用于类别均衡的场景。 -
混淆矩阵(Confusion Matrix):
直观展示每个类别的预测情况,帮助分析模型弱点。
宏平均F1(Macro-F1):
对每个类别的F1值取平均,适合类别不均衡场景。
关于宏平均F1的概念,可以参考博文:机器学习 - 机器学习模型的评价指标 -CSDN博客
五、实际应用注意事项
-
类别不平衡处理:
-
过采样少数类(如SMOTE算法)或欠采样多数类。
-
使用加权损失函数,增加少数类的惩罚权重。
-
-
特征工程:
-
选择区分性强的特征(如PCA降维后可视化观察类别可分性)。
-
对文本数据使用TF-IDF或词嵌入(Word2Vec)。
-
-
模型选择:
-
类别较少时:逻辑回归(OvR/OvO)或SVM。
-
高维数据(如图像):优先使用神经网络(CNN)。
-
需要可解释性:决策树或随机森林。
-
六、总结
多分类问题的核心在于有效区分多个类别间的差异,需根据数据特性选择合适的模型和策略:
-
类别较少且均衡 → OvR/OvO + 线性模型(如SVM)。
-
高维复杂数据 → 神经网络(Softmax输出)。--后面我们会有文章专门介绍Softmax
-
需解释性 → 决策树或规则模型。
-
样本不均衡 → 调整损失权重或采样方法。
通过理解数据分布、选择合适的评估指标,并针对性地优化模型,可以显著提升多分类任务的性能。
七、附加:如何理解多类线性可分?
“多类线性可分”是对多类别分类问题中数据分布的一种描述,其含义可以从以下几个角度来理解:
-
基本概念
- 在二分类问题中,我们说一个数据集线性可分,意思是存在一条直线(在二维中)或一个超平面(在高维中),能将正负两类完全分开。
- 对于多分类问题,数据集被分为三个或更多的类别。如果数据集“多类线性可分”,就意味着存在一组线性决策函数,使得对于任意一个样本,其对应的真实类别的决策函数值比其他类别的决策函数值高,从而能正确分类所有样本。
-
数学描述
假设我们有 K 个类别,为每个类别 k 定义一个线性函数:数据集是多类线性可分的,如果对于所有样本 x 属于类别 k 都满足:
这就意味着,利用“argmax”决策规则,
每个样本都可以被正确分类。
-
直观理解
- 决策边界:
对于多分类问题,多类线性可分意味着可以用一组线性边界(或超平面)将不同类别划分开。例如,在二维空间中,这可能表现为几条直线组合在一起,将平面划分成多个区域,每个区域对应一个类别。 - 分区方法:
一种常见的思路是“one-vs-rest”方法,对于每个类别训练一个二分类器,将该类别与其他所有类别分开。如果所有这些二分类器都能完美区分,那么整体问题就可以看作多类线性可分。
- 决策边界:
-
举例说明
例子:二维空间中的三分类问题
假设我们有三个类别的数据,分别用红色、蓝色和绿色表示。
- 如果这些数据分布在二维平面上,并且可以找到两条直线将平面划分成三个区域,每个区域内的数据都只属于某一个颜色,那么这个数据集就是多类线性可分的。
- 同理对其他类别也成立。这样,通过计算各函数值,并取最大者作为预测结果,就可以正确地对每个样本分类。
“多类线性可分”指的是在一个多类别问题中,存在一组线性决策函数(或超平面组合),能够将不同类别的数据完全分开,使得对任意样本,真实类别对应的函数值都高于其他类别。这是线性分类器(如 Softmax 回归、线性 SVM 在 one-vs-rest 策略下)的理想状态,也是判断模型在理论上是否有能力完美分类数据的一种标准。
可见,如果数据集是多类线性可分的,那么一定存在一个“argmax” 方式的线性分类器可以将它们正确分开.