当前位置: 首页 > article >正文

【机器学习】Softmax 函数

Softmax 是机器学习中常用的函数,广泛用于多分类问题的输出层。它可以将一组实数转换为一个概率分布,使得结果满足“非负”和“总和为1”的要求。在分类问题中,Softmax 让模型预测的每个类别概率都易于解释。本文将详细讲解 Softmax 的原理、公式推导、Numpy 实现及其在 Pytorch 中的实际应用。

Softmax 原理

给定一个类别集合 { y 1 , y 2 , … , y n } \{y_1, y_2, \dots, y_n\} {y1,y2,,yn},Softmax 将模型输出的每个数值(称为“得分”或“logits”)转换为概率值。假设模型输出 z i z_i zi 为第 i i i 类的得分,Softmax 将所有的得分映射到概率空间,使每个得分转化为该类的预测概率。

Softmax 函数的公式为:
P ( y i ) = e z i ∑ j = 1 n e z j P(y_i) = \frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}} P(yi)=j=1nezjezi
其中 z i z_i zi 表示模型为第 i i i 类输出的得分, n n n 是类别的数量。通过对指数值的归一化处理,Softmax 函数输出的概率满足:

  1. 所有概率值都为非负数;
  2. 概率总和为 1。

Softmax 计算中的数值稳定性

在计算中,Softmax 可能会因为指数运算导致数值溢出,为了减小这种风险,可以对每个 (z_i) 值减去一个常数 max ⁡ ( z ) \max(z) max(z)
P ( y i ) = e z i − max ⁡ ( z ) ∑ j = 1 n e z j − max ⁡ ( z ) P(y_i) = \frac{e^{z_i - \max(z)}}{\sum_{j=1}^n e^{z_j - \max(z)}} P(yi)=j=1nezjmax(z)ezimax(z)
这种转换不会改变概率的分布,避免了指数函数产生的大数值溢出问题。

Numpy 实现 Softmax 函数

下面通过 Numpy 实现 Softmax,并进行数据可视化以更直观地理解 Softmax 对得分的转换过程。

import numpy as np
import matplotlib.pyplot as plt

# 定义 Softmax 函数
def softmax(logits):
    """
    使用数值稳定性的 Softmax 函数实现

    参数:
    - logits: 模型输出得分向量(shape: (n,),表示 n 个类别的得分)

    返回:
    - probs: 转换后的概率向量,shape: (n,)
    """
    exp_shifted = np.exp(logits - np.max(logits))  # 减去 max(logits) 以确保数值稳定性
    probs = exp_shifted / np.sum(exp_shifted)  # 归一化为概率
    return probs

# 示例输入的分类得分
logits = np.array([2.0, 1.0, 0.1])

# 使用 Softmax 函数计算各类别的概率
probs = softmax(logits)

# 输出各类的预测概率
print("分类得分:", logits)
print("预测概率:", probs)

Softmax 输出可视化

我们可以用图像展示 Softmax 如何将得分转化为概率,假设输入的分类得分范围为 -2 到 4。

# 生成模拟的分类得分范围
logit_range = np.linspace(-2, 4, 100)
all_probs = np.array([softmax([l, 1.0, 0.1]) for l in logit_range])

# 可视化不同类别的预测概率随得分变化的趋势
plt.plot(logit_range, all_probs[:, 0], label="类别 1")
plt.plot(logit_range, all_probs[:, 1], label="类别 2")
plt.plot(logit_range, all_probs[:, 2], label="类别 3")
plt.xlabel("得分 (logits)")
plt.ylabel("概率")
plt.title("Softmax 函数输出的概率分布")
plt.legend()
plt.show()

Softmax 损失函数:交叉熵损失

在多分类任务中,常用 交叉熵损失函数 来衡量模型预测概率分布与真实标签的匹配程度。对于单个样本,交叉熵损失定义为:
L = − ∑ i = 1 n y i ⋅ log ⁡ ( P ( y i ) ) L = -\sum_{i=1}^{n} y_i \cdot \log(P(y_i)) L=i=1nyilog(P(yi))

