当前位置: 首页 > article >正文

推理阶段不同batch size对大模型推理结果的影响

大模型推理阶段,进行batch inference批处理推理解码,会像预期的那样速度很快推完吗?会不会有什么问题?

batch inference推理的结果居然会和一条一条推理结果差的很远?!!
Batch Decoding/Inference of LLMs will cause different outputs with different batch size?!

行为表现

测试中可以发现,即使是在推理阶段,不是在训练阶段,对于多模态大模型VLLM,如果推理时候为了加速推理,不是一条一条数据让模型推,而是一次推理batch_size>1条数据,对比batch_size=1和batch_size>1的结果,会发现这两份结果分布是不一样的,why?

  • 可能表现为,batch_size=1测下来的模型推理结果基本上都是对的,例如本身让模型回复“是”或者“否”,很短的回答,模型回答的挺好的,不仅正确而且简短没有废话
  • 调试好了之后大规模数据上batch inference批处理,batch_size>1,发现推理没有变快,推理结果还有问题,准确性大幅下降,模型甚至给出了很多长回复(例如模型开始解释,或者开始模棱两可说为什么不能回答这个问题)

是否是模型随机种子、采样影响

推理阶段影响模型随机性的参数

推理阶段影响模型随机性的参数,可以控制的主要有3个,分别为temperature、topK和topP:

  • temperature:温度参数影响输出的概率分布,当温度接近0时,模型会变得非常确定性,几乎总是选择具有最高概率的下一个词,从而产生更加一致但可能较为重复或缺乏创意的输出。如果希望减少随机性,可以将温度设得低一些
  • topK:只考虑最有可能的k个词汇,并从中进行随机选取,设置一个较小的k值可以帮助减小随机性,因为只有少量高概率的词汇会被选中
  • topP:它基于累积概率来决定候选词汇集。具体来说,模型会选择累积概率达到p阈值的最少数量的词汇作为候选。p值通常设定为0.8或0.9左右,这意味着大约80%或90%的累计概率被覆盖。如果想进一步降低随机性,可以提高这个值

有的人会觉得batch inference结果的差异是模型本身随机性导致的,从分布里面采样,采出来结果不可能每次都一样。为了去掉随机性干扰,可以把temperature设置为0,topK和topP都做类似的设置。当然最好的方式,直接设置do_sample或者是sampling=False,解码时不进行随机采样,这样结果按理就是确定的。

结果表明,即使sampling=False,batch inference的结果还是会和batch_size=1不一样

batch inference结果受到哪些因素影响

在do_sample=False的情况下,已知的,会影响推理结果的因素主要有:

  • batch size,不同batch size大小结果会很不一样
  • padding side (left/right),无论左还是右都不能消除,左和右的推理结果也会很不一样
  • padding value (<unk><eos><bos><0>),用不同的padding值都不能消除差异,不同padding的值也会右影响
  • dtype,数据类型也会有影响(有人怀疑是RMSNorm导致的浮点溢出的问题),FP32、BF16、FP16等都会有影响,即使和原始模型的dtype一样,batch inference解码结果也会不一样
  • KV-cache,是否打开KV-cache也会影响,但是关闭KV-cache并不能解决问题

已知的,会影响到的模型包括所有使用旋转位置编码的模型

解决方法:无,目前还没有修复,可以参考下面的github上gante的comment

缓解的方式

在使用多模型模型MiniCPM-V-2.6尝试,保证每个batch里面输入的token长度是一样的(都是问同一个问题,并且图片的数量一样),这种情况下就不需要padding,得到的batch inference的结果统计下来和batch_size=1的结果是一致

参考

  1. 探究inference阶段batch inference差异的论文:The batch size can affect inference results,Openreview
  2. github上的深入探究:huggingface的探究

http://www.kler.cn/a/313283.html

相关文章:

  • PyTorch深度学习与企业级项目实战-预训练语言模型GPT
  • RabbitMQ高效的消息队列中间件原理及实践
  • 申论1_概括、分析
  • flutter 发版的时候设置版本号
  • 标准C++ 字符串
  • 新版 idea 编写 idea 插件时,启动出现 ClassNotFound
  • 部分解决FDTD安装后,matlab指令fopen报错
  • C++初阶学习——探索STL奥秘——标准库中的priority_queue与模拟实现
  • Go语言笔记
  • 什么是HTTP DDOS,如何防护
  • 【Android】浅析MVC与MVP
  • 低代码可视化工具--vue条件判断v-if可视化设置-代码生成器
  • 【Android】sendevent和getevent
  • 江西金融发展集团通过ZStack Zaku容器云推进数字化转型
  • 前端框架对比与选择:如何在现代Web开发中做出最佳决策
  • 系统架构设计师|数据库基础-006
  • Docker 里面按照ifconfig
  • AppStore评论爬虫
  • 了解深度学习,张量,线性代数,激活函数的概念
  • 计算机网络传输层---课后综合题
  • Day24笔记-异常和错误
  • JVM 调优篇8 调优案例5- 逃逸分析
  • docker 安装mongo 集群
  • 4款音频转文字在线转换工具帮你解锁新的记录模式。
  • Python 装饰器使用详解
  • 【Java集合】LinkedList