当前位置: 首页 > article >正文

[论文阅读] Knowledge Fusion of Large Language Models

Knowledge Fusion of Large Language Models (FuseLLM)


Methodology

整体Pipeline如下图所示
在这里插入图片描述

不同的动物代表不同的LLM。左边第一,第二分别是Ensemble以及Weight Merging方法。最右侧为本文提出的FuseLLM。

  • Ensemble: 融合多个models的预测结果,比如求加权平均等。
  • Weight Merging:在权重/参数层面融合,但通常仅限于相同架构的模型。
  • FuseLLM 主要思想为:融合多个LLMs(可以是不同架构的)的probabilistic matrices,得到Fused Matrix后,喂给Target Model,起到知识蒸馏的作用

这里面会涉及到一个关键:

  • 不同LLM,使用的Tokenizer可能不同,设置也可能不一样(如 model_max_length ),分词结果可能不一样(比如对同一个句子分词,tokens总数不同),使用的Vocabulary也可能不一样,因此生成的probabilistic matrix在维度上可能有所不同,如何解决对齐问题?这个实际上就是 token alignment 问题,本文中着重描述了解决方案。

Definition of Problem

假设我们有一个语料库 C \mathcal{C} C K K K个source LLMs, 对于文本 t ∈ C t \in \mathcal{C} tC,经过 K K K个LLM处理,可以得到对应的概率分布矩阵 probabilistic distribution matrix { P t θ j } j = 1 K \{\mathbf{P}^{\theta_j}_t\}^K_{j=1} {Ptθj}j=1K,其中 θ j \theta_j θj表示第 j j j个LLM的参数。我们要做的就是将这 K K K概率分布矩阵融合,然后送入Target LLM中辅助训练:
P t = F u s i o n ( P t θ 1 , P t θ 2 , … , P t θ K ) , \begin{align} \mathbf{P}_t=\mathbb{F}\mathrm{usion}(\mathbf{P}_t^{\theta_1},\mathbf{P}_t^{\theta_2},\ldots,\mathbf{P}_t^{\theta_K}), \end{align} Pt=Fusion(Ptθ1,Ptθ2,,PtθK),
P t \mathbf{P}_t Pt即得到的融合概率分布矩阵(Fused Representation Matrix)。

为了将 P t \mathbf{P}_t Pt迁移至target model中,我们假设 Q t \mathbf{Q}_t Qt为其输出的representation matrix,则Knowledge Fusion的训练目标为:
L F u s i o n = − E t ∼ C [ D ( Q t , P t ) ] . \begin{align} \mathcal{L}_{\mathrm{Fusion}}=-\mathbb{E}_{t\sim\mathcal{C}}\left[\mathbb{D}(\mathbf{Q}_t,\mathbf{P}_t)\right]. \end{align} LFusion=EtC[D(Qt,Pt)].
其中 D ( ⋅ , ⋅ ) \mathbb{D}(\cdot, \cdot) D(,)表示差异性函数,具体实现可以是KL散度。
整体的模型损失如下:
L = λ L C L M + ( 1 − λ ) L F u s i o n . \begin{align}\mathcal{L}=\lambda\mathcal{L}_{\mathrm{CLM}}+(1-\lambda)\mathcal{L}_{\mathrm{Fusion}}.\end{align} L=λLCLM+(1λ)LFusion.
其中 L C L M \mathcal{L}_{\mathrm{CLM}} LCLM表示最原始的ground-truth之间的损失, λ \lambda λ为系数。

实现细节

Token Alignment

我们假设有两个LLM,使用不同的tokenizer。对同一段文本分词,得到的token序列不同,长度也不同:
在这里插入图片描述
如上图,用DeepSeek和TinyLlama各自的分词器分词,得到的结果完全不一样。最终预测的概率分布矩阵也不一样。

Token-Level Alignment

为了解决这个问题,FuseLLM采用基于最小编辑距离Minimal Edit Distance(MinED)的动态规划策略,在token-level实现对齐,以下图为例:
在这里插入图片描述
具体实现的源代码other.py如下:


