[论文阅读] Knowledge Fusion of Large Language Models
Knowledge Fusion of Large Language Models (FuseLLM)
Methodology
整体Pipeline如下图所示
不同的动物代表不同的LLM。左边第一,第二分别是Ensemble以及Weight Merging方法。最右侧为本文提出的FuseLLM。
- Ensemble: 融合多个models的预测结果,比如求加权平均等。
- Weight Merging:在权重/参数层面融合,但通常仅限于相同架构的模型。
- FuseLLM 主要思想为:融合多个LLMs(可以是不同架构的)的probabilistic matrices,得到Fused Matrix后,喂给Target Model,起到知识蒸馏的作用。
这里面会涉及到一个关键:
- 不同LLM,使用的Tokenizer可能不同,设置也可能不一样(如 model_max_length ),分词结果可能不一样(比如对同一个句子分词,tokens总数不同),使用的Vocabulary也可能不一样,因此生成的probabilistic matrix在维度上可能有所不同,如何解决对齐问题?这个实际上就是 token alignment 问题,本文中着重描述了解决方案。
Definition of Problem
假设我们有一个语料库
C
\mathcal{C}
C,
K
K
K个source LLMs, 对于文本
t
∈
C
t \in \mathcal{C}
t∈C,经过
K
K
K个LLM处理,可以得到对应的概率分布矩阵 probabilistic distribution matrix:
{
P
t
θ
j
}
j
=
1
K
\{\mathbf{P}^{\theta_j}_t\}^K_{j=1}
{Ptθj}j=1K,其中
θ
j
\theta_j
θj表示第
j
j
j个LLM的参数。我们要做的就是将这
K
K
K个概率分布矩阵融合,然后送入Target LLM中辅助训练:
P
t
=
F
u
s
i
o
n
(
P
t
θ
1
,
P
t
θ
2
,
…
,
P
t
θ
K
)
,
\begin{align} \mathbf{P}_t=\mathbb{F}\mathrm{usion}(\mathbf{P}_t^{\theta_1},\mathbf{P}_t^{\theta_2},\ldots,\mathbf{P}_t^{\theta_K}), \end{align}
Pt=Fusion(Ptθ1,Ptθ2,…,PtθK),
P
t
\mathbf{P}_t
Pt即得到的融合概率分布矩阵(Fused Representation Matrix)。
为了将
P
t
\mathbf{P}_t
Pt迁移至target model中,我们假设
Q
t
\mathbf{Q}_t
Qt为其输出的representation matrix,则Knowledge Fusion的训练目标为:
L
F
u
s
i
o
n
=
−
E
t
∼
C
[
D
(
Q
t
,
P
t
)
]
.
\begin{align} \mathcal{L}_{\mathrm{Fusion}}=-\mathbb{E}_{t\sim\mathcal{C}}\left[\mathbb{D}(\mathbf{Q}_t,\mathbf{P}_t)\right]. \end{align}
LFusion=−Et∼C[D(Qt,Pt)].
其中
D
(
⋅
,
⋅
)
\mathbb{D}(\cdot, \cdot)
D(⋅,⋅)表示差异性函数,具体实现可以是KL散度。
整体的模型损失如下:
L
=
λ
L
C
L
M
+
(
1
−
λ
)
L
F
u
s
i
o
n
.
\begin{align}\mathcal{L}=\lambda\mathcal{L}_{\mathrm{CLM}}+(1-\lambda)\mathcal{L}_{\mathrm{Fusion}}.\end{align}
L=λLCLM+(1−λ)LFusion.
其中
L
C
L
M
\mathcal{L}_{\mathrm{CLM}}
LCLM表示最原始的ground-truth之间的损失,
λ
\lambda
λ为系数。
实现细节
Token Alignment
我们假设有两个LLM,使用不同的tokenizer。对同一段文本分词,得到的token序列不同,长度也不同:
如上图,用DeepSeek和TinyLlama各自的分词器分词,得到的结果完全不一样。最终预测的概率分布矩阵也不一样。
Token-Level Alignment
为了解决这个问题,FuseLLM采用基于最小编辑距离Minimal Edit Distance(MinED)的动态规划策略,在token-level实现对齐,以下图为例:
具体实现的源代码other.py如下:
def dtw(series_1, series_2, norm_func=np.linalg.norm):
"""Use dynamic time wrapping to align to tokenizers, modified from:
https://github.com/talcs/simpledtw/blob/master/simpledtw.py"""
"""
Parameters
----------
series_1: List[str]
blending_input_tokens
series_2: List[str]
base_input_tokens
norm_func: function
edit distance evaluation between 2 tokens
Return Values
----------
matches: List[Tuple]
matched pairs between a base token and a blending token
matrix[-1, -1]: int
the total cost for mapping the two series of tokens
mappings_series_1: List[List]
mapping from blending tokens to base tokens
eg: [0], [1, 2], [3, 4, 5], [6], ...
mappings_series_2: List[List]
mapping from base tokens to blending tokens
matrix: List[int]
the dtw matrix
"""
matrix = np.zeros((len(series_1) + 1, len(series_2) + 1))
matrix[0, :] = np.inf
matrix[:, 0] = np.inf
matrix[0, 0] = 0
for i, vec1 in enumerate(series_1):
for j, vec2 in enumerate(series_2):
cost = norm_func(vec1, vec2)
matrix[i + 1, j + 1] = cost + min(
matrix[i, j + 1], matrix[i + 1, j], matrix[i, j]
)
matrix = matrix[1:, 1:]
i = matrix.shape[0] - 1
j = matrix.shape[1] - 1
matches = []
mappings_series_1 = [list() for v in range(matrix.shape[0])]
mappings_series_2 = [list() for v in range(matrix.shape[1])]
while i > 0 or j > 0:
matches.append((i, j))
mappings_series_1[i].append(j)
mappings_series_2[j].append(i)
option_diag = matrix[i - 1, j - 1] if i > 0 and j > 0 else np.inf
option_up = matrix[i - 1, j] if i > 0 else np.inf
option_left = matrix[i, j - 1] if j > 0 else np.inf
move = np.argmin([option_diag, option_up, option_left])
if move == 0:
i -= 1
j -= 1
elif move == 1:
i -= 1
else:
j -= 1
matches.append((0, 0))
mappings_series_1[0].append(0)
mappings_series_2[0].append(0)
matches.reverse()
for mp in mappings_series_1:
mp.reverse()
for mp in mappings_series_2:
mp.reverse()
return matches, matrix[-1, -1], mappings_series_1, mappings_series_2, matrix
Logit-Level Alignment
利用该对齐结果,将不同LLMs得到的representation matrix对齐。关键代码other.py如下:
def transform_step_logits(
base_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
blending_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
base_model_vocab: Dict[str, int],
base_model_input_ids: List[int],
blending_model_input_ids: List[int],
blending_model_per_step_logits: List[List[float]],
blending_model_per_step_indices: List[List[int]],
vocab_align_type: str = "hard",
blending_to_base_mapping: Dict[str, str] = None,
):
"""Align blending model per step logits & indices with base model."""
"""
Parameters
----------
base_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase
blending_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase
base_model_vocab: Dict[str, int]
mapping token to id using vocabulary of base model
base_model_input_ids: List[int]
ids of base_model_input_tokens
blending_model_input_ids: List[int]
ids of blending_model_input_tokens
blending_model_per_step_logits: List[List[float]]
logits for each token in blending_model_input_tokens
blending_model_per_step_indices: List[List[int]]
indices corresponding to logits for each token in blending_model_input_tokens
vocab_align_type: str = "hard"
blending_to_base_mapping: Dict[str, str] = None
mapping each blending token to its corresponding base token
Return Values
----------
aligned_blending_model_per_step_logits: List[List[float]]
aligned logits for each token in base_model_input_tokens for the FuseLLM training
aligned_blending_model_per_step_indices: List[List[int]]
aligned indices corresponding aligned logits for each token in base_model_input_tokens for the FuseLLM training.
Use the base model vocabulary to look up the token.
"""
base_model_tokens = base_model_tokenizer.convert_ids_to_tokens(base_model_input_ids)
blending_model_tokens = blending_model_tokenizer.convert_ids_to_tokens(
blending_model_input_ids
)
base_model_special_token = TOKENIZER_TO_SPECIAL_TOKEN[
base_model_tokenizer.__class__
]
blending_model_special_token = TOKENIZER_TO_SPECIAL_TOKEN[
blending_model_tokenizer.__class__
]
def dist_fn(a, b):
"""Calculate editdistance between two tokens, a is from blending model, b is from base model."""
aa = a.replace(blending_model_special_token, "")
bb = b.replace(base_model_special_token, "")
dist = editdistance.eval(aa, bb)
return dist
_, _, _, base_to_blending, _ = dtw(
blending_model_tokens, base_model_tokens, norm_func=dist_fn
)
aligned_blending_model_per_step_logits, aligned_blending_model_per_step_indices = (
[],
[],
)
for i, blending_idx in enumerate(base_to_blending):
aligned_blending_model_per_step_logit = []
aligned_blending_model_per_step_index = []
if len(blending_idx) == 1: # one base token map to one blending token
j = blending_idx[0]
base_token = base_model_tokens[i]
blending_token = blending_model_tokens[j].replace(
blending_model_special_token, base_model_special_token
)
if (
(
blending_model_tokenizer.__class__
== transformers.GPTNeoXTokenizerFast
or blending_model_tokenizer.__class__
== transformers.GPT2TokenizerFast
)
and i == 0
and base_token.startswith(base_model_special_token)
and not blending_token.startswith(base_model_special_token)
):
blending_token = (
base_model_special_token + blending_token
) # special case for mpt
if vocab_align_type == "hard":
if (
base_token == blending_token
): # find the aligned mapping, use the corresponding logits
# the logits and indices at this step
for blending_logit, blending_index in zip(
blending_model_per_step_logits[j],
blending_model_per_step_indices[j],
):
# the token corresponds to the logit and indices
blending_t = blending_model_tokenizer.convert_ids_to_tokens(
[blending_index]
)[0].replace(
blending_model_special_token, base_model_special_token
)
if blending_t in base_model_vocab:
aligned_index = base_model_vocab[
blending_t
] # the index of the token in base model vocab
if (
aligned_index
not in aligned_blending_model_per_step_index
):
aligned_blending_model_per_step_index.append(
aligned_index
)
aligned_blending_model_per_step_logit.append(
blending_logit
)
else: # find error aligned mapping, use the one-hot logits
aligned_blending_model_per_step_index.append(
base_model_vocab[base_token]
)
aligned_blending_model_per_step_logit.append(1.0)
elif vocab_align_type == "soft":
if (base_token == blending_token) or (
blending_token in blending_to_base_mapping
and base_token == blending_to_base_mapping[blending_token]
): # find the aligned mapping, use the corresponding logits
# the logits and indices at this step
for blending_logit, blending_index in zip(
blending_model_per_step_logits[j],
blending_model_per_step_indices[j],
):
# the token corresponds to the logit and indices
blending_t = blending_model_tokenizer.convert_ids_to_tokens(
[blending_index]
)[0].replace(
blending_model_special_token, base_model_special_token
)
blending_t = blending_to_base_mapping[blending_t]
if blending_t in base_model_vocab:
aligned_index = base_model_vocab[
blending_t
] # the index of the token in base model vocab
if (
aligned_index
not in aligned_blending_model_per_step_index
):
aligned_blending_model_per_step_index.append(
aligned_index
)
aligned_blending_model_per_step_logit.append(
blending_logit
)
else:
logger.warning(
f"blending_t: {blending_t} not in base_model_vocab!"
)
else: # find error aligned mapping, use the one-hot logits
aligned_blending_model_per_step_index.append(
base_model_vocab[base_token]
)
aligned_blending_model_per_step_logit.append(1.0)
else:
logger.warning(
f"The vocab_align_type: '{vocab_align_type}' is not support!"
)
raise NotImplementedError
else: # one base token map to multiple blending token, in this case only fit base token. use the one-hot logits
base_token = base_model_tokens[i]
aligned_blending_model_per_step_index.append(base_model_vocab[base_token])
aligned_blending_model_per_step_logit.append(1.0)
aligned_blending_model_per_step_indices.append(
aligned_blending_model_per_step_index
)
aligned_blending_model_per_step_logits.append(
aligned_blending_model_per_step_logit
)
return (
aligned_blending_model_per_step_logits,
aligned_blending_model_per_step_indices,
)
Fusion Strategies:
得到对齐的representation matrix以后,由于不同的LLM具有不同的性能,可以使用概率分布矩阵与ground-truth之间的交叉熵损失(CE loss)评估LLM的优劣,再根据此判断选择哪些LLM参与知识融合。CE loss越低,证明模型效果更好。具体而言,作者提出了两种Fusion Strategy:
- MinCE: 仅选择CE loss最小的representation matrix用于知识融合。
- AvgCE: 基于各个模型的CE loss,采用多个representation matrices的加权平均,用于知识融合。
整体的算法流程如下:
- 注:这里Eq.5实际是本文中上述的Eq.3
一些思考
本文的思路是将多个LLMs输出的概率分布矩阵视为知识,将知识融合后,送入target LLM进行训练,以达到融合多种模型知识,提升目标模型性能的目的。但在实际的实现当中我们会发现,logit-level的alignment,要么是直接采用blending_model_per_step_logits/indices,要么直接用ground-truth one-hot作为融合后的知识,而没有充分评估logit-level中,blending/base_model_per_step_logits之间的差异性。为此,Probabilistic Token Alignment for Large Language Model Fusion提出采用Probabilistic Token Alignment方法,在logit-level实现alignment。