当前位置: 首页 > article >正文

【漫话机器学习系列】125.普拉托变换(Platt Scaling)

普拉托变换(Platt Scaling)详解

1. 什么是普拉托变换(Platt Scaling)?

普拉托变换(Platt Scaling)是一种用于二分类支持向量机(SVM)的后处理方法,它的作用是将 SVM 的输出分数转换为概率值。由于 SVM 本身并不直接输出概率,而是输出一个决策值(即 f(x)),因此我们需要使用某种方法将这个决策值映射到 [0,1] 范围的概率值。普拉托变换正是通过拟合一个 Sigmoid 函数 来完成这个任务。


2. 线性支持向量机的输出问题

在 SVM 训练完成后,它会为每个输入样本 x 计算一个 决策函数值

f(x) = w^T x + b

这个值表示样本相对于决策边界的距离。一般来说:

  • 如果 f(x) > 0,则分类为正类( y = 1 )。
  • 如果 f(x) < 0,则分类为负类( y = 0 )。

但这个值只是一个距离,而不是概率。例如,f(x) = 2 和 f(x) = 100 都表示正类,但哪个更有可能是正类呢?这个问题就是 SVM 无法直接提供概率输出的原因。因此,我们需要一个方法将 SVM 的输出转换为概率。


3. 普拉托变换的基本思想

普拉托变换的核心思想是 使用 Sigmoid 函数拟合 SVM 的输出,使其变成概率

P(y=1|x) = \frac{1}{1 + e^{A f(x) + B}}

其中:

  • f(x) 是 SVM 计算得到的决策值(即得分)。
  • A 和 B 是两个需要拟合的参数。

换句话说,我们希望通过训练找到 A 和 B,使得 SVM 的输出 f(x) 能够很好地拟合样本的真实概率分布。


4. 如何训练 A 和 B?

训练 A 和 B 的过程如下:

  1. 在数据集上训练 SVM,获取所有样本的 得分 f(x) 及其真实类别 y。
  2. 使用交叉验证 构建一个 逻辑回归模型,以 f(x) 为输入,y 为输出,拟合 Sigmoid 函数 的参数A 和 B。
  3. 计算损失函数
    • 普拉托变换使用 最大似然估计(MLE) 进行参数拟合。
    • 具体来说,最小化以下 交叉熵损失

      v\sum_{i=1}^{N} \left[ y_i \log P(y=1|x_i) + (1 - y_i) \log (1 - P(y=1|x_i)) \right]i=1
    • 这个公式与 逻辑回归的损失函数 一样,因此可以使用标准的优化方法(如梯度下降)来求解 A 和 B。

5. 为什么要使用普拉托变换?

普拉托变换是 SVM 最常用 的概率校准方法,原因如下:

  • 简单高效:只需要额外训练一个小的逻辑回归模型(拟合 Sigmoid),计算开销较小。
  • 适用于 SVM 及其他分类器:虽然最初是为 SVM 提出的,但 Platt Scaling 也可用于其他分类器(如神经网络、随机森林等)。
  • 提高模型可解释性:许多应用(如医学、金融等)需要输出概率,而不仅仅是分类结果。

6. Platt Scaling 的局限性

尽管普拉托变换在很多场景下都能有效校准概率输出,但它也有一些局限:

  1. 对样本不均衡较敏感:如果数据集的类别不均衡,拟合的 Sigmoid 可能会偏向多数类。
  2. 假设数据符合 Sigmoid 分布:Platt Scaling 假设 SVM 输出的决策值可以用 Sigmoid 函数拟合,但在某些复杂分布下,可能效果不好。
  3. 需要额外的数据:Platt Scaling 需要使用交叉验证数据来训练 A 和 B,这可能会浪费一部分训练数据。

如果数据集较大、类别不均衡或者 Sigmoid 拟合效果不好,通常可以考虑 Isotonic Regression(等温回归) 作为替代方法。


7. 代码实现(Python 示例)

在实际应用中,Scikit-learn 提供了 Platt Scaling 的实现,示例如下:

from sklearn.svm import SVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# 生成数据
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 训练 SVM(不带概率)
svm = SVC(kernel='linear', probability=False)
svm.fit(X_train, y_train)

# 使用 Platt Scaling 校准概率
svm_calibrated = CalibratedClassifierCV(svm, method='sigmoid', cv=5)
svm_calibrated.fit(X_train, y_train)

# 预测概率
probabilities = svm_calibrated.predict_proba(X_test)

print("前 5 个样本的预测概率:")
print(probabilities[:5])

运行结果 

前 5 个样本的预测概率:
[[0.39216467 0.60783533]
 [0.19236017 0.80763983]
 [0.51329367 0.48670633]
 [0.24076157 0.75923843]
 [0.0697871  0.9302129 ]]


8. 结论

普拉托变换(Platt Scaling)是一种 简单有效的概率校准方法,特别适用于 SVM 这种不直接输出概率的分类器。它通过 拟合 Sigmoid 函数,将分类得分映射到 [0,1] 范围,使得输出更具可解释性。

在实际应用中,我们可以:

  • 需要概率输出的任务(如医疗诊断、金融风控)中使用 Platt Scaling。
  • 如果数据 类别不均衡,可以尝试 Isotonic Regression 作为替代方案。

希望这篇文章能够帮助你更好地理解 Platt Scaling 的原理和应用!


http://www.kler.cn/a/582053.html

相关文章:

  • 【UNIAPP】获取视频的第一帧作为封面(基于视频URL,Canvas)复制即用
  • git文件过大导致gitea仓库镜像推送失败问题解决(push failed: context deadline exceeded)
  • 在VMware Workstation Pro上轻松部署CentOS7 Linux虚拟机
  • 使用Java爬虫根据关键词获取衣联网商品列表:实战指南
  • 《2025年软件测试工程师面试》MySQL面试题
  • Linux 环境变量快速上手指南
  • 《UI 设计:点亮大数据可视化的智慧之光》
  • 告别手抖烦恼,重拾生活稳 “态”
  • 【物联网-WIFI】
  • 【MySQL】增删改查进阶
  • 【测试框架篇】单元测试框架pytest(5):setup和teardown的详细使用
  • 2.4 基于Vitest的单元测试基础设施搭建
  • DeFi基石ERC4626标准实现一个金库合约
  • 聊一聊大型 Unity 游戏资源打包
  • redis趣味解读
  • chrome源码中非常巧妙、复杂或者不常见的技术手段
  • 天地图的使用
  • 【机械视觉】C#+VisionPro联合编程———【五、硬币检测小项目实现(C#+VisionPro联合编程和csv文件格式操作)】
  • Vue主流的状态保存框架对比
  • git无法提交解决方案--! [rejected] master -> master (non-fast-forward)