当前位置: 首页 > article >正文

《Long Context Compression with Activation Beacon》笔记

Activation Beacon出自智源与人大在2024年1月放在arxiv上的论文《Long Context Compression with Activation Beacon》(v1版的题目:Soaring from 4K to 400K: Extending LLM’s Context with Activation Beacon)。它引入了Beacon token将上下文信息蒸馏到其激活(activations);在压缩时将文本切分成固定大小的块(chunk),并根据压缩比 α \alpha α进一步将chunk分成更小的单元,beacon token插入在每个单元后面;LLM每次编码一个chunk,在自注意力机制执行过程中将chunk的信息蒸馏到beacon token的激活信息(activation)中,逐步地对整个长文本完成压缩过程,论文实验结果表明此方法可以有效加速推理过程并节省KV cache内存占用。

WeChatWorkScreenshot_51ab7dcb-bd63-4c27-ba89-8c01af79d511

实现思路

如论文图1所示意,对输入文本 X = [ x 1 , … , x n ] X = [x_1, \ldots, x_n] X=[x1,,xn],将其划分为相同尺寸w(如1024)的chunk:
[ x 1 , … , x n ] → Partition [ X 1 , … X ⌈ n / w ⌉ ] , X i = [ x ( i − 1 ) w + 1 , … , x i w ] = [ x 1 i , … , x w i ] [x_1, \ldots, x_n] \xrightarrow{\text{Partition}} [X_1, \ldots X_{\lceil n/w \rceil}], X_i=[x_{(i-1)w+1}, \ldots,x_{iw}] = [x^i_1, \ldots, x^i_w] [x1,,xn]Partition [X1,Xn/w],Xi=[x(i1)w+1,,xiw]=[x1i,,xwi]
对每一个chunk X i X_i Xi,使用一个压缩比 α i \alpha_i αi(w可由 α i \alpha_i αi整除),即将chunk划分到大小为 α \alpha α的更细粒度单元,一组共 k i = w / α i k_i=w/\alpha_i ki=w/αi个beacon token: B i = [ ⟨ b ⟩ 1 i , … , ⟨ b ⟩ k i i ] B_i=[\langle \mathbf{b} \rangle^i_1, \ldots, \langle \mathbf{b} \rangle^i_{k_i}] Bi=[⟨b1i,,bkii]被交替地插入到这些单元后。
X i → Interleave  B i X i ′ = [ x 1 i , … , x α i i , ⟨ b ⟩ 1 i , … , x w − α i + 1 i , … , x w i , ⟨ b ⟩ k i i ] X_i \xrightarrow{\text{Interleave} \ B_i} X^{\prime}_i = [x^i_1, \ldots, x^i_{\alpha_i}, \langle \mathbf{b} \rangle^i_1, \ldots, x^i_{w-\alpha_i +1}, \ldots, x^i_w, \langle \mathbf{b} \rangle^i_{k_i}] XiInterleave Bi Xi=[x1i,,xαii,b1i,,xwαi+1i,,xwi,bkii]
LLM逐一地编码这些chunk,在自注意力机制过程中将每个chunk的信息压缩到beacon token的激活(activations)中,在编码了 X i ′ X^{\prime}_i Xi后,将 X i X_i Xi的所有原始token(raw tokens)的激活信息给丢弃,但一直保留并累积beacon token B i B_i Bi的激活信息;在编码下一个chunk X i + 1 ′ X^{\prime}_{i+1} Xi+1时,LLM将累积的beacon激活作为原始上下文 X ≤ i X_{\le i} Xi的代理。

WeChatWorkScreenshot_6d367a38-54aa-4262-8f65-ffa19c9e7ddb

