当前位置: 首页 > article >正文

大模型系列——投机解码:Prompt Lookup Decoding代码解读

官方代码见:GitHub - apoorvumang/prompt-lookup-decoding

UPDATE 2: This method is now available in vLLM as well by setting speculative_model="[ngram]" 🥳

UPDATE: This has been added to the transformers library. Please see this for a code example, or simply add prompt_lookup_num_tokens=10 to your model.generate(...) call.

TLDR: We modify speculative decoding where we replace the draft model with simple string matching in the prompt to generate candidate token sequences. This results in significant speedups (2x-4x) in input-grounded tasks, with no effect on output quality. This method can be used with any decoder model without model changes or external datastore, and with both greedy and sampling techniques.

Intuition: In several LLM use cases where you're doing input grounded generation (summarization, document QA, multi-turn chat, code editing), there is high n-gram overlap between LLM input (prompt) and LLM output. This could be entity names, phrases, or code chunks that the LLM directly copies from the input while generating the output. Prompt lookup exploits this pattern to speed up autoregressive decoding in LLMs.

def find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=10):
    input_length = input_ids.size(1)

    for ngram_size in range(max_ngram_size, 0, -1):
        # Extract the last n tokens as our search ngram
        ngram = input_ids[0, -ngram_size:].tolist()

        # Create sliding windows of size ngram_size
        windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)

        # Convert ngram to a tensor for comparison
        ngram_tensor = torch.tensor(ngram, device=input_ids.device).unsqueeze(0)

        # Find where the windows match the ngram
        matches = (windows == ngram_tensor).all(dim=2)

        # Get the indices of matches
        match_indices = matches.nonzero(as_tuple=True)[1]

        # Iterate through match indices to find a valid continuation
        for idx in match_indices:
            start_idx = idx + ngram_size
            end_idx = start_idx + num_pred_tokens
            # Ensure we don't go beyond the length of input_ids and avoid self-match
            if end_idx <= input_length and start_idx < input_length - ngram_size:
                return input_ids[0, start_idx:end_idx]

    # If no match is found, return an empty tensor
    return torch.tensor([], dtype=torch.long, device=input_ids.device)

ODOs/Thoughts/Future work

  • There's probably better ways to do string matching than the current one, and there are several obvious things to improve eg. what to do when there are multiple matches? Whats the ideal length of continuation?
  • We haven't yet tried sampling, although there's no reason it shouldn't work.
    • Here, one additional thing to test would be whether prompt lookup while sampling can affect hallucination rates, since this artifically increases probability of sampling exact sequences from input (this was suggest by my colleague Shwetha S)
  • Testing actual FLOPs impact and tradeoffs is needed
  • Also need to figure out best hyperparams - 3 and 10 were chosen on very little testing
  • It would be an interesting challenge to design the "best lookup function" for decoding, could even be a competition?

这个方法可能还是有问题的,正如坐着所说,可能存在幻觉,不一定ngram匹配上的就能加速


http://www.kler.cn/a/449309.html

相关文章:

  • 递归查询全量分页数据问题
  • GitCode 光引计划投稿|MilvusPlus:开启向量数据库新篇章
  • SRE 与 DevOps记录
  • 华院计算参与项目再次被《新闻联播》报道
  • 智能体实战(需求分析助手)一、需求概述及迭代规划
  • 解锁动态规划的奥秘:从零到精通的创新思维解析(3)
  • 使用pdf2zh遇到的问题
  • 海天味业:困境突围,再寻增长
  • CV实战项目----YOLO
  • SoftMoE:From sparse to soft mixtures of experts
  • Postman集合转JMeter脚本
  • AI应用-本地模型实现AI生成PPT(简易版)
  • C++ 函数编程题
  • 远程医疗:科技助力健康触手可及
  • linux socket编程之udp_dict_serve服务端--引入配置文件
  • 阿里云技术公开课直播预告:基于阿里云 Elasticsearch 构建 AI 搜索和可观测 Chatbot
  • python学习——洛谷 [NOIP1998 提高组] 拼数 两种方法
  • 盒子模型(内边距的设置)
  • 计算机组成原理的学习笔记(6)-- 存储器·其一 SRAM/DRAM/ROM/主存储器的初步认识
  • 学习threejs,scene.overrideMaterial全局材质效果
  • pycharm无法识别conda环境(已解决)
  • Jmeter对图片验证码的处理【超详细】
  • Html:点击图标链接发起QQ临时会话
  • 【LeetCode每日一题】——415.字符串相加
  • 【C++——迭代器】
  • 安科瑞能源物联网平台在老旧小区用电安全改造中的应用与优势