《Long Context Compression with Activation Beacon》笔记
Activation Beacon出自智源与人大在2024年1月放在arxiv上的论文《Long Context Compression with Activation Beacon》(v1版的题目:Soaring from 4K to 400K: Extending LLM’s Context with Activation Beacon)。它引入了Beacon token将上下文信息蒸馏到其激活(activations);在压缩时将文本切分成固定大小的块(chunk),并根据压缩比 α \alpha α进一步将chunk分成更小的单元,beacon token插入在每个单元后面;LLM每次编码一个chunk,在自注意力机制执行过程中将chunk的信息蒸馏到beacon token的激活信息(activation)中,逐步地对整个长文本完成压缩过程,论文实验结果表明此方法可以有效加速推理过程并节省KV cache内存占用。
实现思路
如论文图1所示意,对输入文本
X
=
[
x
1
,
…
,
x
n
]
X = [x_1, \ldots, x_n]
X=[x1,…,xn],将其划分为相同尺寸w(如1024)的chunk:
[
x
1
,
…
,
x
n
]
→
Partition
[
X
1
,
…
X
⌈
n
/
w
⌉
]
,
X
i
=
[
x
(
i
−
1
)
w
+
1
,
…
,
x
i
w
]
=
[
x
1
i
,
…
,
x
w
i
]
[x_1, \ldots, x_n] \xrightarrow{\text{Partition}} [X_1, \ldots X_{\lceil n/w \rceil}], X_i=[x_{(i-1)w+1}, \ldots,x_{iw}] = [x^i_1, \ldots, x^i_w]
[x1,…,xn]Partition[X1,…X⌈n/w⌉],Xi=[x(i−1)w+1,…,xiw]=[x1i,…,xwi]
对每一个chunk
X
i
X_i
Xi,使用一个压缩比
α
i
\alpha_i
αi(w可由
α
i
\alpha_i
αi整除),即将chunk划分到大小为
α
\alpha
α的更细粒度单元,一组共
k
i
=
w
/
α
i
k_i=w/\alpha_i
ki=w/αi个beacon token:
B
i
=
[
⟨
b
⟩
1
i
,
…
,
⟨
b
⟩
k
i
i
]
B_i=[\langle \mathbf{b} \rangle^i_1, \ldots, \langle \mathbf{b} \rangle^i_{k_i}]
Bi=[⟨b⟩1i,…,⟨b⟩kii]被交替地插入到这些单元后。
X
i
→
Interleave
B
i
X
i
′
=
[
x
1
i
,
…
,
x
α
i
i
,
⟨
b
⟩
1
i
,
…
,
x
w
−
α
i
+
1
i
,
…
,
x
w
i
,
⟨
b
⟩
k
i
i
]
X_i \xrightarrow{\text{Interleave} \ B_i} X^{\prime}_i = [x^i_1, \ldots, x^i_{\alpha_i}, \langle \mathbf{b} \rangle^i_1, \ldots, x^i_{w-\alpha_i +1}, \ldots, x^i_w, \langle \mathbf{b} \rangle^i_{k_i}]
XiInterleave BiXi′=[x1i,…,xαii,⟨b⟩1i,…,xw−αi+1i,…,xwi,⟨b⟩kii]
LLM逐一地编码这些chunk,在自注意力机制过程中将每个chunk的信息压缩到beacon token的激活(activations)中,在编码了
X
i
′
X^{\prime}_i
Xi′后,将
X
i
X_i
Xi的所有原始token(raw tokens)的激活信息给丢弃,但一直保留并累积beacon token
B
i
B_i
Bi的激活信息;在编码下一个chunk
X
i
+
1
′
X^{\prime}_{i+1}
Xi+1′时,LLM将累积的beacon激活作为原始上下文
X
≤
i
X_{\le i}
X≤i的代理。
如论文图2所示,Activation Beacon与一般的LLM相比只做少许修改。对于第i个chunk
X
i
′
X^{\prime}_i
Xi′,编码过程可以写作:
LLM
(
⟨
b
⟩
1
i
,
…
,
⟨
b
⟩
k
i
−
1
i
−
1
⏟
beacon activations accumulated from
X
<
i
′
,
x
1
i
,
…
,
x
α
i
i
,
⟨
b
⟩
1
i
,
…
,
x
w
−
α
i
+
1
i
,
…
,
x
w
i
,
⟨
b
⟩
k
i
i
⏟
the current chunk
X
i
′
)
,
\operatorname{LLM}(\underbrace{\langle\mathbf{b}\rangle_1^i, \ldots,\langle\mathbf{b}\rangle_{k_{i-1}}^{i-1}}_{\text {beacon activations accumulated from } X_{<i}^{\prime}}, \underbrace{x_1^i, \ldots, x_{\alpha_i}^i,\langle\mathbf{b}\rangle_1^i, \ldots, x_{w-\alpha_i+1}^i, \ldots, x_w^i,\langle\mathbf{b}\rangle_{k_i}^i}_{\text {the current chunk } X_i^{\prime}}),
LLM(beacon activations accumulated from X<i′
⟨b⟩1i,…,⟨b⟩ki−1i−1,the current chunk Xi′
x1i,…,xαii,⟨b⟩1i,…,xw−αi+1i,…,xwi,⟨b⟩kii),
也就是LLM的输入是前面chunk的激活累积和当前chunk需要被编码的token的混合物。设D表示LLM的隐藏层尺寸,
H
∈
R
(
w
+
k
i
)
×
D
\boldsymbol{H} \in \mathbb{R}^{(w+k_i) \times D}
H∈R(w+ki)×D表示LLM任意层的self attention的输入隐藏状态。我们会区分raw token和beacon token:
I
r
=
{
j
∣
x
j
i
≠
⟨
b
⟩
}
,
I
b
=
{
j
∣
x
j
i
=
⟨
b
⟩
}
;
H
r
=
H
[
I
r
]
,
H
b
=
H
[
I
b
]
.
\mathbb{I}^r=\left\{j \mid x_j^i \neq\langle\mathbf{b}\rangle\right\}, \quad \mathbb{I}^b=\left\{j \mid x_j^i=\langle\mathbf{b}\rangle\right\} ; \quad \boldsymbol{H}^r=\boldsymbol{H}\left[\mathbb{I}^r\right], \quad \boldsymbol{H}^b=\boldsymbol{H}\left[\mathbb{I}^b\right] .
Ir={j∣xji=⟨b⟩},Ib={j∣xji=⟨b⟩};Hr=H[Ir],Hb=H[Ib].
将隐状态变成query, key, value:
Q
r
=
W
Q
r
H
r
,
K
r
=
W
K
r
H
r
,
V
r
=
W
V
r
H
r
,
Q
b
=
W
Q
b
H
b
,
K
b
=
W
K
b
H
b
,
V
b
=
W
V
b
H
b
,
\begin{array}{lll} \boldsymbol{Q}^r=\boldsymbol{W}_Q^r \boldsymbol{H}^r, & \boldsymbol{K}^r=\boldsymbol{W}_K^r \boldsymbol{H}^r, & \boldsymbol{V}^r=\boldsymbol{W}_V^r \boldsymbol{H}^r, \\ \boldsymbol{Q}^b=\boldsymbol{W}_Q^b \boldsymbol{H}^b, & \boldsymbol{K}^b=\boldsymbol{W}_K^b \boldsymbol{H}^b, & \boldsymbol{V}^b=\boldsymbol{W}_V^b \boldsymbol{H}^b, \end{array}
Qr=WQrHr,Qb=WQbHb,Kr=WKrHr,Kb=WKbHb,Vr=WVrHr,Vb=WVbHb,
上式中
W
∗
r
\boldsymbol{W}^r_*
W∗r是LLM原来的投影矩阵,
W
∗
b
\boldsymbol{W}^b_*
W∗b是新引入的只处理beacon token的投影矩阵。再将raw token和beacon token的query/key/value状态来得到
Q
,
K
,
V
∈
R
(
w
+
k
i
)
×
D
\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V} \in \mathbb{R}^{(w+k_i) \times D}
Q,K,V∈R(w+ki)×D
Q
[
I
r
]
=
Q
r
,
Q
[
I
b
]
=
Q
b
,
K
[
I
r
]
=
K
r
,
K
[
I
b
]
=
K
b
,
V
[
I
r
]
=
V
r
,
V
[
I
b
]
=
V
b
\boldsymbol{Q}\left[\mathbb{I}^r\right]= \boldsymbol{Q}^r,\boldsymbol{Q}\left[\mathbb{I}^b\right]= \boldsymbol{Q}^b, \quad \boldsymbol{K}\left[\mathbb{I}^r\right]= \boldsymbol{K}^r,\boldsymbol{K}\left[\mathbb{I}^b\right]= \boldsymbol{K}^b, \quad \boldsymbol{V}\left[\mathbb{I}^r\right]= \boldsymbol{V}^r,\boldsymbol{V}\left[\mathbb{I}^b\right]= \boldsymbol{V}^b
Q[Ir]=Qr,Q[Ib]=Qb,K[Ir]=Kr,K[Ib]=Kb,V[Ir]=Vr,V[Ib]=Vb
最后,用标准方法计算self attention:
A
=
softmax
(
mask
(
Q
{
K
a
c
;
K
}
T
D
)
)
,
V
=
A
{
V
a
c
;
V
}
\boldsymbol{A} = \text{softmax}\left(\text{mask} \left( \frac{\boldsymbol{Q}\{\boldsymbol{K}^{ac}; \boldsymbol{K} \}^T }{\sqrt{D}} \right)\right), \quad \boldsymbol{V} = \boldsymbol{A}\{\boldsymbol{V}^{ac};\boldsymbol{V} \}
A=softmax(mask(DQ{Kac;K}T)),V=A{Vac;V}
上式中的
{
⋅
;
⋅
}
\{ \cdot ; \cdot\}
{⋅;⋅}表示矩阵连接,
K
a
c
,
V
a
c
∈
R
m
i
−
1
×
D
\boldsymbol{K}^{ac}, \boldsymbol{V}^{ac} \in \mathbb{R}^{m_{i-1} \times D}
Kac,Vac∈Rmi−1×D是从之前的chunk累积得到的beacon token的激活参数,
m
i
−
1
=
∑
j
=
1
i
−
1
k
j
m_{i-1} = \sum^{i-1}_{j=1} k_j
mi−1=∑j=1i−1kj, mask就是causal attention mask。在self attention过程中,所有的token与其他token进行交互,使得beacon tokens的key和value(
K
b
,
V
b
\boldsymbol{K}^{b}, \boldsymbol{V}^{b}
Kb,Vb)蒸馏了
X
i
X_i
Xi的上下文信息,它们会增量累积:
K
a
c
=
{
K
a
c
;
K
b
}
,
V
a
c
=
{
V
a
c
;
V
b
}
\boldsymbol{K}^{ac} = \{\boldsymbol{K}^{ac}; \boldsymbol{K}^{b}\}, \boldsymbol{V}^{ac} = \{\boldsymbol{V}^{ac};\boldsymbol{V}^{b} \}
Kac={Kac;Kb},Vac={Vac;Vb}
### 下面代码是activation beacon在实现时,interleave插入beacon token的代码,位于model_beacon.py的Memory类的_step函数
input_len = input_ids.shape[1]
if beacon_size > 0:
# insert beacon tokens in between raw tokens,对应论文中的式(2)
input_ids_with_beacons = input_ids.new_full((input_ids.shape[0], input_len + beacon_size), self.beacon_token.item())
raw_token_indices = torch.arange(input_ids_with_beacons.shape[1], device=input_ids.device)
interleave_start_idx = compression_ratio - self._interleave_remainder
raw_token_indices = raw_token_indices[raw_token_indices % (compression_ratio + 1) != interleave_start_idx].unsqueeze(0).expand_as(input_ids)
input_ids_with_beacons = input_ids_with_beacons.scatter(dim=1, index=raw_token_indices, src=input_ids)
input_ids = input_ids_with_beacons
# attention mask
## beacon token是参与attention的,所以默认值为1
attention_mask_with_beacons = attention_mask.new_full((attention_mask.shape[0], attention_mask.shape[1] + beacon_size), 1)
attention_mask_with_beacons = attention_mask_with_beacons.scatter(dim=1, index=raw_token_indices, src=attention_mask)
attention_mask = attention_mask_with_beacons
# labels
if labels is not None:
## beacon token不参与loss的计算,所以标签为-100
labels_with_beacons = labels.new_full((labels.shape[0], labels.shape[1] + beacon_size), -100)
labels_with_beacons = labels_with_beacons.scatter(dim=1, index=raw_token_indices, src=labels)
labels = labels_with_beacons
训练过程
Activation Beacon的学习目标是在当前chunk上下文和之前压缩信息的条件下提高生成质量,损失函数如下:
min
Θ
b
.
∑
i
=
2
⌈
N
/
w
⌉
∑
j
=
1
w
Pr
(
x
j
i
∣
⟨
b
⟩
1
1
,
…
,
⟨
b
⟩
k
i
−
1
i
−
1
,
x
1
i
,
…
x
j
−
1
i
;
Θ
,
Θ
b
)
.
\min _{\boldsymbol{\Theta}^b} . \sum_{i=2}^{\lceil N / w\rceil} \sum_{j=1}^w \operatorname{Pr}\left(x_j^i \mid\langle\mathbf{b}\rangle_1^1, \ldots,\langle\mathbf{b}\rangle_{k_{i-1}}^{i-1}, x_1^i, \ldots x_{j-1}^i ; \mathbf{\Theta}, \boldsymbol{\Theta}^b\right) .
Θbmin.i=2∑⌈N/w⌉j=1∑wPr(xji∣⟨b⟩11,…,⟨b⟩ki−1i−1,x1i,…xj−1i;Θ,Θb).
上式中
Θ
\mathbf{\Theta}
Θ是LLM的参数,在训练过程中被冻结,
Θ
b
\mathbf{\Theta^b}
Θb是每一层中beacon token对应的投影矩阵
W
∗
b
\boldsymbol{W}^b_*
W∗b和beacon token 的embedding
e
⟨
b
⟩
\mathbf{e}_{\langle b \rangle}
e⟨b⟩(所有beacon token使用共享embedding ),训练时beacon token不参与损失计算(标签被设置为-100)因为它们仅用作压缩。
训练时第i个chunk的压缩比 α i \alpha_i αi是随机地从{2, 4, 8, 16, 32}中选取的,意在让模型灵活地支持不同的压缩粒度。而在推理时可以根据下游任务选择一个压缩比并应用到所有chunk。
训练过程分为预训练和微调,消融实验表明两个阶段对模型效果都有提升。
注意,activation beacon默认的方式是将其交替地插入在原始上下文中(代码中的interleave),论文做消融实验时尝试将beacon token全部放在chunk的最后时效果是会下降的(代码中的append)。
注:Activation Beacon与MemoRAG是同一个团队出的,理解这篇思路之后,就能更好地理解MemoRAG的记忆模型了。(对比这两篇论文对应的记忆模型的代码,几乎是一样的,有点奇怪为什么memorag没有引用这篇文章,也没有对代码做说明。因为不理解memorag的记忆模型的代码,通过搜索关键字beacon搜到了这篇论文)。