深度学习|表示学习|作为损失函数的交叉熵|04
如是我闻: Cross Entropy(交叉熵)是一种用于衡量两个概率分布之间差异的度量方式,被广泛的应用,但是总是记不住。
背后的理论
交叉熵源于信息论,它是描述两个概率分布之间相似度的重要工具。假设有两个概率分布:
- p ( x ) p(x) p(x) : 真实分布(通常是one-hot编码的标签)
- q ( x ) q(x) q(x) : 预测分布(通常是模型的输出,经过Softmax)
交叉熵衡量的是使用预测分布 q ( x ) q(x) q(x) 去表达真实分布 p ( x ) p(x) p(x) 所需的信息量(或不确定性)。
交叉熵的公式为:
H
(
p
,
q
)
=
−
∑
x
p
(
x
)
log
q
(
x
)
H(p, q) = - \sum_{x} p(x) \log q(x)
H(p,q)=−x∑p(x)logq(x)
在分类任务中,通常 p ( x ) p(x) p(x) 是one-hot编码的概率分布,例如对于分类为类别 i i i,只有 p ( x i ) = 1 p(x_i) = 1 p(xi)=1,其余为0。此时,交叉熵简化为:
H ( p , q ) = − log q ( x i ) H(p, q) = - \log q(x_i) H(p,q)=−logq(xi)
其中 q ( x i ) q(x_i) q(xi) 是模型预测类别 i i i 的概率。
在机器学习中的应用
在神经网络中,交叉熵通常作为分类问题的损失函数。例如:
-
多分类任务(Softmax 输出):模型的输出是类别的概率分布。
- 损失函数为:
L = − 1 N ∑ i = 1 N ∑ j = 1 C y i j log ( y ^ i j ) L = - \frac{1}{N} \sum_{i=1}^N \sum_{j=1}^C y_{ij} \log(\hat{y}_{ij}) L=−N1i=1∑Nj=1∑Cyijlog(y^ij)
其中:- N N N: 样本数量
- C C C: 类别数量
- y i j y_{ij} yij: 第 i i i 个样本的真实标签(one-hot编码)
- y ^ i j \hat{y}_{ij} y^ij: 第 i i i 个样本预测为类别 j j j 的概率
- 损失函数为:
-
二分类任务(Sigmoid 输出):用于只有两个类别的场景。
- 损失函数为:
L = − 1 N ∑ i = 1 N ( y i log ( y ^ i ) + ( 1 − y i ) log ( 1 − y ^ i ) ) L = - \frac{1}{N} \sum_{i=1}^N \left( y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right) L=−N1i=1∑N(yilog(y^i)+(1−yi)log(1−y^i))
其中:- y i y_i yi: 第 i i i 个样本的真实标签(0 或 1)
- y ^ i \hat{y}_i y^i: 第 i i i个样本的预测概率
- 损失函数为:
为什么使用 Cross Entropy?
-
惩罚错误预测:
- 如果预测结果 q ( x i ) q(x_i) q(xi) 接近真实值 p ( x i ) p(x_i) p(xi),交叉熵值很低。
- 如果预测值远离真实值,交叉熵值较高,模型会被更强烈地惩罚。
-
理论基础:
- 交叉熵直接源自最大似然估计(MLE)。通过最小化交叉熵,实际上是在最大化模型对训练数据的似然。
-
概率解释:
- 交叉熵可以理解为用预测分布 q ( x ) q(x) q(x) 表达真实分布 p ( x ) p(x) p(x) 的代价,反映了预测分布的准确性。
举例说明
假设我们有 3 类分类任务,真实标签为 [ 1 , 0 , 0 ] [1, 0, 0] [1,0,0](表示类别 1),模型输出的概率分布为 [ 0.7 , 0.2 , 0.1 ] [0.7, 0.2, 0.1] [0.7,0.2,0.1]。
交叉熵计算:
H
(
p
,
q
)
=
−
(
1
⋅
log
(
0.7
)
+
0
⋅
log
(
0.2
)
+
0
⋅
log
(
0.1
)
)
H(p, q) = - \left( 1 \cdot \log(0.7) + 0 \cdot \log(0.2) + 0 \cdot \log(0.1) \right)
H(p,q)=−(1⋅log(0.7)+0⋅log(0.2)+0⋅log(0.1))
H
(
p
,
q
)
=
−
log
(
0.7
)
≈
0.356
H(p, q) = - \log(0.7) \approx 0.356
H(p,q)=−log(0.7)≈0.356
如果模型输出错误的概率分布
[
0.1
,
0.3
,
0.6
]
[0.1, 0.3, 0.6]
[0.1,0.3,0.6],交叉熵将变大:
H
(
p
,
q
)
=
−
log
(
0.1
)
≈
2.302
H(p, q) = - \log(0.1) \approx 2.302
H(p,q)=−log(0.1)≈2.302
从中可以看出:交叉熵随着预测接近真实分布而减小。
以上