如论文图2所示,Activation Beacon与一般的LLM相比只做少许修改。对于第i个chunk X i ′ X^{\prime}_i Xi,编码过程可以写作:
LLM ⁡ ( ⟨ b ⟩ 1 i , … , ⟨ b ⟩ k i − 1 i − 1 ⏟ beacon activations accumulated from  X < i ′ , x 1 i , … , x α i i , ⟨ b ⟩ 1 i , … , x w − α i + 1 i , … , x w i , ⟨ b ⟩ k i i ⏟ the current chunk  X i ′ ) , \operatorname{LLM}(\underbrace{\langle\mathbf{b}\rangle_1^i, \ldots,\langle\mathbf{b}\rangle_{k_{i-1}}^{i-1}}_{\text {beacon activations accumulated from } X_{<i}^{\prime}}, \underbrace{x_1^i, \ldots, x_{\alpha_i}^i,\langle\mathbf{b}\rangle_1^i, \ldots, x_{w-\alpha_i+1}^i, \ldots, x_w^i,\langle\mathbf{b}\rangle_{k_i}^i}_{\text {the current chunk } X_i^{\prime}}), LLM(beacon activations accumulated from X<i b1i,,bki1i1,the current chunk Xi x1i,,xαii,b1i,,xwαi+1i,,xwi,bkii),
也就是LLM的输入是前面chunk的激活累积和当前chunk需要被编码的token的混合物。设D表示LLM的隐藏层尺寸, H ∈ R ( w + k i ) × D \boldsymbol{H} \in \mathbb{R}^{(w+k_i) \times D} HR(w+ki)×D表示LLM任意层的self attention的输入隐藏状态。我们会区分raw token和beacon token:
I r = { j ∣ x j i ≠ ⟨ b ⟩ } , I b = { j ∣ x j i = ⟨ b ⟩ } ; H r = H [ I r ] , H b = H [ I b ] . \mathbb{I}^r=\left\{j \mid x_j^i \neq\langle\mathbf{b}\rangle\right\}, \quad \mathbb{I}^b=\left\{j \mid x_j^i=\langle\mathbf{b}\rangle\right\} ; \quad \boldsymbol{H}^r=\boldsymbol{H}\left[\mathbb{I}^r\right], \quad \boldsymbol{H}^b=\boldsymbol{H}\left[\mathbb{I}^b\right] . Ir={jxji=b},Ib={jxji=b};Hr=H[Ir],Hb=H[Ib].
将隐状态变成query, key, value:
Q r = W Q r H r , K r = W K r H r , V r = W V r H r , Q b = W Q b H b , K b = W K b H b , V b = W V b H b , \begin{array}{lll} \boldsymbol{Q}^r=\boldsymbol{W}_Q^r \boldsymbol{H}^r, & \boldsymbol{K}^r=\boldsymbol{W}_K^r \boldsymbol{H}^r, & \boldsymbol{V}^r=\boldsymbol{W}_V^r \boldsymbol{H}^r, \\ \boldsymbol{Q}^b=\boldsymbol{W}_Q^b \boldsymbol{H}^b, & \boldsymbol{K}^b=\boldsymbol{W}_K^b \boldsymbol{H}^b, & \boldsymbol{V}^b=\boldsymbol{W}_V^b \boldsymbol{H}^b, \end{array} Qr=WQrHr,Qb=WQbHb,Kr=WKrHr,Kb=WKbHb,Vr=WVrHr,Vb=WVbHb,
上式中 W ∗ r \boldsymbol{W}^r_* Wr是LLM原来的投影矩阵, W ∗ b \boldsymbol{W}^b_* Wb是新引入的只处理beacon token的投影矩阵。再将raw token和beacon token的query/key/value状态来得到 Q , K , V ∈ R ( w + k i ) × D \boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V} \in \mathbb{R}^{(w+k_i) \times D} Q,K,VR(w+ki)×D
Q [ I r ] = Q r , Q [ I b ] = Q b , K [ I r ] = K r , K [ I b ] = K b , V [ I r ] = V r , V [ I b ] = V b \boldsymbol{Q}\left[\mathbb{I}^r\right]= \boldsymbol{Q}^r,\boldsymbol{Q}\left[\mathbb{I}^b\right]= \boldsymbol{Q}^b, \quad \boldsymbol{K}\left[\mathbb{I}^r\right]= \boldsymbol{K}^r,\boldsymbol{K}\left[\mathbb{I}^b\right]= \boldsymbol{K}^b, \quad \boldsymbol{V}\left[\mathbb{I}^r\right]= \boldsymbol{V}^r,\boldsymbol{V}\left[\mathbb{I}^b\right]= \boldsymbol{V}^b Q[Ir]=Qr,Q[Ib]=Qb,K[Ir]=Kr,K[Ib]=Kb,V[Ir]=Vr,V[Ib]=Vb
最后,用标准方法计算self attention:
A = softmax ( mask ( Q { K a c ; K } T D ) ) , V = A { V a c ; V } \boldsymbol{A} = \text{softmax}\left(\text{mask} \left( \frac{\boldsymbol{Q}\{\boldsymbol{K}^{ac}; \boldsymbol{K} \}^T }{\sqrt{D}} \right)\right), \quad \boldsymbol{V} = \boldsymbol{A}\{\boldsymbol{V}^{ac};\boldsymbol{V} \} A=softmax(mask(D Q{Kac;K}T)),V=A{Vac;V}
上式中的 { ⋅ ; ⋅ } \{ \cdot ; \cdot\} {;}表示矩阵连接, K a c , V a c ∈ R m i − 1 × D \boldsymbol{K}^{ac}, \boldsymbol{V}^{ac} \in \mathbb{R}^{m_{i-1} \times D} Kac,VacRmi1×D是从之前的chunk累积得到的beacon token的激活参数, m i − 1 = ∑ j = 1 i − 1 k j m_{i-1} = \sum^{i-1}_{j=1} k_j mi1=j=1i1kj, mask就是causal attention mask。在self attention过程中,所有的token与其他token进行交互,使得beacon tokens的key和value( K b , V b \boldsymbol{K}^{b}, \boldsymbol{V}^{b} Kb,Vb)蒸馏了 X i X_i Xi的上下文信息,它们会增量累积:
K a c = { K a c ; K b } , V a c = { V a c ; V b } \boldsymbol{K}^{ac} = \{\boldsymbol{K}^{ac}; \boldsymbol{K}^{b}\}, \boldsymbol{V}^{ac} = \{\boldsymbol{V}^{ac};\boldsymbol{V}^{b} \} Kac={Kac;Kb},Vac={Vac;Vb}