def dtw(series_1, series_2, norm_func=np.linalg.norm):
    """Use dynamic time wrapping to align to tokenizers, modified from:
    https://github.com/talcs/simpledtw/blob/master/simpledtw.py"""

    """

    Parameters
    ----------
    series_1: List[str]
        blending_input_tokens
    series_2: List[str]
        base_input_tokens
    norm_func: function
        edit distance evaluation between 2 tokens


    Return Values
    ----------
    matches: List[Tuple]
        matched pairs between a base token and a blending token
    matrix[-1, -1]: int 
        the total cost for mapping the two series of tokens
    mappings_series_1: List[List]
        mapping from blending tokens to base tokens
        eg: [0], [1, 2], [3, 4, 5], [6], ...
    mappings_series_2: List[List]
        mapping from base tokens to blending tokens
    matrix: List[int]
        the dtw matrix

    """

    matrix = np.zeros((len(series_1) + 1, len(series_2) + 1))
    matrix[0, :] = np.inf
    matrix[:, 0] = np.inf
    matrix[0, 0] = 0
    for i, vec1 in enumerate(series_1):
        for j, vec2 in enumerate(series_2):
            cost = norm_func(vec1, vec2)
            matrix[i + 1, j + 1] = cost + min(
                matrix[i, j + 1], matrix[i + 1, j], matrix[i, j]
            )
    matrix = matrix[1:, 1:]
    i = matrix.shape[0] - 1
    j = matrix.shape[1] - 1
    matches = []
    mappings_series_1 = [list() for v in range(matrix.shape[0])]
    mappings_series_2 = [list() for v in range(matrix.shape[1])]
    while i > 0 or j > 0:
        matches.append((i, j))
        mappings_series_1[i].append(j)
        mappings_series_2[j].append(i)
        option_diag = matrix[i - 1, j - 1] if i > 0 and j > 0 else np.inf
        option_up = matrix[i - 1, j] if i > 0 else np.inf
        option_left = matrix[i, j - 1] if j > 0 else np.inf
        move = np.argmin([option_diag, option_up, option_left])
        if move == 0:
            i -= 1
            j -= 1
        elif move == 1:
            i -= 1
        else:
            j -= 1
    matches.append((0, 0))
    mappings_series_1[0].append(0)
    mappings_series_2[0].append(0)
    matches.reverse()
    for mp in mappings_series_1:
        mp.reverse()
    for mp in mappings_series_2:
        mp.reverse()

    return matches, matrix[-1, -1], mappings_series_1, mappings_series_2, matrix


Logit-Level Alignment

利用该对齐结果,将不同LLMs得到的representation matrix对齐。关键代码other.py如下:


