【论文笔记】LLaVA-KD: A Framework of Distilling Multimodal Large Language Models
Abstract
大语言模型(Large Language Models, LLM)的成功,使得研究者为了统一视觉和语言的理解去探索多模态大预言模型(Multimodal Large Language Models, MLLM)。
但是MLLM庞大的模型和复杂的计算使其很难应用在资源受限的环境,小型MLLM(s-MLLM)的表现又远不如大型的MLLM(l-MLLM)。
基于上述提到的问题,本文提出了全新的LLaVA-KD框架,将l-MLLM的知识转移到s-MLLM。具体地,本文提出:
- 多模态蒸馏(Multimodal Distillation, MDist),减小l-MLLM和s-MLLM之间的视觉-文本输出分布的差异;
- 关系蒸馏(Relation Distillation, RDist),迁移l-MLLM对视觉特征之间相关性的建模能力。
本文还提出了三阶段训练方案,充分挖掘s-MLLM的潜力: - 预训练时蒸馏,对齐视觉文本表示;
- 有监督地微调,使模型具备多模态理解;
- 微调时蒸馏,进一步迁移l-MLLM的能力。
本文方法在不改变s-MLLM结构的情况下显著提升了其性能。
github仓库
Introduction
本文从研究各种训练策略的角度出发,在不改变模型架构的情况下,探索提高、s-MLLM的性能。
图1:为了训练s-MLLM,(a)已有方法遵循两步训练,包括预训练(Pre-Training, PT)和监督微调(Supervised Fine-Tuning, SFT);(b)本文的LLaVA-KD提出了三步训练,包含蒸馏预训练(Distilled Pre-Training, DPT)来对齐视觉文本表示、SFT来提升模型的多模态理解能力、蒸馏微调(Distilled Fine-Tuning, DFT)来转移l-MLLM的能力;©本文将LLaVA-KD和其他sota的MLLM在五个热门的多模态benchmark上进行比较。
如图1所示,已有的s-MLLM遵循两步训练策略,包含PT和SFT。PT阶段将视觉特征投影到文本嵌入空间;SFT阶段增强模型的理解和推离能力。
但是s-MLLM的模型容量很小,很难像l-MLLM那样捕获复杂的知识。本文将研究如何借助蒸馏提升s-MLLM的训练。
3 LLaVA-KD
图2:LLaVA-KD的总图,包含三阶段训练。1) DPT,向l-MLLM对齐视觉和文本信息。2) SFT,为s-MLLM带来多模态理解能力。3)DFT,向s-MLLM迁移l-MLLM的能力。在DPT和DFT中应用MDist,使用RDist来使得s-MLLM捕获视觉信息的复杂关系。
3.1 Composition of Distilled MLLM Architecture
图2左侧展示了MLLM的蒸馏过程,包含l-MLLM作为教师模型,和s-MLLM作为学生模型,分别包含三个部分:
Frozen Visual Encoder:用于获得强力的视觉特征。给定输入图像 X v ∈ R H × W × 3 X_v\in\mathbb{R}^{H\times W\times 3} Xv∈RH×W×3,排序成2D patches P v ∈ R N p × S p 2 × 3 P_v\in\mathbb{R}^{N_p\times S_p^2\times 3} Pv∈RNp×Sp2×3,其中 S p S_p Sp和 N p N_p Np表示patch的大小和数量。最后的transformer层将 P v P_v Pv变成 Z v ∈ R N p × C Z_v\in\mathbb{R}^{N_p\times C} Zv∈RNp×C,其中特征维度为 C C C。教师和学生都使用相同的Frozen Visual Encoder。
Visual Projector:包含两个MLP层,带有激活函数GELU,将 Z v Z_v Zv映射到文本嵌入空间 H v ∈ R N p × D H_v\in\mathbb{R}^{N_p\times D} Hv∈RNp×D,其中 D D D是嵌入空间维度。
Large Language Model (LLM):用于实现对视觉和语言信息的统一认识。给定视觉嵌入的多模态输入 H v H_v Hv和文本嵌入 H t H_t Ht,LLM将二者的连接 H = [ H v , H t ] H=[H_v,H_t] H=[Hv,Ht]作为输入,生成输出 y = [ y p , y v , y r ] = { y t } t = 1 T y=[y_p,y_v,y_r]=\{y_t\}_{t=1}^T y=[yp,yv,yr]={yt}t=1T,其中 y p , y v , y r y_p,y_v,y_r yp,yv,yr分别代表prompt、视觉和响应tokens, T T T代表所有预测token的长度。本文将教师和学生的LLM分别称为l-LLM和s-LLM。
3.2 Training Scheme of Teacher Model L-MLLM
Pre-Training:Visual Encoder和l-LLM冻结,只有Projector被优化,用于对齐视觉和文本特征。训练过程中,使用图像-描述对,对应的目标公式表示为:
L
reg
=
−
∑
m
=
1
M
log
ϕ
l
(
y
m
∣
y
<
m
)
(1)
\mathcal{L}_\text{reg}=-\sum_{m=1}^M \log\phi_l(y_m|y_{<m})\tag{1}
Lreg=−m=1∑Mlogϕl(ym∣y<m)(1)
其中
M
M
M表示预测的响应tokens的长度,
ϕ
l
(
y
m
∣
y
<
m
)
\phi_l(y_m|y_{<m})
ϕl(ym∣y<m)表示响应tokens
y
m
y_m
ym的分布基于先前预测
y
<
m
y_{<m}
y<m的条件。
Supervised Fine-Tuning:该阶段保持Visual Encoder的冻结,旨在联合优化Projector和l -LLM,以增强教师模型l-MLLM的理解和教学跟随能力。训练过程中,利用高质量的对话数据集,训练目标 L S F T \mathcal{L}_{SFT} LSFT如Eq.1所示。
3.3 Framework of LLaVA-KD
3.3.1 MLLM-Oriented KD Strategy
Multimodal Distillation (MDist):考虑到MLLM本质上是利用LLM进行多模态信息理解和推理,我们沿用LLM的朴素蒸馏方法,即利用KL散度(KLD)对响应预测进行蒸馏。训练目标可以定义为:
L
res
=
∑
m
=
1
M
KLD
(
ϕ
l
(
y
m
∣
y
<
m
)
,
ϕ
s
(
y
m
∣
y
<
m
)
)
=
∑
m
=
1
M
∑
j
=
1
V
ϕ
l
(
Y
j
∣
y
<
m
)
log
(
ϕ
l
(
Y
j
∣
y
<
m
)
ϕ
s
(
Y
j
∣
y
<
m
)
)
(2)
\begin{aligned} \mathcal{L}_\text{res}&=\sum_{m=1}^M \text{KLD}(\phi_l(y_m|y_{<m}),\phi_s(y_m|y_{<m})) \\ &=\sum_{m=1}^M \sum_{j=1}^V \phi_l(Y_j|y_{<m})\log (\frac{\phi_l(Y_j|y_{<m})}{\phi_s(Y_j|y_{<m})})\tag{2} \end{aligned}
Lres=m=1∑MKLD(ϕl(ym∣y<m),ϕs(ym∣y<m))=m=1∑Mj=1∑Vϕl(Yj∣y<m)log(ϕs(Yj∣y<m)ϕl(Yj∣y<m))(2)
其中
M
M
M表示响应tokens的长度,
V
V
V表示词汇空间。
ϕ
l
\phi_l
ϕl和
ϕ
s
\phi_s
ϕs表示l-MLLM和s-MLLM的参数,
ϕ
l
(
Y
j
∣
y
<
m
)
\phi_l(Y_j|y_{<m})
ϕl(Yj∣y<m)和
ϕ
s
(
Y
j
∣
y
<
m
)
\phi_s(Y_j|y_{<m})
ϕs(Yj∣y<m)表示由l-MLLM和s-MLLM预测的词汇
Y
j
Y_j
Yj出现在token
y
m
y_m
ym的概率。
同时,视觉表征对于LLM的多模态理解也至关重要。因此,进一步优化教师和学生输出视觉分布之间的KLD:
L
vis
=
∑
k
=
1
K
∑
j
=
1
V
ϕ
l
(
Y
j
∣
y
<
k
)
log
(
ϕ
l
(
Y
j
∣
y
<
k
)
ϕ
s
(
Y
j
∣
y
<
k
)
)
(3)
\mathcal{L}_\text{vis}=\sum_{k=1}^K \sum_{j=1}^V \phi_l(Y_j|y_{<k})\log (\frac{\phi_l(Y_j|y_{<k})}{\phi_s(Y_j|y_{<k})})\tag{3}
Lvis=k=1∑Kj=1∑Vϕl(Yj∣y<k)log(ϕs(Yj∣y<k)ϕl(Yj∣y<k))(3)
其中
K
K
K表示视觉token的长度,
ϕ
l
(
Y
j
∣
y
<
k
)
\phi_l(Y_j|y_{<k})
ϕl(Yj∣y<k)和
ϕ
s
(
Y
j
∣
y
<
k
)
\phi_s(Y_j|y_{<k})
ϕs(Yj∣y<k)分别表示由l-MLLM和s-MLLM预测的词汇
Y
j
Y_j
Yj出现在token
y
k
y_k
yk的概率。
本文在DPT阶段用MDist来对齐s-MLLM中的视觉和语言特征,加强了s-MLLM的理解能力。
Relation Distillation (RDist):为了使学生模型能够捕获视觉信息中的复杂关系,本文从LLM输出的视觉tokens中构造自相关矩阵。通过优化矩阵之间的相似性,学生模型继承了教师模型理解视觉tokens之间错综复杂关系的能力。为了达到这个目的,首先计算自相关矩阵如下:
{
R
v
s
=
y
v
s
⊗
y
v
s
∈
R
N
p
×
N
p
R
v
t
=
y
v
t
⊗
y
v
t
∈
R
N
p
×
N
p
\begin{equation} \left\{ \begin{aligned} R_v^s &= y_v^s\otimes y_v^s\in\mathbb{R}^{N_p\times N_p} \\ R_v^t &= y_v^t\otimes y_v^t\in\mathbb{R}^{N_p\times N_p} \end{aligned} \right.\tag{4} \end{equation}
{RvsRvt=yvs⊗yvs∈RNp×Np=yvt⊗yvt∈RNp×Np(4)
其中
⊗
\otimes
⊗表示矩阵乘法,
y
v
s
y_v^s
yvs和
y
v
t
y_v^t
yvt表示学生和教师的视觉logits,
N
p
N_p
Np表示视觉token的数量。目标是最大化
R
v
s
R_v^s
Rvs和
R
v
t
R_v^t
Rvt的余弦相似度:
L
rel
=
1
−
Cos
(
R
v
s
,
R
v
t
)
=
1
−
R
v
s
⋅
R
v
t
∣
∣
R
v
s
∣
∣
∣
∣
R
v
t
∣
∣
(5)
\mathcal{L}_\text{rel}=1-\text{Cos}(R_v^s,R_v^t)=1-\frac{R_v^s\cdot R_v^t}{||R_v^s||\ ||R_v^t||}\tag{5}
Lrel=1−Cos(Rvs,Rvt)=1−∣∣Rvs∣∣ ∣∣Rvt∣∣Rvs⋅Rvt(5)
用RDist可以进一步提升s-MLLM在DPT和DFT阶段的视觉表达能力。
3.3.2 Three-stage Distillation Scheme
Distilled Pre-Training (DPT):该阶段的主要目的是将视觉特征投射到文本嵌入空间。在LLaVA-KD中,使用蒸馏过程来像l-MLLM一样更好地对齐视觉和文本信息。
具体地,冻结visual encoder和s-MLLM中的LLM,只优化projector。在训练过程中,通过MDist最小化学生模型和教师模型在视觉和反应的输出分布上的差异。
为了优化这个目标,可以进一步促进投影的视觉特征与文本嵌入的对齐。此外,我们利用RDist来增强视觉特征的质量,使学生模型能够借鉴教师模型处理复杂视觉信息的能力。
总的来说,除了优化自回归预测结果,还使用了MDist和RDist:
L
DPT
=
L
PT
+
α
L
res
+
β
L
vis
+
γ
L
rel
(6)
\mathcal{L}_\text{DPT}=\mathcal{L}_\text{PT}+\alpha\mathcal{L}_\text{res}+\beta\mathcal{L}_\text{vis}+\gamma\mathcal{L}_\text{rel}\tag{6}
LDPT=LPT+αLres+βLvis+γLrel(6)
Supervised Fine-Tuning (SFT):这个阶段遵循l-MLLM训练阶段的通用SFT过程(Sec.3.2)。通过联合训练Projector和l-LLM,使模型具有推理能力和指令跟踪能力。训练目标由Eq.1定义,表示为 L SFT ′ \mathcal{L}_\text{SFT}' LSFT′。
Distilled Fine-Tuning (DFT):该阶段的主要目标是进一步增强s-MLLM的理解和推理能力。具体来说,采用了MDist和RDist相结合的蒸馏策略,冻结了Visual Encoder,优化了Projector和sLLM。通过使用MDist,可以对s-MLLM中的小规模s-LLM进行充分优化,从而更好地模拟大规模l-LLM的推理能力。和RDist可以进一步促进s-MLLM学习l-MLLM的视觉表征。
总体训练目标可以表示为:
L
D
F
T
=
L
reg
+
α
′
L
res
+
β
′
L
vis
+
γ
′
L
rel
(7)
\mathcal{L}_{DFT}=\mathcal{L}_\text{reg}+\alpha'\mathcal{L}_\text{res}+\beta'\mathcal{L}_\text{vis}+\gamma'\mathcal{L}_\text{rel}\tag{7}
LDFT=Lreg+α′Lres+β′Lvis+γ′Lrel(7)
其中
L
reg
\mathcal{L}_\text{reg}
Lreg表示自回归预测loss。