其中 (y_i) 是真实标签的 one-hot 编码,(P(y_i)) 是 Softmax 转换后的预测概率。

# 定义交叉熵损失函数
def cross_entropy_loss(probs, y_true):
    """
    计算交叉熵损失
    
    参数:
    - probs: Softmax 预测概率 (shape: (n,))
    - y_true: 实际标签 (shape: (n,)),one-hot 编码

    返回:
    - loss: 交叉熵损失
    """
    loss = -np.sum(y_true * np.log(probs + 1e-10))  # 加1e-10防止 log(0)
    return loss

# 示例计算
y_true = np.array([1, 0, 0])  # 假设类别 1 为正确类别
loss = cross_entropy_loss(probs, y_true)
print("交叉熵损失:", loss)

在 PyTorch 中使用 Softmax

在 PyTorch 中,我们可以直接调用 torch.nn.functional.softmax 来实现 Softmax。此外,PyTorch 提供的 torch.nn.CrossEntropyLoss 函数在内部自动包含了 Softmax 和交叉熵的计算,无需显式计算。

import torch
import torch.nn.functional as F

# 示例:在 PyTorch 中实现 Softmax 和交叉熵损失
logits_torch = torch.tensor([2.0, 1.0, 0.1])

# 使用 PyTorch 的 Softmax 函数
probs_torch = F.softmax(logits_torch, dim=0)
print("PyTorch 预测概率:", probs_torch.numpy())

# 使用交叉熵损失函数
y_true_index = torch.tensor([0])  # 假设第一个类别为正确类别
loss_fn = torch.nn.CrossEntropyLoss()
loss_torch = loss_fn(logits_torch.unsqueeze(0), y_true_index)
print("PyTorch 交叉熵损失:", loss_torch.item())

在 PyTorch 中,torch.nn.CrossEntropyLoss 在传入 logits 后自动应用 Softmax 和交叉熵计算,为多分类问题提供了便捷的计算方式。

总结

本文介绍了 Softmax 的原理、公式、Numpy 实现、可视化以及在 PyTorch 中的使用。Softmax 是将得分转化为概率分布的关键函数,尤其适用于多分类任务。我们还探讨了数值稳定性的处理以及交叉熵损失在多分类中的作用,理解并实现 Softmax 有助于构建更稳定且易解释的分类模型。


http://www.kler.cn/a/371553.html

相关文章:

  • Linux第一个系统程序---进度条
  • 计算机网络 笔记 数据链路层 2
  • 【GESP】C++二级练习 luogu-B2079, 求出 e 的值
  • 2025年第三届“华数杯”国际赛B题解题思路与代码(Matlab版)
  • 设计模式 行为型 责任链模式(Chain of Responsibility Pattern)与 常见技术框架应用 解析
  • 数据库高安全—角色权限:权限管理权限检查
  • GraphQL系列 - 第1讲 GraphQL语法入门
  • 计算机毕业设计——ssm基于HTML5的互动游戏新闻网站的设计与实现录像演示2021
  • R_机器学习——常用函数方法汇总
  • Java进阶篇设计模式之四 -----适配器模式和桥接模式
  • 会议录音转文字怎么转?有这6款视频语音转文字工具就够了!
  • 微信小程序时间弹窗——年月日时分
  • G2 基于生成对抗网络(GAN)人脸图像生成
  • Pytorch学习--神经网络--非线性激活
  • 【Unity基础】初识UI Toolkit - 运行时UI
  • 【项目复现】——DDoS-SDN Detection Project
  • Nginx + Lua + Redis:打造智能 IP 黑名单系统
  • 「Mac畅玩鸿蒙与硬件14」鸿蒙UI组件篇4 - Toggle 和 Checkbox 组件
  • Conditional DETR论文笔记
  • Java基础(4)之正则,异常与文件IO流
  • KAN原作论文github阅读(readme)
  • 深度解读GaussDB逻辑解码技术原理
  • 使用 OpenCV 进行人眼检测
  • springboot天气预报推送小程序-计算机毕业设计源码41533
  • 青少年编程与数学 02-002 Sql Server 数据库应用 15课题、备份与还原
  • C++初阶(七)--类和对象(4)