MSE分类时梯度消失的问题详解和交叉熵损失的梯度推导
下面是MSE不适合分类任务的解释,包含梯度推导。以及交叉熵的梯度推导。
前文请移步笔者的另一篇博客:大模型训练为什么选择交叉熵损失(Cross-Entropy Loss):均方误差(MSE)和交叉熵损失的深入对比
MSE分类时梯度消失的问题详解
我们深入探讨 MSE(均方误差)的梯度特性,结合公式推导和分析,解释为什么在预测值接近 0 或 1 时梯度趋于 0,以及这背后的含义。我会尽量保持清晰且严谨,适合高理论水平的深度学习研究者。
MSE 的梯度推导与特性
均方误差(Mean Squared Error, MSE)在分类任务中的定义通常是针对模型输出 ( y ^ \hat{y} y^ )(预测值)和真实标签 ( y y y ) 之间的平方差。对于单个样本,MSE 损失函数可以写为:
L M S E = 1 2 ( y − y ^ ) 2 L_{MSE} = \frac{1}{2} (y - \hat{y})^2 LMSE=21(y−y^)2
这里的 ( 1 2 \frac{1}{2} 21 ) 是为了简化梯度计算的常数项(在求导时会消掉),在实际实现中可能省略,但不影响分析。为了与交叉熵的场景保持一致,我们假设 ( y y y ) 是分类任务中的真实标签(例如二分类中 ( y ∈ { 0 , 1 } y \in \{0, 1\} y∈{0,1} )),而 ( y ^ \hat{y} y^ ) 是模型的输出,通常经过 sigmoid 或 softmax 处理,范围在 ( [ 0 , 1 ] [0, 1] [0,1] ) 内,表示概率。
1. MSE 梯度的计算
为了计算梯度,我们需要对 ( L M S E L_{MSE} LMSE ) 关于模型输出 ( y ^ \hat{y} y^ ) 求偏导数:
∂ L M S E ∂ y ^ = ∂ ∂ y ^ [ 1 2 ( y − y ^ ) 2 ] \frac{\partial L_{MSE}}{\partial \hat{y}} = \frac{\partial}{\partial \hat{y}} \left[ \frac{1}{2} (y - \hat{y})^2 \right] ∂y^∂LMSE=∂y^∂[21(y−y^)2]
根据链式法则,求导过程如下:
∂ L M S E ∂ y ^ = 1 2 ⋅ 2 ( y − y ^ ) ⋅ ( − 1 ) = − ( y − y ^ ) \frac{\partial L_{MSE}}{\partial \hat{y}} = \frac{1}{2} \cdot 2 (y - \hat{y}) \cdot (-1) = -(y - \hat{y}) ∂y^∂LMSE=21⋅2(y−y^)⋅(−1)=−(y−y^)
所以,MSE 的梯度为:
∂ L M S E ∂ y ^ = y ^ − y \frac{\partial L_{MSE}}{\partial \hat{y}} = \hat{y} - y ∂y^∂LMSE=y^−y
这表明 MSE 的梯度与预测值 ( y ^ \hat{y} y^ ) 和真实值 ( y y y ) 之间的误差 ( y ^ − y \hat{y} - y y^−y ) 成正比,是一个线性关系。
2. 考虑激活函数(如 sigmoid)
在实际分类任务中,模型的原始输出(称为 logits,记为 ( z z z ))通常会通过激活函数(如 sigmoid)转化为概率 ( y ^ \hat{y} y^ )。对于二分类,sigmoid 函数定义为:
y ^ = σ ( z ) = 1 1 + e − z \hat{y} = \sigma(z) = \frac{1}{1 + e^{-z}} y^=σ(z)=1+e−z1
因此,损失函数变为关于 ( z z z ) 的函数:
L M S E = 1 2 ( y − σ ( z ) ) 2 L_{MSE} = \frac{1}{2} (y - \sigma(z))^2 LMSE=21(y−σ(z))2
我们需要计算梯度 ( ∂ L M S E ∂ z \frac{\partial L_{MSE}}{\partial z} ∂z∂LMSE ),使用链式法则:
∂ L M S E ∂ z = ∂ L M S E ∂ y ^ ⋅ ∂ y ^ ∂ z \frac{\partial L_{MSE}}{\partial z} = \frac{\partial L_{MSE}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial z} ∂z∂LMSE=∂y^∂LMSE⋅∂z∂y^
-
第一部分:
∂ L M S E ∂ y ^ = y ^ − y \frac{\partial L_{MSE}}{\partial \hat{y}} = \hat{y} - y ∂y^∂LMSE=y^−y -
第二部分(sigmoid 的导数):
∂ y ^ ∂ z = ∂ ∂ z ( 1 1 + e − z ) = σ ( z ) ( 1 − σ ( z ) ) = y ^ ( 1 − y ^ ) \frac{\partial \hat{y}}{\partial z} = \frac{\partial}{\partial z} \left( \frac{1}{1 + e^{-z}} \right) = \sigma(z) (1 - \sigma(z)) = \hat{y} (1 - \hat{y}) ∂z∂y^=∂z∂(1+e−z1)=σ(z)(1−σ(z))=y^(1−y^)
所以,完整的梯度为:
∂ L M S E ∂ z = ( y ^ − y ) ⋅ y ^ ( 1 − y ^ ) \frac{\partial L_{MSE}}{\partial z} = (\hat{y} - y) \cdot \hat{y} (1 - \hat{y}) ∂z∂LMSE=(y^−y)⋅y^(1−y^)
这个梯度会通过反向传播进一步影响网络参数的更新。
为什么梯度趋于 0?
现在我们分析为什么 MSE 的梯度在 ( y ^ \hat{y} y^ ) 接近 0 或 1 时会趋于 0,以及这对分类任务意味着什么。
1. 梯度公式中的关键项:( y ^ ( 1 − y ^ ) \hat{y} (1 - \hat{y}) y^(1−y^) )
sigmoid 的导数 ( y ^ ( 1 − y ^ ) \hat{y} (1 - \hat{y}) y^(1−y^) ) 是一个二次函数:
- 当 ( y ^ = 0 \hat{y} = 0 y^=0 ) 时,( y ^ ( 1 − y ^ ) = 0 ⋅ ( 1 − 0 ) = 0 \hat{y} (1 - \hat{y}) = 0 \cdot (1 - 0) = 0 y^(1−y^)=0⋅(1−0)=0 );
- 当 ( y ^ = 1 \hat{y} = 1 y^=1 ) 时,( y ^ ( 1 − y ^ ) = 1 ⋅ ( 1 − 1 ) = 0 \hat{y} (1 - \hat{y}) = 1 \cdot (1 - 1) = 0 y^(1−y^)=1⋅(1−1)=0 );
- 当 ( y ^ = 0.5 \hat{y} = 0.5 y^=0.5 ) 时,( y ^ ( 1 − y ^ ) = 0.5 ⋅ 0.5 = 0.25 \hat{y} (1 - \hat{y}) = 0.5 \cdot 0.5 = 0.25 y^(1−y^)=0.5⋅0.5=0.25 )(最大值)。
这意味着,当模型的预测 ( y ^ \hat{y} y^ ) 趋向于边界(0 或 1)时,梯度中的 ( y ^ ( 1 − y ^ ) \hat{y} (1 - \hat{y}) y^(1−y^) ) 项会迅速减小到 0。
2. 具体场景分析
-
情况 1:预测错误且置信度高
假设 ( y = 1 y = 1 y=1 ),但 ( y ^ → 0 \hat{y} \to 0 y^→0 )(模型错误且非常确信):
∂ L M S E ∂ z = ( y ^ − y ) ⋅ y ^ ( 1 − y ^ ) ≈ ( 0 − 1 ) ⋅ 0 ⋅ ( 1 − 0 ) = 0 \frac{\partial L_{MSE}}{\partial z} = (\hat{y} - y) \cdot \hat{y} (1 - \hat{y}) \approx (0 - 1) \cdot 0 \cdot (1 - 0) = 0 ∂z∂LMSE=(y^−y)⋅y^(1−y^)≈(0−1)⋅0⋅(1−0)=0
梯度趋于 0,尽管误差 ( y ^ − y = − 1 \hat{y} - y = -1 y^−y=−1 ) 很大。 -
情况 2:预测正确且置信度高
假设 ( y = 1 y = 1 y=1 ),且 ( y ^ → 1 \hat{y} \to 1 y^→1 )(模型正确且确信):
∂ L M S E ∂ z = ( y ^ − y ) ⋅ y ^ ( 1 − y ^ ) ≈ ( 1 − 1 ) ⋅ 1 ⋅ ( 1 − 1 ) = 0 \frac{\partial L_{MSE}}{\partial z} = (\hat{y} - y) \cdot \hat{y} (1 - \hat{y}) \approx (1 - 1) \cdot 1 \cdot (1 - 1) = 0 ∂z∂LMSE=(y^−y)⋅y^(1−y^)≈(1−1)⋅1⋅(1−1)=0
梯度同样趋于 0,这次是因为误差 ( y ^ − y = 0 \hat{y} - y = 0 y^−y=0 )。
在第一种情况下,模型需要快速修正错误,但梯度接近 0,导致更新几乎停止,这就是所谓的“梯度消失”问题。
3. 与交叉熵对比
交叉熵(以二元交叉熵为例)的损失为:
L B C E = − [ y log ( y ^ ) + ( 1 − y ) log ( 1 − y ^ ) ] L_{BCE} = -[y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})] LBCE=−[ylog(y^)+(1−y)log(1−y^)]
梯度为:
∂ L B C E ∂ z = y ^ − y \frac{\partial L_{BCE}}{\partial z} = \hat{y} - y ∂z∂LBCE=y^−y
注意,这里没有 ( y ^ ( 1 − y ^ ) \hat{y} (1 - \hat{y}) y^(1−y^) ) 这样的乘积项。当 ( y = 1 y = 1 y=1 ) 且 ( y ^ → 0 \hat{y} \to 0 y^→0 ) 时,梯度 ( y ^ − y ≈ 0 − 1 = − 1 \hat{y} - y \approx 0 - 1 = -1 y^−y≈0−1=−1 ),大小显著且方向明确,推动模型快速修正。
梯度为 0 的含义
-
学习停滞
当 ( y ^ \hat{y} y^ ) 接近 0 或 1 时,MSE 的梯度趋于 0,意味着模型参数的更新量极小。即使预测完全错误,模型也无法有效学习。这在分类任务中尤为致命,因为分类需要明确的决策边界,而非模糊的中间状态。 -
饱和效应
sigmoid 的“饱和”特性(输出接近边界时导数趋于 0)与 MSE 的线性误差项结合,加剧了梯度消失。这种效应在深层网络中会被放大,导致早期层几乎无法更新。 -
不适合分类的根本原因
分类任务需要模型对错误预测(尤其是高置信度错误)产生强反馈,而 MSE 在这种情况下“麻木”,无法提供足够的梯度信号。相比之下,交叉熵通过对数形式放大了错误预测的惩罚,确保梯度始终有效。
总结
MSE 的梯度公式 ( ∂ L M S E ∂ z = ( y ^ − y ) ⋅ y ^ ( 1 − y ^ ) \frac{\partial L_{MSE}}{\partial z} = (\hat{y} - y) \cdot \hat{y} (1 - \hat{y}) ∂z∂LMSE=(y^−y)⋅y^(1−y^) ) 显示,其大小受 ( y ^ ( 1 − y ^ ) \hat{y} (1 - \hat{y}) y^(1−y^) ) 控制。当预测值 ( y ^ \hat{y} y^ ) 接近 0 或 1 时,这一项趋于 0,导致梯度消失。无论是正确的高置信预测还是错误的极端预测,MSE 都无法提供足够的优化动力。这解释了为什么 MSE 在分类任务中表现不佳,而交叉熵凭借其非线性惩罚和稳定梯度成为更优选择。
多分类场景:Softmax + MSE 梯度推导
我们将深入探讨多分类场景下,使用 softmax 激活函数结合 MSE(均方误差)损失的梯度推导。这部分内容会非常详细,包含完整的数学推导,并分析梯度特性,尤其是在预测值接近边界时的行为。我们假设读者具备较高的理论水平,能够理解矩阵运算和链式法则的复杂应用。
多分类场景:Softmax + MSE 的定义
在多分类任务中,假设有 ( C C C ) 个类别,模型的原始输出(logits)为向量 ( z = [ z 1 , z 2 , … , z C ] \mathbf{z} = [z_1, z_2, \dots, z_C] z=[z1,z2,…,zC] )。通过 softmax 函数,logits 被转化为概率分布 ( y ^ = [ y ^ 1 , y ^ 2 , … , y ^ C ] \hat{\mathbf{y}} = [\hat{y}_1, \hat{y}_2, \dots, \hat{y}_C] y^=[y^1,y^2,…,y^C] ):
y ^ i = softmax ( z i ) = e z i ∑ j = 1 C e z j \hat{y}_i = \text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^C e^{z_j}} y^i=softmax(zi)=∑j=1Cezjezi
真实标签 ( y = [ y 1 , y 2 , … , y C ] \mathbf{y} = [y_1, y_2, \dots, y_C] y=[y1,y2,…,yC] ) 通常采用 one-hot 编码,例如对于第 ( k k k ) 类,( y k = 1 y_k = 1 yk=1 ),其他 ( y i = 0 ( i ≠ k ) y_i = 0 (i \neq k) yi=0(i=k) )。
MSE 损失函数定义为预测概率 ( y ^ \hat{\mathbf{y}} y^ ) 与真实标签 ( y \mathbf{y} y ) 之间的平方差平均值(为简化推导,我们省略常数因子 ( 1 C \frac{1}{C} C1 ) 的影响,加入 ( 1 2 \frac{1}{2} 21 ) 方便求导):
L M S E = 1 2 ∑ i = 1 C ( y i − y ^ i ) 2 L_{MSE} = \frac{1}{2} \sum_{i=1}^C (y_i - \hat{y}_i)^2 LMSE=21i=1∑C(yi−y^i)2
目标是计算梯度 ( ∂ L M S E ∂ z j \frac{\partial L_{MSE}}{\partial z_j} ∂zj∂LMSE )(对每个 logit ( z j z_j zj ) 的偏导),以理解其在多分类中的行为。
梯度推导
由于 ( y ^ i \hat{y}_i y^i ) 是 ( z \mathbf{z} z ) 的函数,我们需要使用链式法则分两步计算:
∂ L M S E ∂ z j = ∑ i = 1 C ∂ L M S E ∂ y ^ i ⋅ ∂ y ^ i ∂ z j \frac{\partial L_{MSE}}{\partial z_j} = \sum_{i=1}^C \frac{\partial L_{MSE}}{\partial \hat{y}_i} \cdot \frac{\partial \hat{y}_i}{\partial z_j} ∂zj∂LMSE=i=1∑C∂y^i∂LMSE⋅∂zj∂y^i
1. 计算 ( ∂ L M S E ∂ y ^ i \frac{\partial L_{MSE}}{\partial \hat{y}_i} ∂y^i∂LMSE )
首先,对 ( y ^ i \hat{y}_i y^i ) 求偏导:
∂ L M S E ∂ y ^ i = ∂ ∂ y ^ i [ 1 2 ∑ k = 1 C ( y k − y ^ k ) 2 ] \frac{\partial L_{MSE}}{\partial \hat{y}_i} = \frac{\partial}{\partial \hat{y}_i} \left[ \frac{1}{2} \sum_{k=1}^C (y_k - \hat{y}_k)^2 \right] ∂y^i∂LMSE=∂y^i∂[21k=1∑C(yk−y^k)2]
由于平方和中只有 ( k = i k = i k=i ) 的项与 ( y ^ i \hat{y}_i y^i ) 有关,其他项的导数为 0,因此:
∂ L M S E ∂ y ^ i = 1 2 ⋅ 2 ( y i − y ^ i ) ⋅ ( − 1 ) = − ( y i − y ^ i ) \frac{\partial L_{MSE}}{\partial \hat{y}_i} = \frac{1}{2} \cdot 2 (y_i - \hat{y}_i) \cdot (-1) = -(y_i - \hat{y}_i) ∂y^i∂LMSE=21⋅2(yi−y^i)⋅(−1)=−(yi−y^i)
即:
∂ L M S E ∂ y ^ i = y ^ i − y i \frac{\partial L_{MSE}}{\partial \hat{y}_i} = \hat{y}_i - y_i ∂y^i∂LMSE=y^i−yi
这与二分类的 MSE 梯度形式一致,表明误差方向取决于预测值与真实值的偏差。
2. 计算 ( ∂ y ^ i ∂ z j \frac{\partial \hat{y}_i}{\partial z_j} ∂zj∂y^i )(Softmax 的导数)
Softmax 的导数稍微复杂,因为 ( y ^ i \hat{y}_i y^i ) 不仅依赖于 ( z i z_i zi ),还通过归一化分母依赖于所有 ( z j z_j zj)。我们需要分情况讨论:
-
当 ( i = j i = j i=j ) 时(对自己的偏导):
y ^ i = e z i ∑ k = 1 C e z k \hat{y}_i = \frac{e^{z_i}}{\sum_{k=1}^C e^{z_k}} y^i=∑k=1Cezkezi
使用商法则求导:
∂ y ^ i ∂ z i = e z i ⋅ ∑ k = 1 C e z k − e z i ⋅ e z i ( ∑ k = 1 C e z k ) 2 = e z i ( ∑ k = 1 C e z k − e z i ) ( ∑ k = 1 C e z k ) 2 \frac{\partial \hat{y}_i}{\partial z_i} = \frac{e^{z_i} \cdot \sum_{k=1}^C e^{z_k} - e^{z_i} \cdot e^{z_i}}{(\sum_{k=1}^C e^{z_k})^2} = \frac{e^{z_i} (\sum_{k=1}^C e^{z_k} - e^{z_i})}{(\sum_{k=1}^C e^{z_k})^2} ∂zi∂y^i=(∑k=1Cezk)2ezi⋅∑k=1Cezk−ezi⋅ezi=(∑k=1Cezk)2ezi(∑k=1Cezk−ezi)
代入 ( y ^ i = e z i ∑ k = 1 C e z k \hat{y}_i = \frac{e^{z_i}}{\sum_{k=1}^C e^{z_k}} y^i=∑k=1Cezkezi ):
∂ y ^ i ∂ z i = y ^ i ⋅ ∑ k = 1 C e z k − e z i ∑ k = 1 C e z k = y ^ i ( 1 − y ^ i ) \frac{\partial \hat{y}_i}{\partial z_i} = \hat{y}_i \cdot \frac{\sum_{k=1}^C e^{z_k} - e^{z_i}}{\sum_{k=1}^C e^{z_k}} = \hat{y}_i (1 - \hat{y}_i) ∂zi∂y^i=y^i⋅∑k=1Cezk∑k=1Cezk−ezi=y^i(1−y^i) -
当 ( i ≠ j i \neq j i=j ) 时(交叉项):
∂ y ^ i ∂ z j = 0 ⋅ ∑ k = 1 C e z k − e z i ⋅ e z j ( ∑ k = 1 C e z k ) 2 = − e z i e z j ( ∑ k = 1 C e z k ) 2 = − y ^ i y ^ j \frac{\partial \hat{y}_i}{\partial z_j} = \frac{0 \cdot \sum_{k=1}^C e^{z_k} - e^{z_i} \cdot e^{z_j}}{(\sum_{k=1}^C e^{z_k})^2} = -\frac{e^{z_i} e^{z_j}}{(\sum_{k=1}^C e^{z_k})^2} = -\hat{y}_i \hat{y}_j ∂zj∂y^i=(∑k=1Cezk)20⋅∑k=1Cezk−ezi⋅ezj=−(∑k=1Cezk)2eziezj=−y^iy^j
总结 Softmax 的导数:
∂
y
^
i
∂
z
j
=
{
y
^
i
(
1
−
y
^
i
)
,
if
i
=
j
−
y
^
i
y
^
j
,
if
i
≠
j
\frac{\partial \hat{y}_i}{\partial z_j} = \begin{cases} \hat{y}_i (1 - \hat{y}_i), & \text{if } i = j \\ -\hat{y}_i \hat{y}_j, & \text{if } i \neq j \end{cases}
∂zj∂y^i={y^i(1−y^i),−y^iy^j,if i=jif i=j
这表明 Softmax 的输出是相互耦合的,改变一个 ( z j z_j zj ) 会影响所有 ( y ^ i \hat{y}_i y^i )。
3. 合并计算完整梯度
将两部分合并:
∂ L M S E ∂ z j = ∑ i = 1 C ( y ^ i − y i ) ⋅ ∂ y ^ i ∂ z j \frac{\partial L_{MSE}}{\partial z_j} = \sum_{i=1}^C (\hat{y}_i - y_i) \cdot \frac{\partial \hat{y}_i}{\partial z_j} ∂zj∂LMSE=i=1∑C(y^i−yi)⋅∂zj∂y^i
代入 Softmax 导数,分开 ( i = j i = j i=j ) 和 ( i ≠ j i \neq j i=j ) 的项:
∂ L M S E ∂ z j = ( y ^ j − y j ) ⋅ ∂ y ^ j ∂ z j + ∑ i ≠ j ( y ^ i − y i ) ⋅ ∂ y ^ i ∂ z j \frac{\partial L_{MSE}}{\partial z_j} = (\hat{y}_j - y_j) \cdot \frac{\partial \hat{y}_j}{\partial z_j} + \sum_{i \neq j} (\hat{y}_i - y_i) \cdot \frac{\partial \hat{y}_i}{\partial z_j} ∂zj∂LMSE=(y^j−yj)⋅∂zj∂y^j+i=j∑(y^i−yi)⋅∂zj∂y^i
-
第一项(( i = j i = j i=j )):
( y ^ j − y j ) ⋅ y ^ j ( 1 − y ^ j ) (\hat{y}_j - y_j) \cdot \hat{y}_j (1 - \hat{y}_j) (y^j−yj)⋅y^j(1−y^j) -
第二项(( i ≠ j i \neq j i=j ) 的总和):
∑ i ≠ j ( y ^ i − y i ) ⋅ ( − y ^ i y ^ j ) = − y ^ j ∑ i ≠ j ( y ^ i − y i ) y ^ i \sum_{i \neq j} (\hat{y}_i - y_i) \cdot (-\hat{y}_i \hat{y}_j) = -\hat{y}_j \sum_{i \neq j} (\hat{y}_i - y_i) \hat{y}_i i=j∑(y^i−yi)⋅(−y^iy^j)=−y^ji=j∑(y^i−yi)y^i
完整梯度为:
∂ L M S E ∂ z j = ( y ^ j − y j ) y ^ j ( 1 − y ^ j ) − y ^ j ∑ i ≠ j ( y ^ i − y i ) y ^ i \frac{\partial L_{MSE}}{\partial z_j} = (\hat{y}_j - y_j) \hat{y}_j (1 - \hat{y}_j) - \hat{y}_j \sum_{i \neq j} (\hat{y}_i - y_i) \hat{y}_i ∂zj∂LMSE=(y^j−yj)y^j(1−y^j)−y^ji=j∑(y^i−yi)y^i
为了更直观,我们可以进一步整理。注意到 ( ∑ i = 1 C y i = 1 \sum_{i=1}^C y_i = 1 ∑i=1Cyi=1 )(one-hot 标签),( ∑ i = 1 C y ^ i = 1 \sum_{i=1}^C \hat{y}_i = 1 ∑i=1Cy^i=1 )(softmax 性质),我们可以分析特定情况。
4. 简化(以 one-hot 标签为例)
假设真实标签 ( y k = 1 y_k = 1 yk=1 ),其他 ( y i = 0 ( i ≠ k ) y_i = 0 (i \neq k) yi=0(i=k) )。梯度为:
∂ L M S E ∂ z j = ( y ^ j − y j ) y ^ j ( 1 − y ^ j ) − y ^ j ∑ i ≠ j ( y ^ i − y i ) y ^ i \frac{\partial L_{MSE}}{\partial z_j} = (\hat{y}_j - y_j) \hat{y}_j (1 - \hat{y}_j) - \hat{y}_j \sum_{i \neq j} (\hat{y}_i - y_i) \hat{y}_i ∂zj∂LMSE=(y^j−yj)y^j(1−y^j)−y^ji=j∑(y^i−yi)y^i
-
若 ( j = k j = k j=k )(正确类别):
( y j = 1 y_j = 1 yj=1 ),( y i = 0 ( i ≠ j ) y_i = 0 (i \neq j) yi=0(i=j) )
第一项:( ( y ^ j − 1 ) y ^ j ( 1 − y ^ j ) (\hat{y}_j - 1) \hat{y}_j (1 - \hat{y}_j) (y^j−1)y^j(1−y^j) )
第二项:( − y ^ j ∑ i ≠ j ( y ^ i − 0 ) y ^ i = − y ^ j ∑ i ≠ j y ^ i 2 -\hat{y}_j \sum_{i \neq j} (\hat{y}_i - 0) \hat{y}_i = -\hat{y}_j \sum_{i \neq j} \hat{y}_i^2 −y^j∑i=j(y^i−0)y^i=−y^j∑i=jy^i2 )
总梯度:
∂ L M S E ∂ z k = ( y ^ k − 1 ) y ^ k ( 1 − y ^ k ) − y ^ k ∑ i ≠ k y ^ i 2 \frac{\partial L_{MSE}}{\partial z_k} = (\hat{y}_k - 1) \hat{y}_k (1 - \hat{y}_k) - \hat{y}_k \sum_{i \neq k} \hat{y}_i^2 ∂zk∂LMSE=(y^k−1)y^k(1−y^k)−y^ki=k∑y^i2 -
若 ( j ≠ k j \neq k j=k )(错误类别):
( y j = 0 y_j = 0 yj=0 ),第一项:( y ^ j y ^ j ( 1 − y ^ j ) = y ^ j 2 ( 1 − y ^ j ) \hat{y}_j \hat{y}_j (1 - \hat{y}_j) = \hat{y}_j^2 (1 - \hat{y}_j) y^jy^j(1−y^j)=y^j2(1−y^j) )
第二项:( − y ^ j [ ( y ^ k − 1 ) y ^ k + ∑ i ≠ j , i ≠ k y ^ i 2 ] -\hat{y}_j [(\hat{y}_k - 1) \hat{y}_k + \sum_{i \neq j, i \neq k} \hat{y}_i^2] −y^j[(y^k−1)y^k+∑i=j,i=ky^i2] )
总梯度更复杂,但核心依赖于 ( y ^ j \hat{y}_j y^j ) 的值。
梯度趋于 0 的分析
现在分析当 ( y ^ j \hat{y}_j y^j ) 接近边界(0 或 1)时的行为:
-
正确类别 ( j = k j = k j=k ),( y ^ k → 1 \hat{y}_k \to 1 y^k→1 )
- ( y ^ k − 1 → 0 \hat{y}_k - 1 \to 0 y^k−1→0 )
- ( 1 − y ^ k → 0 1 - \hat{y}_k \to 0 1−y^k→0 )
- (
∑
i
≠
k
y
^
i
2
→
0
\sum_{i \neq k} \hat{y}_i^2 \to 0
∑i=ky^i2→0 )(因为 (
∑
i
=
1
C
y
^
i
=
1
\sum_{i=1}^C \hat{y}_i = 1
∑i=1Cy^i=1 ))
∂ L M S E ∂ z k ≈ 0 ⋅ y ^ k ⋅ 0 − y ^ k ⋅ 0 = 0 \frac{\partial L_{MSE}}{\partial z_k} \approx 0 \cdot \hat{y}_k \cdot 0 - \hat{y}_k \cdot 0 = 0 ∂zk∂LMSE≈0⋅y^k⋅0−y^k⋅0=0
梯度趋于 0,这在正确预测时是合理的。
-
错误类别 ( j ≠ k j \neq k j=k ),( y ^ j → 1 \hat{y}_j \to 1 y^j→1 ),( y ^ k → 0 \hat{y}_k \to 0 y^k→0 )
- 第一项:( y ^ j 2 ( 1 − y ^ j ) → 1 ⋅ ( 1 − 1 ) = 0 \hat{y}_j^2 (1 - \hat{y}_j) \to 1 \cdot (1 - 1) = 0 y^j2(1−y^j)→1⋅(1−1)=0 )
- 第二项:(
−
y
^
j
(
y
^
k
−
1
)
y
^
k
≈
−
1
⋅
(
0
−
1
)
⋅
0
=
0
-\hat{y}_j (\hat{y}_k - 1) \hat{y}_k \approx -1 \cdot (0 - 1) \cdot 0 = 0
−y^j(y^k−1)y^k≈−1⋅(0−1)⋅0=0 )
梯度仍趋于 0,尽管预测完全错误。
这表明,当 softmax 输出“饱和”(某类概率接近 1,其他接近 0)时,MSE 的梯度会消失,模型无法有效纠正错误。
与交叉熵对比
交叉熵加 softmax 的梯度为:
∂ L C E ∂ z j = y ^ j − y j \frac{\partial L_{CE}}{\partial z_j} = \hat{y}_j - y_j ∂zj∂LCE=y^j−yj
当 ( y ^ k → 0 \hat{y}_k \to 0 y^k→0 )(正确类别概率很低)时,( ∂ L C E ∂ z k = 0 − 1 = − 1 \frac{\partial L_{CE}}{\partial z_k} = 0 - 1 = -1 ∂zk∂LCE=0−1=−1 ),梯度显著,推动修正。
总结
Softmax + MSE 的梯度推导揭示了其复杂性与局限性。梯度中的 ( y ^ j ( 1 − y ^ j ) \hat{y}_j (1 - \hat{y}_j) y^j(1−y^j) ) 和交叉项导致在预测极端时(接近 0 或 1)梯度趋于 0,限制了模型对错误的高效修正。这进一步验证了 MSE 在多分类任务中的不足,而交叉熵的简洁性和动态性使其更优。
交叉熵损失的梯度推导
我们将深入探讨交叉熵损失(包括二元交叉熵和通用多分类交叉熵)的梯度推导,并分析其梯度特性,特别是为什么在预测错误时(如 ( y ^ → 0 \hat{y} \to 0 y^→0 ) 而 ( y = 1 y = 1 y=1 ))梯度较大,以及这对模型优化的意义。我们会保持数学严谨,适合高理论水平的深度学习研究者。
一、二元交叉熵(Binary Cross-Entropy)的梯度推导
定义
二元交叉熵损失用于二分类任务,假设真实标签 ( y ∈ { 0 , 1 } y \in \{0, 1\} y∈{0,1} ),模型预测概率 ( y ^ ∈ [ 0 , 1 ] \hat{y} \in [0, 1] y^∈[0,1] )(通常通过 sigmoid 函数从 logit ( z z z ) 得到)。损失函数为:
L B C E = − [ y log ( y ^ ) + ( 1 − y ) log ( 1 − y ^ ) ] L_{BCE} = -[y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})] LBCE=−[ylog(y^)+(1−y)log(1−y^)]
其中,( y ^ = σ ( z ) = 1 1 + e − z \hat{y} = \sigma(z) = \frac{1}{1 + e^{-z}} y^=σ(z)=1+e−z1 )。
梯度推导
我们需要计算损失对 logit ( z z z ) 的梯度 ( ∂ L B C E ∂ z \frac{\partial L_{BCE}}{\partial z} ∂z∂LBCE ),使用链式法则:
∂ L B C E ∂ z = ∂ L B C E ∂ y ^ ⋅ ∂ y ^ ∂ z \frac{\partial L_{BCE}}{\partial z} = \frac{\partial L_{BCE}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial z} ∂z∂LBCE=∂y^∂LBCE⋅∂z∂y^
-
计算 ( ∂ L B C E ∂ y ^ \frac{\partial L_{BCE}}{\partial \hat{y}} ∂y^∂LBCE )
对 ( y ^ \hat{y} y^ ) 求偏导:
∂ L B C E ∂ y ^ = ∂ ∂ y ^ [ − y log ( y ^ ) − ( 1 − y ) log ( 1 − y ^ ) ] \frac{\partial L_{BCE}}{\partial \hat{y}} = \frac{\partial}{\partial \hat{y}} \left[ -y \log(\hat{y}) - (1 - y) \log(1 - \hat{y}) \right] ∂y^∂LBCE=∂y^∂[−ylog(y^)−(1−y)log(1−y^)]- 第一项:( ∂ ∂ y ^ [ − y log ( y ^ ) ] = − y ⋅ 1 y ^ \frac{\partial}{\partial \hat{y}} [-y \log(\hat{y})] = -y \cdot \frac{1}{\hat{y}} ∂y^∂[−ylog(y^)]=−y⋅y^1 )
- 第二项:(
∂
∂
y
^
[
−
(
1
−
y
)
log
(
1
−
y
^
)
]
=
−
(
1
−
y
)
⋅
−
1
1
−
y
^
=
1
−
y
1
−
y
^
\frac{\partial}{\partial \hat{y}} [-(1 - y) \log(1 - \hat{y})] = -(1 - y) \cdot \frac{-1}{1 - \hat{y}} = \frac{1 - y}{1 - \hat{y}}
∂y^∂[−(1−y)log(1−y^)]=−(1−y)⋅1−y^−1=1−y^1−y )
合并:
∂ L B C E ∂ y ^ = − y y ^ + 1 − y 1 − y ^ \frac{\partial L_{BCE}}{\partial \hat{y}} = -\frac{y}{\hat{y}} + \frac{1 - y}{1 - \hat{y}} ∂y^∂LBCE=−y^y+1−y^1−y
-
计算 ( ∂ y ^ ∂ z \frac{\partial \hat{y}}{\partial z} ∂z∂y^ )(sigmoid 导数)
y ^ = σ ( z ) = 1 1 + e − z \hat{y} = \sigma(z) = \frac{1}{1 + e^{-z}} y^=σ(z)=1+e−z1
∂ y ^ ∂ z = e − z ( 1 + e − z ) 2 = 1 1 + e − z ⋅ e − z 1 + e − z = y ^ ( 1 − y ^ ) \frac{\partial \hat{y}}{\partial z} = \frac{e^{-z}}{(1 + e^{-z})^2} = \frac{1}{1 + e^{-z}} \cdot \frac{e^{-z}}{1 + e^{-z}} = \hat{y} (1 - \hat{y}) ∂z∂y^=(1+e−z)2e−z=1+e−z1⋅1+e−ze−z=y^(1−y^) -
合并计算完整梯度
∂ L B C E ∂ z = ( − y y ^ + 1 − y 1 − y ^ ) ⋅ y ^ ( 1 − y ^ ) \frac{\partial L_{BCE}}{\partial z} = \left( -\frac{y}{\hat{y}} + \frac{1 - y}{1 - \hat{y}} \right) \cdot \hat{y} (1 - \hat{y}) ∂z∂LBCE=(−y^y+1−y^1−y)⋅y^(1−y^)
化简:- 第一项:( − y y ^ ⋅ y ^ ( 1 − y ^ ) = − y ( 1 − y ^ ) -\frac{y}{\hat{y}} \cdot \hat{y} (1 - \hat{y}) = -y (1 - \hat{y}) −y^y⋅y^(1−y^)=−y(1−y^) )
- 第二项:(
1
−
y
1
−
y
^
⋅
y
^
(
1
−
y
^
)
=
(
1
−
y
)
y
^
\frac{1 - y}{1 - \hat{y}} \cdot \hat{y} (1 - \hat{y}) = (1 - y) \hat{y}
1−y^1−y⋅y^(1−y^)=(1−y)y^ )
总和:
∂ L B C E ∂ z = − y ( 1 − y ^ ) + ( 1 − y ) y ^ \frac{\partial L_{BCE}}{\partial z} = -y (1 - \hat{y}) + (1 - y) \hat{y} ∂z∂LBCE=−y(1−y^)+(1−y)y^
= − y + y y ^ + y ^ − y y ^ = y ^ − y = -y + y \hat{y} + \hat{y} - y \hat{y} = \hat{y} - y =−y+yy^+y^−yy^=y^−y
最终,二元交叉熵的梯度为:
∂ L B C E ∂ z = y ^ − y \frac{\partial L_{BCE}}{\partial z} = \hat{y} - y ∂z∂LBCE=y^−y
梯度特性分析
- 与误差成正比:梯度 ( y ^ − y \hat{y} - y y^−y ) 直接反映了预测值与真实值之间的偏差,形式简洁且线性。
- 预测错误时的行为:
- 当 (
y
=
1
y = 1
y=1 ),(
y
^
→
0
\hat{y} \to 0
y^→0 )(严重错误):
∂ L B C E ∂ z = 0 − 1 = − 1 \frac{\partial L_{BCE}}{\partial z} = 0 - 1 = -1 ∂z∂LBCE=0−1=−1
梯度为 -1,大小显著且方向明确,推动 ( z z z ) 增加,使 ( y ^ \hat{y} y^ ) 向 1 靠拢。 - 当 (
y
=
0
y = 0
y=0 ),(
y
^
→
1
\hat{y} \to 1
y^→1 )(严重错误):
∂ L B C E ∂ z = 1 − 0 = 1 \frac{\partial L_{BCE}}{\partial z} = 1 - 0 = 1 ∂z∂LBCE=1−0=1
梯度为 1,推动 ( z z z ) 减小,使 ( y ^ \hat{y} y^ ) 向 0 靠拢。
- 当 (
y
=
1
y = 1
y=1 ),(
y
^
→
0
\hat{y} \to 0
y^→0 )(严重错误):
- 无饱和问题:与 MSE 不同,这里没有 ( y ^ ( 1 − y ^ ) \hat{y} (1 - \hat{y}) y^(1−y^) ) 这样的乘积项,梯度不会因 ( y ^ \hat{y} y^ ) 接近 0 或 1 而趋于 0,确保错误预测时有强反馈。
二、通用交叉熵(多分类 Cross-Entropy)的梯度推导
定义
在多分类任务中,假设有 ( C C C ) 个类别,真实标签 ( y = [ y 1 , y 2 , … , y C ] \mathbf{y} = [y_1, y_2, \dots, y_C] y=[y1,y2,…,yC] )(one-hot 编码,例如 ( y k = 1 y_k = 1 yk=1 ),其他为 0),预测概率 ( y ^ = [ y ^ 1 , y ^ 2 , … , y ^ C ] \hat{\mathbf{y}} = [\hat{y}_1, \hat{y}_2, \dots, \hat{y}_C] y^=[y^1,y^2,…,y^C] ) 通过 softmax 从 logits ( z = [ z 1 , z 2 , … , z C ] \mathbf{z} = [z_1, z_2, \dots, z_C] z=[z1,z2,…,zC] ) 得到:
y ^ i = softmax ( z i ) = e z i ∑ j = 1 C e z j \hat{y}_i = \text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^C e^{z_j}} y^i=softmax(zi)=∑j=1Cezjezi
交叉熵损失为:
L C E = − ∑ i = 1 C y i log ( y ^ i ) L_{CE} = -\sum_{i=1}^C y_i \log(\hat{y}_i) LCE=−i=1∑Cyilog(y^i)
梯度推导
计算 ( ∂ L C E ∂ z j \frac{\partial L_{CE}}{\partial z_j} ∂zj∂LCE ):
∂ L C E ∂ z j = ∑ i = 1 C ∂ L C E ∂ y ^ i ⋅ ∂ y ^ i ∂ z j \frac{\partial L_{CE}}{\partial z_j} = \sum_{i=1}^C \frac{\partial L_{CE}}{\partial \hat{y}_i} \cdot \frac{\partial \hat{y}_i}{\partial z_j} ∂zj∂LCE=i=1∑C∂y^i∂LCE⋅∂zj∂y^i
-
计算 ( ∂ L C E ∂ y ^ i \frac{\partial L_{CE}}{\partial \hat{y}_i} ∂y^i∂LCE )
∂ L C E ∂ y ^ i = ∂ ∂ y ^ i [ − ∑ k = 1 C y k log ( y ^ k ) ] = − y i ⋅ 1 y ^ i = − y i y ^ i \frac{\partial L_{CE}}{\partial \hat{y}_i} = \frac{\partial}{\partial \hat{y}_i} \left[ -\sum_{k=1}^C y_k \log(\hat{y}_k) \right] = -y_i \cdot \frac{1}{\hat{y}_i} = -\frac{y_i}{\hat{y}_i} ∂y^i∂LCE=∂y^i∂[−k=1∑Cyklog(y^k)]=−yi⋅y^i1=−y^iyi -
计算 ( ∂ y ^ i ∂ z j \frac{\partial \hat{y}_i}{\partial z_j} ∂zj∂y^i )(Softmax 导数)
- 当 (
i
=
j
i = j
i=j ):
∂ y ^ i ∂ z i = y ^ i ( 1 − y ^ i ) \frac{\partial \hat{y}_i}{\partial z_i} = \hat{y}_i (1 - \hat{y}_i) ∂zi∂y^i=y^i(1−y^i) - 当 (
i
≠
j
i \neq j
i=j ):
∂ y ^ i ∂ z j = − y ^ i y ^ j \frac{\partial \hat{y}_i}{\partial z_j} = -\hat{y}_i \hat{y}_j ∂zj∂y^i=−y^iy^j
- 当 (
i
=
j
i = j
i=j ):
-
合并计算完整梯度
∂ L C E ∂ z j = ∑ i = 1 C ( − y i y ^ i ) ⋅ ∂ y ^ i ∂ z j \frac{\partial L_{CE}}{\partial z_j} = \sum_{i=1}^C \left( -\frac{y_i}{\hat{y}_i} \right) \cdot \frac{\partial \hat{y}_i}{\partial z_j} ∂zj∂LCE=i=1∑C(−y^iyi)⋅∂zj∂y^i
分开 ( i = j i = j i=j ) 和 ( i ≠ j i \neq j i=j ):
∂ L C E ∂ z j = − y j y ^ j ⋅ y ^ j ( 1 − y ^ j ) + ∑ i ≠ j ( − y i y ^ i ) ( − y ^ i y ^ j ) \frac{\partial L_{CE}}{\partial z_j} = -\frac{y_j}{\hat{y}_j} \cdot \hat{y}_j (1 - \hat{y}_j) + \sum_{i \neq j} \left( -\frac{y_i}{\hat{y}_i} \right) (-\hat{y}_i \hat{y}_j) ∂zj∂LCE=−y^jyj⋅y^j(1−y^j)+i=j∑(−y^iyi)(−y^iy^j)- 第一项:( − y j ( 1 − y ^ j ) -y_j (1 - \hat{y}_j) −yj(1−y^j) )
- 第二项:(
∑
i
≠
j
y
i
y
^
j
=
y
^
j
∑
i
≠
j
y
i
\sum_{i \neq j} y_i \hat{y}_j = \hat{y}_j \sum_{i \neq j} y_i
∑i=jyiy^j=y^j∑i=jyi )
由于 ( y \mathbf{y} y ) 是 one-hot,假设 ( y k = 1 y_k = 1 yk=1 ),其他为 0: - 若 (
j
=
k
j = k
j=k ):(
∑
i
≠
j
y
i
=
0
\sum_{i \neq j} y_i = 0
∑i=jyi=0 )
∂ L C E ∂ z k = − 1 ⋅ ( 1 − y ^ k ) + 0 = y ^ k − 1 \frac{\partial L_{CE}}{\partial z_k} = -1 \cdot (1 - \hat{y}_k) + 0 = \hat{y}_k - 1 ∂zk∂LCE=−1⋅(1−y^k)+0=y^k−1 - 若 (
j
≠
k
j \neq k
j=k ):(
y
j
=
0
y_j = 0
yj=0 )
∂ L C E ∂ z j = 0 ⋅ ( 1 − y ^ j ) + y ^ j ⋅ 1 = y ^ j \frac{\partial L_{CE}}{\partial z_j} = 0 \cdot (1 - \hat{y}_j) + \hat{y}_j \cdot 1 = \hat{y}_j ∂zj∂LCE=0⋅(1−y^j)+y^j⋅1=y^j
最终,多分类交叉熵的梯度为:
∂ L C E ∂ z j = y ^ j − y j \frac{\partial L_{CE}}{\partial z_j} = \hat{y}_j - y_j ∂zj∂LCE=y^j−yj
惊人地,这与二元交叉熵一致,表明 softmax + 交叉熵的梯度形式非常优雅。
梯度特性分析
- 与误差成正比:( y ^ j − y j \hat{y}_j - y_j y^j−yj ) 直接衡量预测概率与真实标签的偏差。
- 预测错误时的行为:
- 假设 (
y
k
=
1
y_k = 1
yk=1 ),(
y
^
k
→
0
\hat{y}_k \to 0
y^k→0 )(正确类别概率很低):
∂ L C E ∂ z k = 0 − 1 = − 1 \frac{\partial L_{CE}}{\partial z_k} = 0 - 1 = -1 ∂zk∂LCE=0−1=−1
梯度为 -1,推动 ( z k z_k zk ) 增加。 - 对于 (
j
≠
k
j \neq k
j=k ),若 (
y
^
j
→
1
\hat{y}_j \to 1
y^j→1 )(错误类别概率很高):
∂ L C E ∂ z j = 1 − 0 = 1 \frac{\partial L_{CE}}{\partial z_j} = 1 - 0 = 1 ∂zj∂LCE=1−0=1
梯度为 1,推动 ( z j z_j zj ) 减小。
- 假设 (
y
k
=
1
y_k = 1
yk=1 ),(
y
^
k
→
0
\hat{y}_k \to 0
y^k→0 )(正确类别概率很低):
- 全局一致性:梯度对所有类别的调整是协同的,确保 ( ∑ j y ^ j = 1 \sum_{j} \hat{y}_j = 1 ∑jy^j=1 ) 的约束下,正确类别的概率上升,其他下降。
三、梯度特性比较与洞见
-
梯度大小与误差的直接关系
- 二元和多分类交叉熵的梯度均为 ( y ^ − y \hat{y} - y y^−y ) 或 ( y ^ j − y j \hat{y}_j - y_j y^j−yj ),与误差成正比,且不受 ( y ^ \hat{y} y^ ) 边界值的影响。这与 MSE 的 ( ( y ^ − y ) y ^ ( 1 − y ^ ) (\hat{y} - y) \hat{y} (1 - \hat{y}) (y^−y)y^(1−y^) ) 形成对比,后者在 ( y ^ → 0 \hat{y} \to 0 y^→0 ) 或 ( 1 1 1 ) 时梯度趋于 0。
-
错误预测时的强反馈
- 当预测严重错误时(如 ( y ^ → 0 \hat{y} \to 0 y^→0 ) 而 ( y = 1 y = 1 y=1 )),交叉熵梯度保持恒定(例如 -1),提供稳定的优化信号。这种特性源自对数损失的非线性惩罚,放大错误时的代价。
-
避免饱和
- 交叉熵的梯度推导中,sigmoid 或 softmax 的导数被巧妙抵消,使得梯度不依赖于 ( y ^ ( 1 − y ^ ) \hat{y} (1 - \hat{y}) y^(1−y^)) 这样的项。这避免了激活函数饱和时的梯度消失问题。
-
优化效率
- 梯度的简单形式(( y ^ − y \hat{y} - y y^−y ))计算成本低,且在深层网络中易于反向传播,确保快速收敛。
总结
二元交叉熵和多分类交叉熵的梯度推导均得出 ( ∂ L ∂ z = y ^ − y \frac{\partial L}{\partial z} = \hat{y} - y ∂z∂L=y^−y ),其特性在于与误差直接相关、在错误预测时提供大梯度、无饱和问题。这些特点使交叉熵非常适合分类任务,能够快速修正错误并推动模型收敛,而不像 MSE 那样因梯度消失而停滞。
后记
2025年3月21日21点18分于上海,在grok 3大模型辅助下完成。