### 下面代码是activation beacon在实现时,interleave插入beacon token的代码,位于model_beacon.py的Memory类的_step函数
						input_len = input_ids.shape[1]
            if beacon_size > 0:
                # insert beacon tokens in between raw tokens,对应论文中的式(2)
                input_ids_with_beacons = input_ids.new_full((input_ids.shape[0], input_len + beacon_size), self.beacon_token.item())
                raw_token_indices = torch.arange(input_ids_with_beacons.shape[1], device=input_ids.device)
                interleave_start_idx = compression_ratio - self._interleave_remainder
                raw_token_indices = raw_token_indices[raw_token_indices % (compression_ratio + 1) != interleave_start_idx].unsqueeze(0).expand_as(input_ids)
                input_ids_with_beacons = input_ids_with_beacons.scatter(dim=1, index=raw_token_indices, src=input_ids)
                input_ids = input_ids_with_beacons
                # attention mask
                ## beacon token是参与attention的,所以默认值为1
                attention_mask_with_beacons = attention_mask.new_full((attention_mask.shape[0], attention_mask.shape[1] + beacon_size), 1)
                attention_mask_with_beacons = attention_mask_with_beacons.scatter(dim=1, index=raw_token_indices, src=attention_mask)
                attention_mask = attention_mask_with_beacons
                # labels
                if labels is not None:
                    ## beacon token不参与loss的计算,所以标签为-100
                    labels_with_beacons = labels.new_full((labels.shape[0], labels.shape[1] + beacon_size), -100)
                    labels_with_beacons = labels_with_beacons.scatter(dim=1, index=raw_token_indices, src=labels)
                    labels = labels_with_beacons

训练过程

