当前位置: 首页 > article >正文

【小白学AI系列】NLP 核心知识点(六)Softmax函数介绍

Softmax 函数

Softmax 函数是一种常用的数学函数,广泛应用于机器学习中的分类问题,尤其是在神经网络的输出层。它的主要作用是将一个实数向量“压缩”成一个概率分布,使得所有输出的值在 0 到 1 之间,并且总和为 1。换句话说,Softmax 将模型的原始输出(logits)转化为概率,帮助我们做分类决策。

定义与公式

假设我们有一个向量 z = [ z 1 , z 2 , … , z n ] \mathbf{z} = [z_1, z_2, \dots, z_n] z=[z1,z2,,zn],其中每个元素 z i z_i zi 是模型的原始输出(logit)。Softmax 函数会将每个 z i z_i zi转换成一个概率 P ( y i ) P(y_i) P(yi),公式如下:

Softmax ( z i ) = e z i ∑ j = 1 n e z j \text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{n} e^{z_j}} Softmax(zi)=j=1nezjezi
其中:

  • e z i e^{z_i} ezi 是对每个元素 z i z_i zi 取指数。
  • ∑ j = 1 n e z j \sum_{j=1}^{n} e^{z_j} j=1nezj 是所有元素指数的和,确保所有输出的和为 1。

解释

  1. 指数函数的作用:Softmax 通过对每个元素进行指数运算,放大了较大值的影响,同时抑制了较小值。指数运算的特性使得较大的值会在最终的概率分布中占据主导地位。

  2. 归一化:分母部分确保了所有的概率之和为 1。这个归一化过程使得输出变成了一个有效的概率分布,能够进行分类。

  3. 概率分布:输出的每个 P ( y i ) P(y_i) P(yi) 都是一个 0 到 1 之间的数,且所有概率的和为 1,这使得 Softmax 适用于多类分类任务,最终能够为每一类分配一个概率。


Softmax 的应用场景

  1. 分类任务:Softmax 常用于多分类问题,模型的输出通常是一个概率分布,表示输入数据属于每个类别的概率。例如,在图像分类任务中,模型的输出层通常是 Softmax 激活函数,用于预测图像属于不同类别的概率。

  2. 神经网络输出层:在神经网络的最后一层,Softmax 函数通常用于将原始的 logits 转换为概率,以便进行预测或决策。

  3. 自然语言处理(NLP):在文本生成(例如 GPT 模型)或机器翻译中,Softmax 用来决定每个时间步的输出词汇。根据每个词汇的概率分布,模型选择一个概率最高的词作为输出。


Softmax 的例子

假设我们有一个模型的输出向量:( \mathbf{z} = [2.0, 1.0, 0.1] ),表示模型对三个类别的预测值(logits)。我们希望将这些值转换为概率分布。

步骤:

  1. 计算每个元素的指数:

[
e^{2.0} = 7.389, \quad e^{1.0} = 2.718, \quad e^{0.1} = 1.105
]

  1. 计算所有指数的和:

sum = e 2.0 + e 1.0 + e 0.1 = 7.389 + 2.718 + 1.105 = 11.212 \text{sum} = e^{2.0} + e^{1.0} + e^{0.1} = 7.389 + 2.718 + 1.105 = 11.212 sum=e2.0+e1.0+e0.1=7.389+2.718+1.105=11.212

  1. 计算每个元素的 Softmax 输出:

Softmax ( 2.0 ) = 7.389 11.212 = 0.659 \text{Softmax}(2.0) = \frac{7.389}{11.212} = 0.659 Softmax(2.0)=11.2127.389=0.659
Softmax ( 1.0 ) = 2.718 11.212 = 0.243 \text{Softmax}(1.0) = \frac{2.718}{11.212} = 0.243 Softmax(1.0)=11.2122.718=0.243
Softmax ( 0.1 ) = 1.105 11.212 = 0.098 \text{Softmax}(0.1) = \frac{1.105}{11.212} = 0.098 Softmax(0.1)=11.2121.105=0.098

因此,Softmax 转换后的概率分布为:

P ( class 1 ) = 0.659 , P ( class 2 ) = 0.243 , P ( class 3 ) = 0.098 P(\text{class 1}) = 0.659, \quad P(\text{class 2}) = 0.243, \quad P(\text{class 3}) = 0.098 P(class 1)=0.659,P(class 2)=0.243,P(class 3)=0.098