def transform_step_logits(
    base_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
    blending_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
    base_model_vocab: Dict[str, int],
    base_model_input_ids: List[int],
    blending_model_input_ids: List[int],
    blending_model_per_step_logits: List[List[float]],
    blending_model_per_step_indices: List[List[int]],
    vocab_align_type: str = "hard",
    blending_to_base_mapping: Dict[str, str] = None,
):
    """Align blending model per step logits & indices with base model."""


    """

    Parameters
    ----------
    base_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase
    blending_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase
    base_model_vocab: Dict[str, int]
        mapping token to id using vocabulary of base model
    base_model_input_ids: List[int]
        ids of base_model_input_tokens
    blending_model_input_ids: List[int]
        ids of blending_model_input_tokens
    blending_model_per_step_logits: List[List[float]]
        logits for each token in blending_model_input_tokens 
    blending_model_per_step_indices: List[List[int]]
        indices corresponding to logits for each token in blending_model_input_tokens 
    vocab_align_type: str = "hard"
    blending_to_base_mapping: Dict[str, str] = None
        mapping each blending token to its corresponding base token 


    Return Values
    ----------
    aligned_blending_model_per_step_logits: List[List[float]]
        aligned logits for each token in base_model_input_tokens for the FuseLLM training
    aligned_blending_model_per_step_indices: List[List[int]]
        aligned indices corresponding aligned logits for each token in base_model_input_tokens for the FuseLLM training. 
        Use the base model vocabulary to look up the token.
    """



    base_model_tokens = base_model_tokenizer.convert_ids_to_tokens(base_model_input_ids)
    blending_model_tokens = blending_model_tokenizer.convert_ids_to_tokens(
        blending_model_input_ids
    )
    base_model_special_token = TOKENIZER_TO_SPECIAL_TOKEN[
        base_model_tokenizer.__class__
    ]
    blending_model_special_token = TOKENIZER_TO_SPECIAL_TOKEN[
        blending_model_tokenizer.__class__
    ]


    def dist_fn(a, b):
        """Calculate editdistance between two tokens, a is from blending model, b is from base model."""
        aa = a.replace(blending_model_special_token, "")
        bb = b.replace(base_model_special_token, "")
        dist = editdistance.eval(aa, bb)
        return dist

    _, _, _, base_to_blending, _ = dtw(
        blending_model_tokens, base_model_tokens, norm_func=dist_fn
    )
    aligned_blending_model_per_step_logits, aligned_blending_model_per_step_indices = (
        [],
        [],
    )
    for i, blending_idx in enumerate(base_to_blending):
        aligned_blending_model_per_step_logit = []
        aligned_blending_model_per_step_index = []
        if len(blending_idx) == 1:  # one base token map to one blending token
            j = blending_idx[0]
            base_token = base_model_tokens[i]
            blending_token = blending_model_tokens[j].replace(
                blending_model_special_token, base_model_special_token
            )
            if (
                (
                    blending_model_tokenizer.__class__
                    == transformers.GPTNeoXTokenizerFast
                    or blending_model_tokenizer.__class__
                    == transformers.GPT2TokenizerFast
                )
                and i == 0
                and base_token.startswith(base_model_special_token)
                and not blending_token.startswith(base_model_special_token)
            ):
                blending_token = (
                    base_model_special_token + blending_token
                )  # special case for mpt
            if vocab_align_type == "hard":
                if (
                    base_token == blending_token
                ):  # find the aligned mapping, use the corresponding logits
                    # the logits and indices at this step
                    for blending_logit, blending_index in zip(
                        blending_model_per_step_logits[j],
                        blending_model_per_step_indices[j],
                    ):
                        # the token corresponds to the logit and indices
                        blending_t = blending_model_tokenizer.convert_ids_to_tokens(
                            [blending_index]
                        )[0].replace(
                            blending_model_special_token, base_model_special_token
                        )
                        if blending_t in base_model_vocab:
                            aligned_index = base_model_vocab[
                                blending_t
                            ]  # the index of the token in base model vocab
                            if (
                                aligned_index
                                not in aligned_blending_model_per_step_index
                            ):
                                aligned_blending_model_per_step_index.append(
                                    aligned_index
                                )
                                aligned_blending_model_per_step_logit.append(
                                    blending_logit
                                )
                else:  # find error aligned mapping, use the one-hot logits
                    aligned_blending_model_per_step_index.append(
                        base_model_vocab[base_token]
                    )
                    aligned_blending_model_per_step_logit.append(1.0)
            elif vocab_align_type == "soft":
                if (base_token == blending_token) or (
                    blending_token in blending_to_base_mapping
                    and base_token == blending_to_base_mapping[blending_token]
                ):  # find the aligned mapping, use the corresponding logits
                    # the logits and indices at this step
                    for blending_logit, blending_index in zip(
                        blending_model_per_step_logits[j],
                        blending_model_per_step_indices[j],
                    ):
                        # the token corresponds to the logit and indices
                        blending_t = blending_model_tokenizer.convert_ids_to_tokens(
                            [blending_index]
                        )[0].replace(
                            blending_model_special_token, base_model_special_token
                        )
                        blending_t = blending_to_base_mapping[blending_t]
                        if blending_t in base_model_vocab:
                            aligned_index = base_model_vocab[
                                blending_t
                            ]  # the index of the token in base model vocab
                            if (
                                aligned_index
                                not in aligned_blending_model_per_step_index
                            ):
                                aligned_blending_model_per_step_index.append(
                                    aligned_index
                                )
                                aligned_blending_model_per_step_logit.append(
                                    blending_logit
                                )
                        else:
                            logger.warning(
                                f"blending_t: {blending_t} not in base_model_vocab!"
                            )
                else:  # find error aligned mapping, use the one-hot logits
                    aligned_blending_model_per_step_index.append(
                        base_model_vocab[base_token]
                    )
                    aligned_blending_model_per_step_logit.append(1.0)
            else:
                logger.warning(
                    f"The vocab_align_type: '{vocab_align_type}' is not support!"
                )
                raise NotImplementedError
        else:  # one base token map to multiple blending token, in this case only fit base token. use the one-hot logits
            base_token = base_model_tokens[i]
            aligned_blending_model_per_step_index.append(base_model_vocab[base_token])
            aligned_blending_model_per_step_logit.append(1.0)
        aligned_blending_model_per_step_indices.append(
            aligned_blending_model_per_step_index
        )
        aligned_blending_model_per_step_logits.append(
            aligned_blending_model_per_step_logit
        )
    return (
        aligned_blending_model_per_step_logits,
        aligned_blending_model_per_step_indices,
    )

