超大规模分类(二):InfoNCE
结合噪声对比估计(Noise Contrastive Estimation,NCE)的思想,通过互信息(Mutual Information)最小化来优化大规模分类任务,2019年,DeepMind的研究人员提出InfoNCE损失。
相较于NCE损失,InfoNCE损失有如下区别:
- InfoNCE的噪声采样自原数据的分布,而NCE的噪声采样自设定的某一分布
- InfoNCE在实操层面通过多分类实现,而NCE在实操层面通过二分类实现
- InfoNCE实现了数据表征和类别表征的互信息最小化,而NCE未实现
问题建模
已知
t
t
t时间段上下文
c
c
c,预测
t
+
k
t+k
t+k时间段的数据为
x
t
+
k
x_{t+k}
xt+k。该问题常见的应用场景有:
(1)文本、音频、图像生成任务,已知
t
t
t时间段之前的所有文本、音频、图像,生成
t
+
k
t+k
t+k时间段的文本、音频、图像
(2)图文对比学习任务,已知文本,预测其对应的图像是哪一个
(3)分类任务,已知聚类信息,预测某数据的类别
在case(2)和case(3)中,
t
t
t时间段上下文
c
c
c分别表示文本和聚类信息,表明InfoNCE具备较强的泛化能力,能够应用到各种场景中。
NCE的具体做法为利用模型拟合 P ( x t + k ∣ c ) P(x_{t+k}|c) P(xt+k∣c),InfoNCE认为该拟合方式粒度不够细,例如借助高维隐层向量 c c c来复原图像 x t + k x_{t+k} xt+k任务,隐层向量维度可能太低了,不足以复原图像。
InfoNCE将上下文 c c c和数据 x t + k x_{t+k} xt+k同时建模成维度一致的向量表示,通过互信息最大化,缩小两者之间的互信息,使得借助上下文 c c c来复原数据 x t + k x_{t+k} xt+k成为可能。
互信息
互信息的公式表示为:
I
(
x
t
+
k
,
c
t
)
=
∑
x
t
+
k
,
c
t
p
(
x
t
+
k
,
c
t
)
l
o
g
p
(
x
t
+
k
∣
c
t
)
p
(
x
t
+
k
)
(1)
I(x_{t+k},c_t)=\sum_{x_{t+k},c_t}p(x_{t+k},c_t)log\frac{p(x_{t+k}|c_{t})}{p(x_{t+k})}\tag{1}
I(xt+k,ct)=xt+k,ct∑p(xt+k,ct)logp(xt+k)p(xt+k∣ct)(1)
互信息表示熵的差值,有:
I
(
x
t
+
k
,
c
t
)
=
H
(
x
t
+
k
)
−
H
(
x
t
+
k
∣
c
t
)
(2)
I(x_{t+k},c_t)=H(x_{t+k})-H(x_{t+k}|c_t)\tag{2}
I(xt+k,ct)=H(xt+k)−H(xt+k∣ct)(2)
其中,
H
(
⋅
)
H(\cdot)
H(⋅)表示熵。
熵表示信息的不确定性或者信息量。互信息等于 c t c_t ct变量引入后, x t + k x_{t+k} xt+k的熵的变化。变化越少,说明 x t + k x_{t+k} xt+k和 c t c_t ct越接近。也就是说,互信息可以代表两个变量的相关性。
优化目标
公式(1)的互信息可以写成:
I
(
x
t
+
k
,
c
t
)
=
∑
x
t
+
k
,
c
t
p
(
x
t
+
k
,
c
t
)
l
o
g
p
(
x
t
+
k
∣
c
t
)
p
(
x
t
+
k
)
=
E
x
t
+
k
,
c
t
(
l
o
g
p
(
x
t
+
k
∣
c
t
)
p
(
x
t
+
k
)
)
(3)
\begin{equation}\begin{aligned} I(x_{t+k},c_t)&=\sum_{x_{t+k},c_t}p(x_{t+k},c_t)log\frac{p(x_{t+k}|c_{t})}{p(x_{t+k})}\\ &=E_{x_{t+k},c_t}\left({log\frac{p(x_{t+k}|c_{t})}{p(x_{t+k})}}\right) \end{aligned} \end{equation} \tag{3}
I(xt+k,ct)=xt+k,ct∑p(xt+k,ct)logp(xt+k)p(xt+k∣ct)=Ext+k,ct(logp(xt+k)p(xt+k∣ct))(3)
从互信息的角度来看,InfoNCE损失希望模型参数
f
θ
(
x
t
+
k
,
c
t
)
f_{\theta}(x_{t+k},c_t)
fθ(xt+k,ct)拟合到
p
(
x
t
+
k
∣
c
t
)
p
(
x
t
+
k
)
\frac{p(x_{t+k}|c_{t})}{p(x_{t+k})}
p(xt+k)p(xt+k∣ct),也就是
f
θ
(
x
t
+
k
,
c
t
)
∝
p
(
x
t
+
k
∣
c
t
)
p
(
x
t
+
k
)
f_{\theta}(x_{t+k},c_t)\propto\frac{p(x_{t+k}|c_{t})}{p(x_{t+k})}
fθ(xt+k,ct)∝p(xt+k)p(xt+k∣ct)。
观察 p ( x t + k ∣ c t ) p ( x t + k ) \frac{p(x_{t+k}|c_{t})}{p(x_{t+k})} p(xt+k)p(xt+k∣ct),类似于NCE损失,InfoNCE将分子 p ( x t + k ∣ c t ) p(x_{t+k}|c_{t}) p(xt+k∣ct)作为正样本,分母 p ( x t + k ) p(x_{t+k}) p(xt+k)作为负样本。
- NCE损失通过二分类,将不同的负样本 p ( x t + k ) p(x_{t+k}) p(xt+k)合并成一个类,区分正样本类别和负样本类别。
- InfoNCE损失通过多分类,每一个负样本 p ( x t + k ) p(x_{t+k}) p(xt+k)独立成各自的类别,若有 N − 1 N-1 N−1个负样本, 1 1 1个正样本,通过 N N N多分类任务进行区分。
于是,训练数据有
N
N
N个样本
V
=
{
v
1
,
v
2
,
.
.
.
,
v
N
}
V=\{v_1,v_2,...,v_N\}
V={v1,v2,...,vN},其中
1
1
1个采样自正样本
p
(
x
t
+
k
∣
c
t
)
p(x_{t+k}|c_{t})
p(xt+k∣ct),
N
−
1
N-1
N−1个采样自负样本
p
(
x
t
+
k
)
p(x_{t+k})
p(xt+k), 每个样本具有各自的类别。那么
v
i
v_i
vi采样自正样本,以及
v
j
≠
i
v_{j\neq i}
vj=i采样自负样本的概率,即
v
i
v_i
vi的数据分布为:
p
v
i
=
p
(
x
t
+
k
∣
c
t
)
∏
l
≠
t
+
k
p
(
x
l
)
∑
j
p
(
x
j
∣
c
t
)
∏
l
≠
j
p
(
x
l
)
=
p
(
x
t
+
k
∣
c
t
)
p
(
x
t
+
k
)
∑
j
p
(
x
j
∣
x
t
)
p
(
x
j
)
(4)
\begin{equation}\begin{aligned} p_{v_i}=\frac{p(x_{t+k}|c_t)\prod_{l\neq{t+k}}p(x_l)}{\sum_jp(x_j|c_t)\prod_{l\neq j}p(x_l)}=\frac{\frac{p(x_{t+k}|c_t)}{p(x_{t+k})}}{\sum_j\frac{p(x_j|x_t)}{p(x_j)}} \end{aligned}\end{equation} \tag{4}
pvi=∑jp(xj∣ct)∏l=jp(xl)p(xt+k∣ct)∏l=t+kp(xl)=∑jp(xj)p(xj∣xt)p(xt+k)p(xt+k∣ct)(4)
观察公式(4),由于
f
θ
(
x
t
+
k
,
c
t
)
∝
p
(
x
t
+
k
∣
c
t
)
p
(
x
t
+
k
)
f_{\theta}(x_{t+k},c_t)\propto\frac{p(x_{t+k}|c_{t})}{p(x_{t+k})}
fθ(xt+k,ct)∝p(xt+k)p(xt+k∣ct),公式(4)可以写成:
f
θ
(
x
t
+
k
,
c
t
)
∑
j
f
θ
(
x
j
,
c
t
)
(5)
\begin{equation}\begin{aligned} \frac{f_{\theta}(x_{t+k},c_t)}{\sum_jf_{\theta}(x_j,c_t)} \end{aligned}\end{equation} \tag{5}
∑jfθ(xj,ct)fθ(xt+k,ct)(5)
损失函数直接等于对数似然函数均值的负数形式,有:
L
=
−
E
V
l
o
g
p
v
h
=
−
E
V
l
o
g
[
f
θ
(
x
t
+
k
,
c
t
)
∑
j
f
θ
(
x
j
,
c
t
)
]
(6)
\begin{equation}\begin{aligned} L&=-\mathbb{E}_Vlogp_{v_h}\\&=-\mathbb{E}_Vlog\left[\frac{f_{\theta}(x_{t+k},c_t)}{\sum_jf_{\theta}(x_j,c_t)}\right] \end{aligned}\end{equation} \tag{6}
L=−EVlogpvh=−EVlog[∑jfθ(xj,ct)fθ(xt+k,ct)](6)
在具体实现过程中,
f
θ
(
x
t
+
k
,
c
t
)
f_{\theta}(x_{t+k},c_t)
fθ(xt+k,ct)一般表示
x
t
+
k
x_{t+k}
xt+k和
c
t
c_t
ct的向量表征之间的余弦相似度。
InfoNCE与互信息的关系
InfoNCE利用
f
(
x
t
+
k
,
c
t
)
f(x_{t+k},c_t)
f(xt+k,ct)来拟合
p
(
x
t
+
k
∣
c
t
)
p
(
x
t
+
k
)
\frac{p(x_{t+k}|c_t)}{p(x_{t+k})}
p(xt+k)p(xt+k∣ct),如果拟合成功,则有最优损失形式:
L
=
−
E
V
l
o
g
[
f
θ
(
x
t
+
k
,
c
t
)
∑
j
f
θ
(
x
j
,
c
t
)
]
=
−
E
V
l
o
g
[
p
(
x
t
+
k
∣
c
t
)
p
(
x
t
+
k
)
p
(
x
t
+
k
∣
c
t
)
p
(
x
t
+
k
)
+
∑
v
j
∈
V
n
e
g
p
(
x
j
∣
c
t
)
p
(
x
j
)
]
=
E
V
l
o
g
[
1
+
p
(
x
t
+
k
)
p
(
x
t
+
k
∣
c
t
)
∑
v
j
∈
V
n
e
g
p
(
x
j
∣
c
t
)
p
(
x
j
)
]
≈
E
V
l
o
g
[
1
+
p
(
x
t
+
k
)
p
(
x
t
+
k
∣
c
t
)
(
N
−
1
)
E
v
j
p
(
x
j
∣
c
t
)
p
(
x
j
)
]
=
E
V
l
o
g
[
1
+
p
(
x
t
+
k
)
p
(
x
t
+
k
∣
c
t
)
(
N
−
1
)
]
=
E
V
l
o
g
[
p
(
x
t
+
k
∣
c
t
)
p
(
x
t
+
k
∣
c
t
)
+
p
(
x
t
+
k
)
p
(
x
t
+
k
∣
c
t
)
(
N
−
1
)
]
≥
E
V
l
o
g
[
p
(
x
t
+
k
)
p
(
x
t
+
k
∣
c
t
)
+
p
(
x
t
+
k
)
p
(
x
t
+
k
∣
c
t
)
(
N
−
1
)
]
=
E
V
l
o
g
[
p
(
x
t
+
k
)
p
(
x
t
+
k
∣
c
t
)
N
]
=
E
V
l
o
g
[
p
(
x
t
+
k
)
p
(
x
t
+
k
∣
c
t
)
]
+
l
o
g
N
=
−
I
(
x
t
+
k
,
c
t
)
+
l
o
g
(
N
)
(7)
\begin{equation}\begin{aligned} L&=-\mathbb{E}_Vlog\left[\frac{f_{\theta}(x_{t+k},c_t)}{\sum_jf_{\theta}(x_j,c_t)}\right]\\ &=-\mathbb{E}_Vlog\left[\frac{\frac{p(x_{t+k}|c_t)}{p(x_{t+k})}}{\frac{p(x_{t+k}|c_t)}{p(x_{t+k})}+\sum_{v_j \in V_{neg}}\frac{p(x_{j}|c_t)}{p(x_{j})}}\right]\\ &=\mathbb{E}_Vlog\left[1+\frac{p(x_{t+k})}{p(x_{t+k}|c_t)}\sum_{v_j \in V_{neg} }\frac{p(x_{j}|c_t)}{p(x_{j})}\right]\\ &\approx\mathbb{E}_Vlog\left[1+\frac{p(x_{t+k})}{p(x_{t+k}|c_t)}(N-1)\mathbb{E}_{v_j}\frac{p(x_{j}|c_t)}{p(x_{j})}\right]\\ &=\mathbb{E}_Vlog\left[1+\frac{p(x_{t+k})}{p(x_{t+k}|c_t)}(N-1)\right]\\ &=\mathbb{E}_Vlog\left[\frac{p(x_{t+k}|c_t)}{p(x_{t+k}|c_t)}+\frac{p(x_{t+k})}{p(x_{t+k}|c_t)}(N-1)\right]\\ &\geq\mathbb{E}_Vlog\left[\frac{p(x_{t+k})}{p(x_{t+k}|c_t)}+\frac{p(x_{t+k})}{p(x_{t+k}|c_t)}(N-1)\right]\\ &=\mathbb{E}_Vlog\left[\frac{p(x_{t+k})}{p(x_{t+k}|c_t)}N\right]\\ &=\mathbb{E}_Vlog\left[\frac{p(x_{t+k})}{p(x_{t+k}|c_t)}\right]+logN\\ &=-I(x_{t+k},c_t)+log(N) \end{aligned}\end{equation} \tag{7}
L=−EVlog[∑jfθ(xj,ct)fθ(xt+k,ct)]=−EVlog
p(xt+k)p(xt+k∣ct)+∑vj∈Vnegp(xj)p(xj∣ct)p(xt+k)p(xt+k∣ct)
=EVlog
1+p(xt+k∣ct)p(xt+k)vj∈Vneg∑p(xj)p(xj∣ct)
≈EVlog[1+p(xt+k∣ct)p(xt+k)(N−1)Evjp(xj)p(xj∣ct)]=EVlog[1+p(xt+k∣ct)p(xt+k)(N−1)]=EVlog[p(xt+k∣ct)p(xt+k∣ct)+p(xt+k∣ct)p(xt+k)(N−1)]≥EVlog[p(xt+k∣ct)p(xt+k)+p(xt+k∣ct)p(xt+k)(N−1)]=EVlog[p(xt+k∣ct)p(xt+k)N]=EVlog[p(xt+k∣ct)p(xt+k)]+logN=−I(xt+k,ct)+log(N)(7)
代码实现
基于最经典的CLIP模型,来理解InfoNCE的代码实现。
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_i, d_e] - learned proj of text to embed
# t - learned temperature parameter
# extract feature representations of each modality
I_f= image_encoder(I) #[n, d_i]
T_f= text_encoder(T) #[n, d_t]
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_i), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t) / 2
CLIP模型分别计算了文本侧和图像侧的InfoNCE损失,具体而言,包含如下步骤:
- 明确CLIP做的是图文对比学习任务:假设给定 M M M个图文对 { ( I i , T i ) } \{(I_i,T_i)\} {(Ii,Ti)}( I I I表示图像, T T T表示文本),我们需要拉近 I i I_i Ii和 T i T_i Ti间的距离,拉远 I i I_i Ii和 T j ≠ i T_{j\neq i} Tj=i距离,拉远 T i T_i Ti和 I j ≠ i I_{j\neq i} Ij=i的距离。最直接的方式是将每一个图文对归成一个类,通过 M M M类别的多分类任务来实现。但图文对 M M M的数量会非常大,可能几百万,可能几亿,压根无法训练。
- 为此,CLIP通过InfoNCE损失来替代 M M M类别的多分类任务。在一个batch内,组成训练数据集合 V V V,共 n n n个,如公式(4)所示。
- 首先提取图像表征 I _ f I\_f I_f和文本表征 T _ f T\_f T_f。
- 对于图像表征
I
_
f
i
I\_f_i
I_fi来说
- 文本表征 T _ f i T\_f_i T_fi可以理解为InfoNCE中的正样本 p ( x t + k ∣ c t ) p(x_{t+k}|c_{t}) p(xt+k∣ct),有 1 1 1个正样本
- T _ f j ≠ i T\_f_{j \neq i} T_fj=i可以理解为InfoNCE中的负样本 p ( x t + k ) p(x_{t+k}) p(xt+k),共有 n − 1 n-1 n−1个负样本
- 计算图像表征 I _ f i I\_f_i I_fi到所有文本表征 { T _ f 1 , T _ f 2 , . . . , T _ f n } \{T\_f_1,T\_f_2,...,T\_f_n\} {T_f1,T_f2,...,T_fn}的相似度距离,CLIP中采用temperature对相似度距离进行平滑及锐化
- 调用torch的cross_entroy_loss方法,计算损失
- 有 n n n个图像表征 { I _ f 1 , I _ f 2 , . . . , I _ f n } \{I\_f_1,I\_f_2,...,I\_f_n\} {I_f1,I_f2,...,I_fn},每一个图像表征计算一个损失,平均后得到 l o s s i loss_i lossi,平均的操作已集成到cross_entropy_loss方法中
- 对于文本表征 T _ f i T\_f_i T_fi来说,上述流程类似,得到 l o s s t loss_t losst
- 两个损失相加,得到最终损失