文本表征的Scaling Laws:Scaling Laws For Dense Retrieval
论文链接:Scaling Laws For Dense Retrieval (arxiv.org)
代码链接:jingtaozhan/DRScale (github.com)
本次解读的论文是清华大学和小红书合作完成的一篇SIGIR2024 best paper论文。该论文最大的意义在于告诉工业界如何选择文本表征的编码器参数、训练数据量的数据,对于工业界迭代文本表征具有借鉴意义。
前言
文本表征在工业界具有非常大的作用,应用场景主要是搜索排序相关性,估计该论文会成为文本表征迭代的指导手册。不得不说,小红书对这篇论文的宣传真下成本,中国大陆机构首次!小红书搜索与清华合作获得SIGIR2024最佳论文奖 (qq.com)
贡献
- 信息检索领域常用的指标,例如NDCG,因其离散的特性,难以稳定平滑地表明效果变化趋势。提出一种新的信息检索指标—对比熵(contrastive entropy)
- 系统地分析了模型参数量、训练数据量、数据标注质量对最终效果的影响。
问题定义
给定一个语料库,检索任务需要为一个特定的query,查找最相关的句子。定义
q
q
q和
p
p
p分别是query和句子,
f
(
⋅
;
θ
)
f(\cdot;\theta)
f(⋅;θ)表示从文本到表征的映射函数,参数为
θ
\theta
θ,
q
q
q和
p
p
p的相关性得分为:
s
(
q
,
p
;
θ
)
=
<
f
(
q
;
θ
)
,
f
(
p
;
θ
)
>
(1)
s(q,p;\theta)=<f(q;\theta),f(p;\theta)>\tag{1}
s(q,p;θ)=<f(q;θ),f(p;θ)>(1)
训练数据包括多个query,每个query会标注有多条相关的passage,训练数据集的构成为
(
q
i
,
p
i
+
)
i
=
1
n
{(q_i,p_i^+)}_{i=1}^n
(qi,pi+)i=1n,其中n表示训练数据包括
n
n
n个query,
q
i
q_i
qi表示第i个query,
p
i
+
p_i^+
pi+表示第i个query的相关passages(即正样本)。
模型架构
整体模型架构采用普遍的BERT模型,提取文本的表征向量。本文主要测试了不同参数大小的BERT模型,来比较检索任务效果的好坏。
对于英文场景,选择了Google开源的24个BERT模型,参数量从0.5M到82M;
对于中文场景,选择了baidu开源的ERNIE系列。
为了确保对比的公正性,不同尺寸的模型输出文本表征向量维度都通过全连接层映射到768维。
训练数据
优选选择大规模的数据集进行训练。
对于英文场景,选择[MS MARCO Passage Ranking dataset](MS MARCO (microsoft.github.io)),该数据集有8.8 million的passage,以及0.5 million的query;
对于中文场景,选择[T2Ranking](THUIR/T2Ranking: T2Ranking: A large-scale Chinese benchmark for passage ranking. (github.com)),该数据集有超过2 million的passage,以及300k的query。
训练设置
重点是训练loss的选择。
对于每个query-passage pair
(
q
i
,
p
i
+
)
(q_i, p_i^+)
(qi,pi+),随机采样不相关的passage作为该query的负样本,优化如下对比损失:
L
(
θ
)
=
−
1
B
∑
i
=
1
B
l
o
g
e
x
p
(
s
(
q
i
,
p
i
+
;
θ
)
)
e
x
p
(
s
(
q
i
,
p
i
+
;
θ
)
)
+
∑
j
e
x
p
(
s
(
q
i
,
p
j
−
;
θ
)
)
(2)
L(\theta)=-\frac{1}{B}\sum_{i=1}^Blog\frac{exp(s(q_i,p_i^+;\theta))}{exp(s(q_i,p_i^+;\theta))+\sum_jexp(s(q_i,p_j^-;\theta))} \tag{2}
L(θ)=−B1i=1∑Blogexp(s(qi,pi+;θ))+∑jexp(s(qi,pj−;θ))exp(s(qi,pi+;θ))(2)其中
B
B
B表示batch size,
p
j
−
{p_j^-}
pj−表示负样本passage集合,
s
(
q
,
p
;
θ
)
s(q,p;\theta)
s(q,p;θ)表示query和passage的相关性得分。
对于所有不同尺寸的模型,利用同样的训练数据微调10000轮,每一轮均采样固定的256个负样本。
这里我感觉有待商榷,不同尺寸的模型可能收敛速度不一样,均训练10000轮可能有问题
评估策略
信息检索领域常用的指标有NDCG@K和MAP@K。但这些指标有如下缺陷:
- 离散,对模型、数据微小的变化不够敏感
- 裁剪,仅能考虑与query相距top k的passage之间的性能差异,k+1及以后的排序关系无法考虑到。
为了弥补上述的问题,提出对比熵的指标,如下
− l o g e x p ( s ( q i , p i + ; θ ) ) e x p ( s ( q i , p i + ; θ ) ) + ∑ j e x p ( s ( q i , p j − ; θ ) ) (3) -log\frac{exp(s(q_i,p_i^+;\theta))}{exp(s(q_i,p_i^+;\theta))+\sum_jexp(s(q_i,p_j^-;\theta))} \tag{3} −logexp(s(qi,pi+;θ))+∑jexp(s(qi,pj−;θ))exp(s(qi,pi+;θ))(3)
和loss公式一模一样。有疑问的是,如何选择指标中的负样本?论文中没有提及,代码也没有公布
为了验证对比熵和常用的离散指标——NDCG@K和MAP@K之间的关系,将他们同时绘制在图表中,如下,能够发现对比熵和通用的离散指标之间具备明显的正相关关系,说明可以用对比熵替代NDCG@K、MAP@K及RECALL@1000等指标。
SCALING LAWS
重头戏来了,文本表征的scaling laws到底是怎样的?
模型大小
scaling laws公式为:
L
(
N
)
=
(
A
N
)
α
+
δ
N
(4)
L(N)=(\frac{A}{N})^{\alpha}+\delta_N\tag{4}
L(N)=(NA)α+δN(4)
,其中
N
N
N表示模型大小,
L
(
⋅
)
L(\cdot)
L(⋅)表示测试集上的对比熵,参数
A
A
A,
α
\alpha
α,
δ
N
\delta_N
δN表示需要拟合的系数。
中英文场景数据集下,绘制的图表如下:
模型参数到对比熵的散点图可以利用一条直线来拟合,如上图。对中英文场景数据集拟合的系数指分别是:
,其中
R
2
R^2
R2表示拟合的好坏。
训练集大小
scaling laws公式为:
L
(
D
)
=
(
B
D
)
β
+
δ
D
(5)
L(D)=(\frac{B}{D})^{\beta}+\delta_D\tag{5}
L(D)=(DB)β+δD(5),其中D表示训练集大小,
B
B
B、
β
\beta
β、
δ
D
\delta_D
δD表示需要拟合的系数。
中英文场景数据集下,绘制的图表如下:
训练集大小到对比熵的散点图可以利用一条直线来拟合,如上图,不同数据集拟合的系数见上图。
训练集质量
在[MS MARCO Passage Ranking dataset](MS MARCO (microsoft.github.io))数据集中,采用三种额外的方式来标注数据
- Inverse Cloze Task (ICT):首先选中passage,其次从passage里选择一个sentence,作为query,构建query-passage pair,用于训练文本表征模型。
- Supervised Generation Models:利用docT5query模型来为每一个passage生成多个query,构建query-passage pair,用于训练文本表征模型
- Large Language Models (LLMs):利用ChatGLM3模型来为每一个passage生成多个query,构建query-passage pair,用于训练文本表征模型
三种方式的标注质量是越来越高的,其对应的效果比较如下:
横坐标为训练集大小,纵坐标为对比熵,可以发现,标注质量非常影响最终信息检索效果。
模型-数据联合的scaling laws
L
(
N
,
D
)
=
[
(
A
N
)
α
β
+
B
D
]
β
+
δ
(6)
L(N,D)=\left[(\frac{A}{N})^{\frac{\alpha}{\beta}}+\frac{B}{D}\right]^{\beta}+\delta \tag{6}
L(N,D)=[(NA)βα+DB]β+δ(6),其中
N
N
N,
D
D
D分别表示模型大小和数据集大小,
A
A
A,
B
B
B,
α
\alpha
α,
β
\beta
β,
δ
\delta
δ表示需要拟合的参数。对应的拟合图表如下
scaling laws用于成本估计
利用参数量为
N
N
N的模型,训练集大小为
D
D
D的训练推理成本为:
Z
(
N
,
D
)
=
Z
d
a
t
a
⋅
D
+
Z
t
r
a
i
n
⋅
N
+
Z
i
n
f
e
r
⋅
N
(7)
Z(N,D)=Z_{data}\cdot D+Z_{train}\cdot N+Z_{infer}\cdot N\tag{7}
Z(N,D)=Zdata⋅D+Ztrain⋅N+Zinfer⋅N(7)
,其中
Z
d
a
t
a
Z_{data}
Zdata,
Z
t
r
a
i
n
Z_{train}
Ztrain,
Z
i
n
f
e
r
Z_{infer}
Zinfer表示数据标注,模型训练,模型推理的成本系数。
成本系数的一般计算逻辑如下:
- Z d a t a Z_{data} Zdata,单条query-passage pair需要0.6美金
-
Z
t
r
a
i
n
Z_{train}
Ztrain,通过浮点数运算量,来估计需要的GPU小时,乘上GPU小时费用,得出需要的训练成本。
- 单台A100 一小时的浮点数运算量为 312 T × 3600 × 25 % 312T\times 3600\times25\% 312T×3600×25%,其中 312 T 312T 312T表示A100 在TP32下的每秒浮点运算量,3600表示1小时3600秒,25%表示GPU实际利用率一般25%。
- transformer模型训练的浮点数运算量为 10000 × ( 30 + 2 × 60 ) × 256 × 6 10000\times (30+2\times 60)\times256\times6 10000×(30+2×60)×256×6, 10000 10000 10000表示单次训练仅包括10000个step,每个step的batchsize为256,batch内每条样本包括1个query(30个token)和正、负passage样本(各60个token),训练阶段浮点数运算量数值上约等于6倍的transformer处理的token量。
- A100 GPU单小时租赁费用为3.93美金
- Z t r a i n = 10000 × ( 30 + 2 × 60 ) × 256 × 6 312 T × 3600 × 25 % × 3.93 = 3.22 × 1 0 − 8 Z_{train}=\frac{10000\times (30+2\times 60)\times256\times6}{312T\times 3600\times25\%}\times 3.93=3.22\times10^{-8} Ztrain=312T×3600×25%10000×(30+2×60)×256×6×3.93=3.22×10−8
-
Z
i
n
f
e
r
Z_{infer}
Zinfer,类似于
Z
t
r
a
i
n
Z_{train}
Ztrain的计算逻辑
- 单台A100 一小时的浮点数运算量为 312 T × 3600 × 25 % 312T\times 3600\times25\% 312T×3600×25%
- transformer模型推理的浮点数运算量为 30 × 1 0 12 × 512 × 2 30\times 10^{12}\times 512\times 2 30×1012×512×2, 30 × 1 0 12 30\times10^{12} 30×1012表示一小时对30trillion的网页进行索引, 512 512 512表示每个网页需要512个token,推理阶段浮点数运算量数值上约等于2倍的transformer处理的token量。
- A100 GPU单小时租赁费用为3.93美金
- Z t r a i n = 30 × 1 0 12 × 512 × 2 312 T × 3600 × 25 % × 3.93 = 0.43 Z_{train}=\frac{30\times 10^{12}\times 512\times 2}{312T\times 3600\times25\%}\times 3.93=0.43 Ztrain=312T×3600×25%30×1012×512×2×3.93=0.43
- 大规模推理真贵啊!!!
标注-训练的成本曲线
横坐标为模型参数,曲线表示固定成本,增加模型参数,减少训练样本数,得到的效果变化图。
- 钱越多,效果越好,money is all your need
- 固定成本下,模型大小一般都大于1B
标注-训练-推理的成本曲线
加上推理之后,需要的费用直线迅速飙升,固定成本下,最优的的模型大小一般在10M、20M、30M、40M等,模型参数量都比较小。
可以发现:
- 推理太贵了
总结
这篇文章在调研了信息检索领域文本表征的scaling law,分析了模型参数量、训练数据量、训练质量对最终效果的影响。最重要的是,分析了标注、训练、推理对综合成本的关系,对工业界实践具备重要的指导性意义。