Softmax 的特点

  1. 输出是概率:Softmax 将每个元素映射到 0 和 1 之间,并且保证输出的总和为 1,适合用来表示概率分布。

  2. 关注较大值:Softmax 通过指数运算放大了较大值的影响,这使得 Softmax 在分类时能突出最有可能的类别。

  3. 数值稳定性:由于 Softmax 函数涉及到指数计算,容易出现数值溢出(当输入值过大时)。为了避免这种情况,通常在实际计算时,先对所有输入减去最大值,这样可以确保数值稳定性,公式变为:
    Softmax ( z i ) = e z i − max ⁡ ( z ) ∑ j = 1 n e z j − max ⁡ ( z ) \text{Softmax}(z_i) = \frac{e^{z_i - \max(z)}}{\sum_{j=1}^{n} e^{z_j - \max(z)}} Softmax(zi)=j=1nezjmax(z)ezimax(z)
    这样做不会改变输出的结果,但能避免数值溢出问题。


Softmax 与其他函数的比较

  • Sigmoid 函数:Sigmoid 函数将一个实数映射到 0 到 1 之间的值,通常用于二分类任务。而 Softmax 函数可以处理多分类问题,输出一个概率分布,适用于多于两个类别的分类任务。

  • ReLU 函数:ReLU(修正线性单元)常用于隐藏层,它将负值转化为零,正值保持不变,主要用于激活函数。而 Softmax 主要用于输出层,将模型的 logits 转化为概率。


这是一个非常好的问题!让我们一起深入探讨一下为什么 Softmax 函数选择使用 指数 而不是直接使用原始的向量值(logits)作为分数。

** 为什么要使用指数函数而不是原始值?**

放大差异

Softmax 函数使用 指数 的一个关键原因是它能有效地 放大 数值之间的差异。直接使用原始的值作为分数,可能会导致数值较小的分数几乎没有影响,而数值较大的分数会主导整个输出。通过使用指数函数,Softmax 可以更好地突出较大值的影响,这样模型在做决策时,具有更强的“选择性”。

  • 举个例子:
    假设你有两个类的 logits: z 1 = 2.0 z_1 = 2.0 z1=2.0 z 2 = 1.0 z_2 = 1.0 z2=1.0。如果你直接使用这两个值作为概率分数,它们的比例是 2.0 / 1.0 = 2.0 2.0 / 1.0 = 2.0 2.0/1.0=2.0,这可能给你一个稍微偏向第一个类的输出。

    但是,如果你对这两个值应用指数:
    e 2.0 = 7.389 , e 1.0 = 2.718 e^{2.0} = 7.389, \quad e^{1.0} = 2.718 e2.0=7.389,e1.0=2.718
    你会发现,经过指数转换后,差异被极大地放大了,比例变成了:
    7.389 2.718 ≈ 2.72 \frac{7.389}{2.718} \approx 2.72 2.7187.3892.72
    这样,较大的分数(( z_1 = 2.0 ))的影响被放大,使得模型的决策更加突出“正确的类别”。

避免“平坦”概率分布

如果我们直接使用原始的值,可能会导致一些类别之间的差异非常微弱,从而得到一个非常“平坦”的概率分布,所有类别的概率都差不多。例如,在一个分类任务中,如果 logits 之间的差异非常小,Softmax 可能会将每个类别的概率都分配得差不多,这样模型的决策就不够明确了。

通过指数运算,Softmax 使得即使 logits 之间的差异非常小,也能够显著区分它们的权重,从而避免“平坦”的概率分布。


2. 为何指数运算能够更有效地表示概率?

概率与相对关系

在分类问题中,我们关心的并不是每个类别的绝对值,而是 类别之间的相对关系。指数函数的使用本质上是一种 相对关系的量化。即使 logits 的绝对值很大或很小,Softmax 的输出只关心它们之间的相对关系。

  • 举个例子:
    假设有两个类别,logits 分别是 ( 10 ) 和 ( 8 ),它们的差异是 2。若用指数处理,则:
    e 10 = 22026.465 , e 8 = 2980.958 e^{10} = 22026.465, \quad e^{8} = 2980.958 e10=22026.465,e8=2980.958
    它们之间的比例是:
    22026.465 2980.958 ≈ 7.4 \frac{22026.465}{2980.958} \approx 7.4 2980.95822026.4657.4
    这个比例表明类别 1 比类别 2 更有可能,概率差异非常明显。

    如果不使用指数,而是直接比较原始值 ( 10 ) 和 ( 8 ),那么概率分布可能就没有那么明显,导致决策不够“坚定”。