Fusion Strategies:

得到对齐的representation matrix以后,由于不同的LLM具有不同的性能,可以使用概率分布矩阵与ground-truth之间的交叉熵损失(CE loss)评估LLM的优劣,再根据此判断选择哪些LLM参与知识融合。CE loss越低,证明模型效果更好。具体而言,作者提出了两种Fusion Strategy:

  1. MinCE: 仅选择CE loss最小的representation matrix用于知识融合。
  2. AvgCE: 基于各个模型的CE loss,采用多个representation matrices的加权平均,用于知识融合。

整体的算法流程如下:
在这里插入图片描述

  • 注:这里Eq.5实际是本文中上述的Eq.3

一些思考

本文的思路是将多个LLMs输出的概率分布矩阵视为知识,将知识融合后,送入target LLM进行训练,以达到融合多种模型知识,提升目标模型性能的目的。但在实际的实现当中我们会发现,logit-level的alignment,要么是直接采用blending_model_per_step_logits/indices,要么直接用ground-truth one-hot作为融合后的知识,而没有充分评估logit-level中,blending/base_model_per_step_logits之间的差异性。为此,Probabilistic Token Alignment for Large Language Model Fusion提出采用Probabilistic Token Alignment方法,在logit-level实现alignment。


http://www.kler.cn/a/537287.html

相关文章:

  • 将Deepseek接入pycharm 进行AI编程
  • 【Leetcode 热题 100】1143. 最长公共子序列
  • 部署LLM模型到云端
  • 判断您的Mac当前使用的是Zsh还是Bash:echo $SHELL、echo $0
  • React组件开发技巧:如何优雅地传递Props?
  • 哈希(Hashing)在 C++ STL 中的应用
  • 【GeeRPC】Day1:服务端与消息编码
  • C++服务端开发注意事项总结
  • 苹果公司宣布正式开源 Xcode 引擎 Swift Build145
  • 清影2.0(AI视频生成)技术浅析(一)
  • 嵌入式面试题 C/C++常见面试题整理_7
  • UE5.1蓝图节点禁用编译
  • 基于RLS的自适应滤波器设计与Matlab实现
  • Win10 部署llama Factory 推荐教程和遇到的问题
  • 【2】Cisco SD-WAN 组件介绍
  • idea中git版本回退
  • JVM 中的四类引用:强、软、弱、虚
  • 24、深入理解与使用 Netty:Java 高性能网络编程的利器
  • (2024|ICLR,LLM 幻觉,事实性,知识层次)DoLa:通过对比层解码可提高大型语言模型的事实性
  • 2025.2.6 数模AI智能体大更新,更专业的比赛辅导,同提示词效果优于gpt-o1/o3mini、deepseek-r1满血
  • 【鸿蒙开发】第二十四章 AI - Core Speech Kit(基础语音服务)
  • Maven概述与安装
  • SpringBoot动力节点杨利军
  • git使用指南(保姆贴)
  • apisix的real-ip插件使用说明
  • Sentinel——Spring Boot 应用接入 Sentinel 后内存开销增长计算方式