Activation Beacon的学习目标是在当前chunk上下文和之前压缩信息的条件下提高生成质量,损失函数如下:
min ⁡ Θ b . ∑ i = 2 ⌈ N / w ⌉ ∑ j = 1 w Pr ⁡ ( x j i ∣ ⟨ b ⟩ 1 1 , … , ⟨ b ⟩ k i − 1 i − 1 , x 1 i , … x j − 1 i ; Θ , Θ b ) . \min _{\boldsymbol{\Theta}^b} . \sum_{i=2}^{\lceil N / w\rceil} \sum_{j=1}^w \operatorname{Pr}\left(x_j^i \mid\langle\mathbf{b}\rangle_1^1, \ldots,\langle\mathbf{b}\rangle_{k_{i-1}}^{i-1}, x_1^i, \ldots x_{j-1}^i ; \mathbf{\Theta}, \boldsymbol{\Theta}^b\right) . Θbmin.i=2N/wj=1wPr(xjib11,,bki1i1,x1i,xj1i;Θ,Θb).
上式中 Θ \mathbf{\Theta} Θ是LLM的参数,在训练过程中被冻结, Θ b \mathbf{\Theta^b} Θb是每一层中beacon token对应的投影矩阵 W ∗ b \boldsymbol{W}^b_* Wb和beacon token 的embedding e ⟨ b ⟩ \mathbf{e}_{\langle b \rangle} eb(所有beacon token使用共享embedding ),训练时beacon token不参与损失计算(标签被设置为-100)因为它们仅用作压缩。

训练时第i个chunk的压缩比 α i \alpha_i αi是随机地从{2, 4, 8, 16, 32}中选取的,意在让模型灵活地支持不同的压缩粒度。而在推理时可以根据下游任务选择一个压缩比并应用到所有chunk。

训练过程分为预训练和微调,消融实验表明两个阶段对模型效果都有提升。

注意,activation beacon默认的方式是将其交替地插入在原始上下文中(代码中的interleave),论文做消融实验时尝试将beacon token全部放在chunk的最后时效果是会下降的(代码中的append)。

WeChatWorkScreenshot_de754912-a289-4c7b-b3c9-a1d1b3331676

注:Activation Beacon与MemoRAG是同一个团队出的,理解这篇思路之后,就能更好地理解MemoRAG的记忆模型了。(对比这两篇论文对应的记忆模型的代码,几乎是一样的,有点奇怪为什么memorag没有引用这篇文章,也没有对代码做说明。因为不理解memorag的记忆模型的代码,通过搜索关键字beacon搜到了这篇论文)。


http://www.kler.cn/a/513408.html

相关文章:

  • [Easy] leetcode-500 键盘行
  • 基于python+Django+mysql鲜花水果销售商城网站系统设计与实现
  • 为什么相关性不是因果关系?人工智能中的因果推理探秘
  • 接口自动化测试
  • 计算机系统原理:一些断言
  • 记录一下OpenCV Contrib 编译踩的坑
  • mybatis(19/134)
  • 【HarmonyOS NEXT】华为分享-碰一碰开发分享
  • 初创企业或中小企业如何进行海外市场问卷调查?
  • HTML中的`<!DOCTYPE html>`是什么意思?
  • Java爬虫调用API时的异常处理策略
  • 算法---冒泡法
  • 推荐一个小而美的 Toast 插件 (一键复制使用)
  • Dart语言的学习路线
  • YOLOv10-1.1部分代码阅读笔记-dist.py
  • 61,【1】BUUCTF WEB BUU XSS COURSE 11
  • 大牙的2024年创作总结
  • 求解ssp 问题建模
  • 个人职业发展与AI赋能的前端开发
  • 交换机Console密码忘记无法登录设备怎么办?
  • ubuntu16.04 VSCode下cmake+clang+lldb调试c++
  • 线程池实现
  • 36. K11364 剑法
  • Erlang语言的面向对象编程
  • 以 RFID 为钥,开启民兵装备管理的科技之门
  • linux 下tensorrt的yolov8的前向推理(python 版本)的实现