避免极端的数值(溢出问题)

指数运算的另一个原因是它可以帮助避免非常大的原始值引起数值溢出的问题。实际上,在使用指数计算时,如果不做适当的缩放(例如减去最大值),logits 过大可能导致计算溢出(例如 ( e^{1000} ) 太大而无法计算)。但通过将所有值统一进行缩放,Softmax 能确保数值的稳定性。


3. 直观解释:指数如何让分类决策更“明确”

  • 假设我们有一个模型输出的 logits 向量 ( \mathbf{z} = [1.0, 2.0, 0.1] ),我们可以通过指数计算得到它们的概率。
    • 对每个元素进行指数操作:
      e 1.0 = 2.718 , e 2.0 = 7.389 , e 0.1 = 1.105 e^{1.0} = 2.718, \quad e^{2.0} = 7.389, \quad e^{0.1} = 1.105 e1.0=2.718,e2.0=7.389,e0.1=1.105
    • 计算总和:
      2.718 + 7.389 + 1.105 = 11.212 2.718 + 7.389 + 1.105 = 11.212 2.718+7.389+1.105=11.212
    • 计算每个类别的概率:
      P ( class 1 ) = 2.718 11.212 ≈ 0.242 P(\text{class 1}) = \frac{2.718}{11.212} \approx 0.242 P(class 1)=11.2122.7180.242
      P ( class 2 ) = 7.389 11.212 ≈ 0.659 P(\text{class 2}) = \frac{7.389}{11.212} \approx 0.659 P(class 2)=11.2127.3890.659
      P ( class 3 ) = 1.105 11.212 ≈ 0.099 P(\text{class 3}) = \frac{1.105}{11.212} \approx 0.099 P(class 3)=11.2121.1050.099
      在这种情况下,类别 2 的概率明显高于其他类别,且各类别之间的概率差异较大,决策也变得更为明确。

4. 总结:为什么用指数而不是原始值?

  1. 增强差异:指数运算可以显著放大 logits 之间的差异,使得模型能够更明显地分辨不同类别。

  2. 凸显相对关系:Softmax 关注的是各类别之间的 相对概率,而非绝对值。指数函数更好地反映了这种相对关系。

  3. 避免平坦分布:如果 logits 值差异不大,Softmax 通过指数运算能够更好地区分各个类别,避免概率分布过于平坦。

  4. 数值稳定性:指数操作可以避免原始值导致的数值溢出问题,尤其是在处理大规模数据时。


http://www.kler.cn/a/527581.html

相关文章:

  • 天融信 NGFW2.3 mibs
  • SAP SD学习笔记27 - 请求计划(开票计划)之1 - 定期请求
  • Leetcode:350
  • 【Rust自学】15.0. 智能指针(序):什么是智能指针及Rust智能指针的特性
  • 【Unity3D】实现2D角色/怪物死亡消散粒子效果
  • 本地部署deepseek模型步骤
  • 如何优化轮式移动机器人的运动稳定性?
  • 仿真设计|基于51单片机的低频信号控制系统仿真
  • PostgreSQL图插件AGE
  • DeepSeek-R1 论文解读 —— 强化学习大语言模型新时代来临?
  • Java 泛型<? extends Object>
  • 小程序-基础加强
  • 最新Java开发进阶!Java进阶面试资料无偿分享_java面试最新资料
  • SpringBoot入门:快速构建第一个Web应用
  • 需求分析应该从哪些方面来着手做?
  • 高低频混合组网系统中基于地理位置信息的信道测量算法matlab仿真
  • 手摸手系列之 DeepSeek-R1 开源大模型私有化部署解决方案
  • Linux_线程同步生产者消费者模型
  • 适合超多氛围灯节点应用的新选择
  • springboot 2.7.6 security mysql redis jwt配置例子
  • 【股票数据API接口36】如何获取股票当天逐笔大单交易数据之Python、Java等多种主流语言实例代码演示通过股票数据接口获取数据
  • 仿真设计|基于51单片机的温室环境监测调节系统
  • C++实现状态模式
  • 如何选择Spring AOP的动态代理?JDK与CGLIB的适用场景
  • python 语音识别
  • 如何在 Kafka 中实现自定义分区器