从矩阵乘法探秘Transformer
目录
- 前言
- 1. transformer背景
- 1.1 回顾线性代数的知识
- 1.1.1 矩阵和行向量
- 1.1.2 矩阵相乘和算子作用
- 1.1.3 从分块矩阵的乘法来看 Q K T V QK^TV QKTV
- 1.2 encoder-decoder
- 1.3 低阶到高阶语义向量的转换
- 1.4 核心的问题
- 2. transformer网络结构
- 2.1 基于KV查询的相似性计算
- 2.2 在一个低维空间做attention
- 2.3 在多个低维空间做attention
- 2.4 位置无关的全连接
- 2.5 归一化+残差网络
- 2.6 整体的变换
- 3. transformer参数和计算量
- 3.1 关于参数量
- 3.2 参数的分布
- 3.3 linear transformer
- 4. 补充—线性Attention的探索
- 结语
- 参考
前言
学习连博的另外一篇文章,从矩阵乘法的角度来理解 transformer,仅供自己参考
refer1:从矩阵乘法探秘transformer+代码讲解
refer2:深入理解transformer
以下内容来自于连博的博客:深入理解transformer,强烈建议阅读原文🤗
1. transformer背景
1.1 回顾线性代数的知识
我们先来回顾下线性代数的一些知识,这是因为《Attention Is All Your Need》这篇文章中 attention 的公式全都是一些矩阵相乘比较晦涩难懂,我们把矩阵剖解从行向量来看可能更容易理解
原文公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , … , h e a d h ) W O h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) \begin{aligned}\mathrm{Attention}(Q,K,V)=\mathrm{softmax}(\dfrac{QK^T}{\sqrt{d_k}})V \\ \mathrm{MultiHead}(Q,K,V)=\mathrm{Concat}(\mathrm{head}_1,\ldots,\mathrm{head}_h)W^{O} \\ \mathrm{head}_i=\mathrm{Attention}(QW_i^Q, KW^{K}_i,VW^V_i) \end{aligned} Attention(Q,K,V)=softmax(dkQKT)VMultiHead(Q,K,V)=Concat(head1,…,headh)WOheadi=Attention(QWiQ,KWiK,VWiV)
1.1.1 矩阵和行向量
我们先来看矩阵和行向量
定义矩阵 X ∈ R N × F X\in R^{N\times F} X∈RN×F,其中 X = ( X 11 , X 12 , … , X 1 F X 21 , X 22 , … , X 2 F ⋮ X N 1 , X N 2 , … , X N F ) X=\begin{pmatrix}X_{11}, X_{12},\ldots, X_{1F} \\ X_{21}, X_{22},\ldots, X_{2F} \\ \vdots\\ X_{N1}, X_{N2},\ldots, X_{NF}\end{pmatrix} X= X11,X12,…,X1FX21,X22,…,X2F⋮XN1,XN2,…,XNF
矩阵 X X X 每一行定义为行向量 X i X_i Xi,其中 X i = ( X i 1 , X i 2 , … , X i F ) , X i ∈ R 1 × F X_{i}=\begin{pmatrix} X_{i1}, X_{i2},\ldots, X_{iF}\end{pmatrix}, X_i \in R^{1\times F} Xi=(Xi1,Xi2,…,XiF),Xi∈R1×F
矩阵 X X X 可以看作是 N N N 个行向量堆叠的结果,也就是说 X = ( X 1 X 2 ⋮ X N ) X=\begin{pmatrix} X_1 \\ X_2 \\ \vdots \\ X_N \end{pmatrix} X= X1X2⋮XN
比如 pytorch 中的 nn.Embedding
它其实就是按照行向量来组织数据的
import torch
import torch.nn as nn
N = 3
F = 8
embed = nn.Embedding(N, F)
idx = torch.tensor([0, 1, 2])
X = embed(idx)
print(X.shape) # torch.Size([3, 8])
我们举个简单的例子,假设有 N N N 个 token, F F F 是 embedding 的维度,每行对应于一个 token 的 embedding 行向量,那么对应的矩阵如下所示:
t o k e n s = ( hello world pad pad pad ) X = ( [ 0.59 , 0.20 , 0.04 , 0.96 ] [ 0.96 , 0.30 , 0.16 , 0.63 ] [ 0.02 , 0.19 , 0.34 , 0.25 ] [ 0.02 , 0.19 , 0.34 , 0.25 ] [ 0.02 , 0.19 , 0.34 , 0.25 ] ) tokens=\begin{pmatrix} \text{hello} \\ \text{world} \\ \text{pad} \\ \text{pad} \\ \text{pad} \end{pmatrix} \\ X=\begin{pmatrix} [0.59, 0.20, 0.04, 0.96] \\ [0.96, 0.30, 0.16, 0.63] \\ [0.02, 0.19, 0.34, 0.25] \\ [0.02, 0.19, 0.34, 0.25] \\ [0.02, 0.19, 0.34, 0.25] \end{pmatrix} tokens= helloworldpadpadpad X= [0.59,0.20,0.04,0.96][0.96,0.30,0.16,0.63][0.02,0.19,0.34,0.25][0.02,0.19,0.34,0.25][0.02,0.19,0.34,0.25]
1.1.2 矩阵相乘和算子作用
我们接着看矩阵相乘和算子作用的一些知识
- 定义线性算子
A
\mathcal{A}
A
- 它可以作用到行向量上 A ( X i ) = X i A \mathcal{A}(X_i)=X_iA A(Xi)=XiA
- 也可以作用到矩阵上 A ( X ) = X A \mathcal{A}(X)=XA A(X)=XA
- 右乘矩阵等于对每个行向量逐个施加行变换
X A = ( X 1 X 2 ⋮ X N ) A = ( X 1 A X 2 A ⋮ X N A ) = ( A ( X 1 ) A ( X 2 ) ⋮ A ( X N ) ) = A ( X ) XA=\begin{pmatrix} X_1\\ X_2\\ \vdots\\ X_N \end{pmatrix}A= \begin{pmatrix} X_1 A\\ X_2 A\\ \vdots\\ X_N A \end{pmatrix}= \begin{pmatrix} \mathcal{A}(X_1) \\ \mathcal{A}(X_2) \\ \vdots\\ \mathcal{A}(X_N) \end{pmatrix}=\mathcal{A}(X) XA= X1X2⋮XN A= X1AX2A⋮XNA = A(X1)A(X2)⋮A(XN) =A(X)
代码对应于 nn.Linear
import torch
import torch.nn as nn
F = 6
linear = nn.Linear(in_features=F, out_features=F)
X_i = torch.rand(1, 6)
X = torch.rand(3, 6)
print(linear(X_i).shape) # torch.Size([1, 6])
print(linear(X).shape) # torch.Size([3, 6])
Note:pytorch/tensorflow 的代码都是按照作用于行向量来组织的
1.1.3 从分块矩阵的乘法来看 Q K T V QK^TV QKTV
我们从分块矩阵乘法看看 Q K T V QK^TV QKTV 具体做了什么事情
首先 S = Q K T S=QK^T S=QKT 是行向量两两计算点积相似性
( Q 1 Q 2 ⋮ Q N ) ( K 1 T , K 2 T , … , K N T ) = ( Q i K j T ) i j = S \begin{pmatrix} Q_{1}\\ Q_{2}\\ \vdots\\ Q_N \end{pmatrix} \begin{pmatrix} K_{1}^T, K_2^T,\ldots,K_N^T\\ \end{pmatrix}=(Q_{i}K_j^T)_{ij}=S Q1Q2⋮QN (K1T,K2T,…,KNT)=(QiKjT)ij=S
接着 S V SV SV 是对 V V V 行向量做加权求和
( S 11 , S 12 , … , S 1 N S 21 , S 22 , … , S 2 N ⋮ S N 1 , S N 2 , … , S N N ) ( V 1 V 2 ⋮ V N ) = ( ∑ j S 1 j V j ∑ j S 2 j V j ⋮ ∑ j S N j V j ) \begin{pmatrix} S_{11},S_{12},\ldots, S_{1N}\\ S_{21},S_{22},\ldots, S_{2N}\\ \vdots\\ S_{N1},S_{N2},\ldots, S_{NN}\\ \end{pmatrix} \begin{pmatrix} V_{1}\\ V_{2}\\ \vdots\\ V_N \end{pmatrix}= \begin{pmatrix} \sum\limits_{j}S_{1j}V_j\\ \sum\limits_{j}S_{2j}V_j\\ \vdots\\ \sum\limits_{j}S_{Nj}V_j \end{pmatrix} S11,S12,…,S1NS21,S22,…,S2N⋮SN1,SN2,…,SNN V1V2⋮VN = j∑S1jVjj∑S2jVj⋮j∑SNjVj
因此我们可以认为 attention 的计算首先是基于 Q , K Q,K Q,K 计算相似性,然后基于 V V V 来加权求和。其中 Q K T V QK^TV QKTV 的每个行向量都是 V V V 行向量的一个加权求和
值得注意的是:
- 论文:一般会有行/列向量两种表示方式
- 列向量表现为左乘以一个矩阵
- 左乘以一个矩阵相当于对每个列向量来施加变化
- 代码:基本都是行向量来作为数据组织的标准
- 本文:
- 向量都按照行向量的形式来组织
- 按照作用于单个行向量的方式来讲解 transformer
1.2 encoder-decoder
接着来看下 encoder-decoder
大部分 seq2seq 的任务建模为 encoder-decoder 的结构,如机器翻译、语音识别、文本摘要、问答系统等等,原论文 《Attention Is All Your Need》 中的 Transformer 结构就是 encoder-decoder 的结构,如下图所示:
Transformer 中的 encoder 用于把离散的 token 序列 x 1 , x 2 , … , x N x_1,x_2,\ldots,x_N x1,x2,…,xN 转化为语义向量序列 Y 1 , Y 2 , … , Y N Y_1,Y_2,\ldots,Y_N Y1,Y2,…,YN,一般组织为多层的网络的形式:
- 第一层:基础语义向量序列 x 1 , x 2 , … , x N → ( X 1 , X 2 , … , X N ) x_1,x_2,\ldots,x_N\rightarrow (X_{1}, X_2,\ldots, X_N) x1,x2,…,xN→(X1,X2,…,XN)
- 其它层:从低阶语义向量转化为高阶语义向量序列 ( X 1 , X 2 , … , X N ) → ( Y 1 , Y 2 , … , Y N ) (X_{1}, X_2,\ldots, X_N)\rightarrow (Y_{1}, Y_2,\ldots, Y_N) (X1,X2,…,XN)→(Y1,Y2,…,YN)
而 decoder 则基于 Y 1 , Y 2 , … , Y N Y_1,Y_2,\ldots,Y_N Y1,Y2,…,YN 自回归式的逐个 token 解码
那像翻译这类的任务通常涉及输入(源语言)和输出(目标之间的映射)之间的映射,因此需要基于 encoder-decoder 这样的架构,encoder 用于处理输入数据(源数据),decoder 则生成输出数据(目标语言)。而像 GPT、DeepSeek 这样的语言模型,它们的任务主要是生成文本(比如对话生成、文本补全等),这些任务并不需要明确的输入和输出对,而只需要基于一个上下文来生成接下来的文本,因此,单独的 decoder 就足够处理这些任务了。
这些模型也被称为自回归模型(Autoregressive Models),因为它们在生成每个词时,依赖于之前生成的词。换句话说,它们是一步一步地生成文本,每生成一个词,就把它作为上下文输入到模型中预测下一个词。例如,GPT、DeepSeek 都是通过给定一段文本(输入),然后依次预测每一个后续词语。这种逐步生成的过程使得它们成为了自回归的模型。
Note:本文主要聚焦到 encoder 部分来理解 transformer
1.3 低阶到高阶语义向量的转换
encoder 的主要工作是寻找算子 T \mathcal{T} T 将低阶的语义向量序列变换为高阶的语义向量序列即
T ( X 1 X 2 ⋮ X N ) → ( Y 1 Y 2 ⋮ Y N ) \mathcal{T}\begin{pmatrix} X_1\\ X_2\\ \vdots\\ X_N \end{pmatrix} \rightarrow\begin{pmatrix} Y_1\\ Y_2\\ \vdots\\ Y_N \end{pmatrix} T X1X2⋮XN → Y1Y2⋮YN
- 输入: X X X 低阶语义向量序列
- 输出: Y Y Y 高阶语义向量序列
- 意义
- Y i = f ( X 1 , X 2 , … , X N ) Y_{i}=f(X_{1}, X_2, \ldots, X_{N}) Yi=f(X1,X2,…,XN)
- 对低阶语义向量做加工组合处理和抽象,变换为一个高阶的语义向量序列
- 高阶语义向量考虑了上下文的语义向量表达
- 用算子作用来表达
- Y = T ( X ) Y=\mathcal{T}(X) Y=T(X)
- X ∈ R N × F , Y ∈ R N × F : R N × F → R N × F X \in R^{N\times F},Y\in R^{N\times F}: \quad R^{N\times F}\rightarrow R^{N\times F} X∈RN×F,Y∈RN×F:RN×F→RN×F
- 这个算子天然可以复合嵌套,形成多层的网络结构 Y = T L ∘ T L − 1 ∘ … ∘ T 1 ( X ) Y=\mathcal{T}_{L}\circ \mathcal{T}_{L-1}\circ \ldots \circ \mathcal{T}_{1}(X) Y=TL∘TL−1∘…∘T1(X)
1.4 核心的问题
我们现在的核心问题是如何设计 Y i = f ( X 1 , X 2 , … , X N ) Y_i=f(X_1,X_2,\ldots,X_N) Yi=f(X1,X2,…,XN),满足:
- Y 1 , … , Y N Y_1,\ldots,Y_N Y1,…,YN 能够并行得到
- Y i Y_i Yi 能够高效的建立起对周围 token 的远程依赖
我们可以先看下 RNN,看它是如何做的:
RNN 的特性如下:
- 递归语义序列 Y 0 → Y 1 → … → Y N Y_0 \rightarrow Y_1 \rightarrow \ldots \rightarrow Y_N Y0→Y1→…→YN
- Y i = tanh ( X i W + Y i − 1 U ) Y_i=\tanh(X_iW+Y_{i-1}U) Yi=tanh(XiW+Yi−1U)
- 串行
- 单方向的依赖关系,例如 Y 3 Y_3 Y3 直接依赖于 Y 2 , X 3 Y_2,X_3 Y2,X3,间接依赖于 X 1 X_1 X1
接着再看下 CNN:
CNN 的特性如下:
- Y i = ( X i − 1 , X i , X i + 1 ) W Y_i=(X_{i-1},X_{i},X_{i+1})W Yi=(Xi−1,Xi,Xi+1)W
- 并行
- 假设窗口宽度是 3,即 kernel_size = 3
- 它不能长距离依赖,一层卷积只能依赖于当前窗口内,不能对窗口外的形成依赖
- 例如 Y 3 Y_3 Y3 依赖于 X 2 , X 3 , X 4 X_2,X_3,X_4 X2,X3,X4,但它没有办法和 X 1 X_1 X1 建议起依赖关系
transformer 要解决的问题就是设计 Y i = f ( X 1 , X 2 , … , X N ) Y_i=f(X_1,X_2,\ldots,X_N) Yi=f(X1,X2,…,XN) 使得:
- Y 1 , … , Y N Y_1,\ldots,Y_N Y1,…,YN 可以做并行计算
- 同时解决长距离依赖的问题
如上图所示,我们在计算 Y 2 Y_2 Y2 时就希望对所有 token 的低阶语义向量序列都能够建议起依赖关系来
整体思路的话就是做两次矩阵的变换即 Y ′ = F ( Y ) = F ∘ A ( X ) Y^{\prime}=\mathcal{F}(Y)= \mathcal{F}\circ \mathcal{A}(X) Y′=F(Y)=F∘A(X)
-
Y
=
A
(
X
)
Y=\mathcal{A}(X)
Y=A(X)
- 第一次矩阵变换
- MultiHead Attention 多头注意力机制
- 高阶的语义等于对全部的低阶语义向量基于相似性(Attention)做加权平均
- A ( X i ) = ∑ j = 1 N s i m ( X i , X j ) X j ∑ j = 1 N s i m ( X i , X j ) \begin{aligned}\mathcal{A}(X_i) &= \frac{\sum_{j=1}^{N} sim(X_i,X_j) X_j}{\sum_{j=1}^N sim(X_i,X_j)} \end{aligned} A(Xi)=∑j=1Nsim(Xi,Xj)∑j=1Nsim(Xi,Xj)Xj
- attention = 相似性
-
Y
′
=
F
(
Y
)
Y^{\prime}=\mathcal{F}(Y)
Y′=F(Y)
- 第二次矩阵变换
- Position-wise Feedforward 前馈神经网络层
- 再施加若干线性变换
2. transformer网络结构
下面我们就来看看 transformer 的网络结构
2.1 基于KV查询的相似性计算
首先看 transformer 第一部分相似性(attention)的计算即
A ( X i ) = ∑ j = 1 N s i m ( X i , X j ) X j ∑ j = 1 N s i m ( X i , X j ) \begin{aligned}\mathcal{A}(X_i) &= \frac{\sum_{j=1}^{N} sim(X_i,X_j) X_j}{\sum_{j=1}^N sim(X_i,X_j)} \end{aligned} A(Xi)=∑j=1Nsim(Xi,Xj)∑j=1Nsim(Xi,Xj)Xj
前面我们说了 transformer 的 motivation 就是把 X i X_i Xi 这样一个低阶语义向量和周围所有的低阶语义向量去做一个相似性,然后再做一个加权平均
那我们怎么来算 A ( X i ) \mathcal{A}(X_i) A(Xi) 这个相似性呢?如果直接计算相似性会发现参数太少,模型复杂度低无法有效学习。那一种自然而然的想法就是我们投影到别的空间来计算相似度即 X i → X i W X_i \rightarrow X_iW Xi→XiW,而不是直接来计算
因此我们可以在原有公式基础上都乘以相应的矩阵 W W W,投影到更高维的空间,那此时相似性公式如下:
A ( X i ) = ∑ j = 1 N s i m ( X i W 1 , X j W 2 ) X j W 3 ∑ j = 1 N s i m ( X i W 1 , X j W 2 ) \begin{aligned} \mathcal{A}(X_i) &= \frac{\sum_{j=1}^{N} sim(X_iW_1,X_jW_{2}) X_jW_3}{\sum_{j=1}^N sim(X_iW_1,X_jW_2)} \end{aligned} A(Xi)=∑j=1Nsim(XiW1,XjW2)∑j=1Nsim(XiW1,XjW2)XjW3
如果我们记 Q i = X i W 1 , K i = X i W 2 , V i = X i W 3 Q_i=X_iW_1,K_i=X_iW_2,V_i=X_iW_3 Qi=XiW1,Ki=XiW2,Vi=XiW3 则有:
A ( X i ) = ∑ j = 1 N s i m ( Q i , K j ) V j ∑ j = 1 N s i m ( Q i , K j ) \begin{aligned}\mathcal{A}(X_i) &= \frac{\sum_{j=1}^{N} sim(Q_i,K_j) V_j}{\sum_{j=1}^N sim(Q_i,K_j)} \end{aligned} A(Xi)=∑j=1Nsim(Qi,Kj)∑j=1Nsim(Qi,Kj)Vj
那这个公式和原文中的类似,那我们怎么去理解 KV 查询呢?
首先我们把 X i X_i Xi 投影出三个向量 Q i , K i , V i Q_i,K_i,V_i Qi,Ki,Vi,其中 K , V K,V K,V 是大家熟悉的 key-value 存储, K j → V j K_j \rightarrow V_j Kj→Vj 相互对应,而 Q Q Q 是查询使用的 query 向量 Q i Q_i Qi
Q , K , V Q,K,V Q,K,V 的查询方法是 query 查询多个 key,获取多个 value,最后把这些 value 加权平均,即
Q i ⇒ ( K 1 → V 1 K 2 → V 2 ⋮ K N → V N ) ⇒ ( s i m ( Q i , K 1 ) V 1 s i m ( Q i , K 2 ) V 2 ⋮ s i m ( Q i , K N ) V N ) ⇒ ∑ j = 1 N s i m ( Q i , K j ) V j Q_i\Rightarrow \begin{pmatrix} K_{1}\rightarrow V_{1}\\ K_2\rightarrow V_2\\ \vdots\\ K_N\rightarrow V_N \end{pmatrix} \Rightarrow \begin{pmatrix} sim(Q_i,K_1)V_{1} \\ sim(Q_i,K_2)V_{2} \\ \vdots\\ sim(Q_i,K_N)V_N \end{pmatrix}\Rightarrow\sum_{j=1}^N sim(Q_i,K_j)V_j Qi⇒ K1→V1K2→V2⋮KN→VN ⇒ sim(Qi,K1)V1sim(Qi,K2)V2⋮sim(Qi,KN)VN ⇒j=1∑Nsim(Qi,Kj)Vj
那我们怎么理解呢?举个简单的例子,假设我们现在有 3 个 token 对应 3 个低阶语义向量 X 1 , X 2 , X 3 X_1,X_2,X_3 X1,X2,X3,接着我们会把 X 1 X_1 X1 投影出三个向量来分别是 Q 1 , K 1 , V 1 Q_1,K_1,V_1 Q1,K1,V1,同理 X 2 X_2 X2 投影出三个向量分别是 Q 2 , K 2 , V 2 Q_2,K_2,V_2 Q2,K2,V2, X 3 X_3 X3 投影出三个向量分别是 Q 3 , K 3 , V 3 Q_3,K_3,V_3 Q3,K3,V3。投影完成之后我们可以把 K 1 → Y 1 , K 2 → Y 2 , K 3 → Y 3 K_1\rightarrow Y_1,K_2\rightarrow Y_2,K_3\rightarrow Y_3 K1→Y1,K2→Y2,K3→Y3 当成一个 K → V K \rightarrow V K→V 查询体系
假设我们要计算低阶语义向量 X 2 X_2 X2 对应的高阶语义向量 Y 2 Y_2 Y2,那么我们先要用 Q 2 Q_2 Q2 查询 K → V K \rightarrow V K→V 体系中的 K K K 即 K 1 , K 2 , K 3 K_1,K_2,K_3 K1,K2,K3,然后分别计算它们的 s i m sim sim,最后把对应的 V V V 做一个加权平均,也就是 Y 2 = ∑ j = 1 3 s i m ( Q 2 , K j ) V j ∑ j = 1 3 s i m ( Q 2 , K j ) \begin{aligned}Y_2 &= \frac{\sum_{j=1}^{3} sim(Q_2,K_j) V_j}{\sum_{j=1}^3 sim(Q_2,K_j)} \end{aligned} Y2=∑j=13sim(Q2,Kj)∑j=13sim(Q2,Kj)Vj
也就是对应到前面的公式:
A ( X i ) = ∑ j = 1 N s i m ( Q i , K j ) V j ∑ j = 1 N s i m ( Q i , K j ) \begin{aligned}\mathcal{A}(X_i) &= \frac{\sum_{j=1}^{N} sim(Q_i,K_j) V_j}{\sum_{j=1}^N sim(Q_i,K_j)} \end{aligned} A(Xi)=∑j=1Nsim(Qi,Kj)∑j=1Nsim(Qi,Kj)Vj
那做完这些后我们会发现已经加了一些有效的参数出来了,也就是对应于 Q , K , V Q,K,V Q,K,V 产生的三个投影矩阵 W Q , W K , W V W_Q,W_K,W_V WQ,WK,WV
2.2 在一个低维空间做attention
下面看一下它的一个实现,单个行向量做 attention 的流程如下:
step 1. 把 X i X_i Xi 从 F F F 维空间投影到 D D D 维空间
- Q i = X i W Q , W Q ∈ R F × D Q_i = X_iW_Q, \quad W_Q\in R^{F \times D} Qi=XiWQ,WQ∈RF×D
- K i = X i W K , W K ∈ R F × D K_i = X_iW_K, \quad W_K\in R^{F \times D} Ki=XiWK,WK∈RF×D
- V i = X i W V , W V ∈ R F × M V_i = X_iW_V, \quad W_V\in R^{F \times M} Vi=XiWV,WV∈RF×M
step 2. Q i Q_i Qi 和所有的 K j K_j Kj 做基于点积的相似度计算
- Q i K T = Q i ( K 1 T , … , K N T ) = ( Q i K 1 T , … , Q i K N T ) Q_iK^{T}=Q_i(K^T_1, \ldots, K^T_N)=(Q_iK^T_1, \ldots, Q_iK^T_N) QiKT=Qi(K1T,…,KNT)=(QiK1T,…,QiKNT)
- Note:简单起见,我们这里省略了 scaling 缩放因子 1 D \frac{1}{\sqrt{D}} D1
step 3. 对相似度的分布做 softmax
- S = s o f t ( Q i K 1 T , … , Q i K N T ) = ( s i 1 , … , s i N ) S=\mathrm{soft}(Q_iK^T_1, \ldots, Q_iK^T_N)=(s_{i1},\ldots, s_{iN}) S=soft(QiK1T,…,QiKNT)=(si1,…,siN)
- s i , j = e x p ( Q i K j T ) ∑ j = 1 N e x p ( Q i K j T ) s_{i,j}= \dfrac{exp(Q_iK_j^T)}{\sum_{j=1}^N exp(Q_iK_j^T)} si,j=∑j=1Nexp(QiKjT)exp(QiKjT)
step 4. 加权平均
- A ( X i ) = ∑ j = 1 N s j V j = ( s i 1 , … , s i N ) ( V 1 V 2 ⋮ V N ) \mathcal{A}(X_i)=\sum_{j=1}^Ns_jV_j=(s_{i1},\ldots, s_{iN}) \begin{pmatrix} V_1 \\ V_2\\ \vdots\\ V_N\end{pmatrix} A(Xi)=∑j=1NsjVj=(si1,…,siN) V1V2⋮VN
- A ( X i ) = s o f t ( Q i K T ) V = s o f t ( X i W Q W K T X T ) X W V \mathcal{A}(X_i) = \mathrm{soft}(Q_iK^{T})V = \mathrm{soft}(X_iW_QW_K^TX^T)XW_V A(Xi)=soft(QiKT)V=soft(XiWQWKTXT)XWV
扩展到多个行向量即对应的矩阵表达式如下:
Y = A ( X ) = ( A ( X 1 ) A ( X 2 ) ⋮ A ( X N ) ) = ( s o f t ( Q 1 K T ) V s o f t ( Q 2 K T ) V ⋮ s o f t ( Q N K T ) V ) = s o f t ( Q K T ) V Y=\mathcal{A}(X) =\begin{pmatrix} \mathcal{A}(X_1)\\ \mathcal{A}(X_2)\\ \vdots\\ \mathcal{A}(X_N) \end{pmatrix} =\begin{pmatrix} \mathrm{soft}(Q_1K^T)V\\ \mathrm{soft}(Q_2K^T)V\\ \vdots \\ \mathrm{soft}(Q_NK^T)V \end{pmatrix}=\mathrm{soft}(QK^T)V Y=A(X)= A(X1)A(X2)⋮A(XN) = soft(Q1KT)Vsoft(Q2KT)V⋮soft(QNKT)V =soft(QKT)V
对应的代码实现如下:
import math
import torch.nn as nn
from torch.nn import functional as F
class SingleHeadAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.F = config["hidden_dim"] # F
self.D = config["subspace_dim"] # D
self.q_proj = nn.Linear(self.F, self.D)
self.k_proj = nn.Linear(self.F, self.D)
self.v_proj = nn.Linear(self.F, self.D)
def forward(self, x):
# x->[B, N, F]
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
y = att @ v
return y
Note:当 D ≠ F D\neq F D=F 时, A ( X ) \mathcal{A}(X) A(X) 不可用
2.3 在多个低维空间做attention
我们思考下为什么要在多个低维空间中去做 attention 呢?在单个低维空间做不就行了吗?🤔
原文中的描述是:
Multi-Head atttention allows the model to jointly attend to information from different representation subspaces at different positions.
也就是说多头注意力可以让模型从不同低维空间表达中去学习不同的语义。我们知道单个词往往有多个含义,把 F F F 维的语义向量投影到 H H H 个不同的子空间中去计算相似加权组合可能会得到完整的语义
具体的做法如下:
- 每个头做独立的 attention 变换
A
h
(
X
)
\mathcal{A}^h(X)
Ah(X)
- 假设有 H H H 个头,每个头作用的低维空间维度是 D D D
- D × H = F D\times H=F D×H=F
- 对
H
H
H 个
D
D
D 维行向量拼接,之后再做一次矩阵变换
- A ( X ) = c o n c a t ( A 1 ( X ) , A 2 ( X ) , … , A H ( X ) ) W O \mathcal{A}(X) = \mathrm{concat}(\mathcal{A}^1(X), \mathcal{A}^2(X), \ldots, \mathcal{A}^{H}(X)) W_O A(X)=concat(A1(X),A2(X),…,AH(X))WO
- W O ∈ R F × F W_O \in R^{F\times F} WO∈RF×F
- 对前面的符号简化
- 在第 j j j 个子空间做单头注意力 Y j = s i m ( Q j , K j ) V j Y^j=sim(Q^j,K^j)V^j Yj=sim(Qj,Kj)Vj
- 合并 Y = ( Y 1 , … , Y H ) W o Y=(Y^1, \ldots ,Y^H)W_o Y=(Y1,…,YH)Wo
代码实现如下:
import math
import torch.nn as nn
from torch.nn import functional as F
class SelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.H = config["n_head"] # H
self.F = config["hidden_dim"] # F
self.D = self.F // self.H # D
# 一次把 qkv 全部映射完成, 对应 $W_Q$, $W_K$, $W_V$
self.qkv_proj = nn.Linear(self.F, 3 * self.F)
# 最后的投影, 对应于 $W_O$
self.out_proj = nn.Linear(self.F, self.F)
def forward(self, x):
# x->[B, N, F]
B, N, _ = x.size()
q, k, v = self.qkv_proj(x).split(self.F, dim=-1)
# matmul 只能在最后两个维度相乘, 需要对 NxD 的矩阵相乘, 做 1,2 维度的交换
# [B, H, N, D]
q = q.view(B, N, self.H, self.D).transpose(1, 2)
k = k.view(B, N, self.H, self.D).transpose(1, 2)
v = v.view(B, N, self.H, self.D).transpose(1, 2)
# 一次把多个头的映射全部完成, 对任意的 (batch, head)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
# [B, H, N, D]
y = att @ v
# [B, N, H, D]
y = y.transpose(1, 2)
# 最后两个维度合并
y = y.contiguous().view(B, N, F)
y = self.out_proj(y)
return y
Note:代码参考自:https://github.com/karpathy/minGPT/tree/master/mingpt
代码示意图如下所示:
输入 X X X 是 (B,N,F) 的三维矩阵,经过 linear(F,3F) 算子后维度变成了 (B,N,3F),通过 split 拿到我们独立的 Q , K , V Q,K,V Q,K,V 矩阵。接着通过 view 和 transpose 之后将 Q , K , V Q,K,V Q,K,V 的最后两维变成我们关心的子矩阵,此时维度是 (B,H,N,D)
然后对不同子空间的 Q i K i V i Q^{i}K^{i}V^{i} QiKiVi 做 attention,做完之后再经过 transpose 和 view 得到我们的输出 Y Y Y,这样我们的输入 X X X 和输出 Y Y Y 是完全能够 match 上的
2.4 位置无关的全连接
前面我们已经讲完了 transformer 中的 attention 变换,下面我们来看 transformer 中另外一个变换 Feedforward 即位置无关的全连接
公式表达如下:
F ( X i ) = ( g ( X i W 1 ) + b 1 ) W 2 + b 2 \mathcal{F}(X_i)=(g(X_iW_1)+b_1)W_2+b_2 F(Xi)=(g(XiW1)+b1)W2+b2
作用到每个行向量 X i X_i Xi 上时是先右乘一个矩阵 W 1 W_1 W1,然后做个激活 g g g 再加个偏置 b 1 b_1 b1,接着再乘以一个 W 2 W_2 W2 加上 b 2 b_2 b2,其实也就是两层的全连接
代码实现如下:
import torch.nn as nn
class PWiseFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.F = config["F"]
self.proj_wide = nn.Linear(self.F, 4 * self.F)
self.proj_narrow = nn.Linear(4 * self.F, self.F)
self.act = nn.ReLU()
def forward(self, x):
x = self.proj_wide(x)
x = self.act(x)
x = self.proj_narrow(x)
return x
2.5 归一化+残差网络
前面我们已经把 Transformer 的核心架构部分讲完了,也就是 T ( X ) = F ∘ A ( X ) \mathcal{T}(X)=\mathcal{F}\circ\mathcal{A}(X) T(X)=F∘A(X),它包含两部分先做 A \mathcal{A} A 变换再做一个 F \mathcal{F} F 变换
那其实在 Transformer 网络结构的中间部分还加入了一些归一化和残差网络,下面我们简单说明下
Transformer 中的 Normalization 层一般都是采用 LayerNorm 来对 Tensor 进行归一化,LayerNorm 的公式如下:
A ′ ( X ) = N ∘ A ( X ) L a y e r N o r m : y = x − μ σ γ + β μ = 1 d ∑ i = 1 d x i σ = 1 d ∑ i = 1 d ( x i − μ ) 2 \begin{aligned} A^{\prime}(X)&=\mathcal{N}\circ\mathcal{A}(X) \\ LayerNorm:y&=\frac{x-\mu}{\sqrt{\sigma}}\gamma+\beta \\ \mu&=\dfrac{1}{d}\sum\limits_{i=1}^{d}x_{i} \\ \sigma&=\sqrt{\dfrac{1}{d}\sum\limits_{i=1}^{d}(x_{i}-\mu)^{2}} \end{aligned} A′(X)LayerNorm:yμσ=N∘A(X)=σx−μγ+β=d1i=1∑dxi=d1i=1∑d(xi−μ)2
LayerNorm 和 BatchNorm 比较像,区别是一个是在行上面做归一化一个是在列上面做归一化,而 LayerNorm 可以看作是作用在行向量上的算子。在 NLP 的序列建模里面一般使用 LayerNorm,而在 CV 里面一般使用 BatchNorm
这主要是因为 padding 的影响,以下面的输入矩阵为例,不同 batch 中 <pad> 个数不同,沿着 token 方向做归一化并没有意义,而每个位置做独立的归一化更有意义
( hello world pad pad pad ) → X = ( [ 0.59 , 0.20 , 0.04 , 0.96 ] [ 0.96 , 0.30 , 0.16 , 0.63 ] [ 0.02 , 0.19 , 0.34 , 0.25 ] [ 0.02 , 0.19 , 0.34 , 0.25 ] [ 0.02 , 0.19 , 0.34 , 0.25 ] ) \begin{pmatrix} \text{hello} \\ \text{world} \\ \text{pad} \\ \text{pad} \\ \text{pad} \end{pmatrix} \rightarrow X= \begin{pmatrix} [0.59, 0.20, 0.04, 0.96] \\ [0.96, 0.30, 0.16, 0.63] \\ [0.02, 0.19, 0.34, 0.25] \\ [0.02, 0.19, 0.34, 0.25] \\ [0.02, 0.19, 0.34, 0.25] \end{pmatrix} helloworldpadpadpad →X= [0.59,0.20,0.04,0.96][0.96,0.30,0.16,0.63][0.02,0.19,0.34,0.25][0.02,0.19,0.34,0.25][0.02,0.19,0.34,0.25]
其他的可能选择 RMSNorm 归一化方法,例如 LLaMA 中使用的就是 RMSNorm,RMSNorm 是 LayerNorm 的变体,RMSNorm 省去了求均值的过程,也没有了偏置 β \beta β,公式如下:
R M S N o r m : y = x M e a n ( x 2 ) + ϵ ∗ γ M e a n ( x 2 ) = 1 N ∑ i = 1 N x i 2 \begin{aligned} RMSNorm:y & =\frac{x}{\sqrt{Mean(x^{2})+\epsilon}}*\gamma \\ Mean(x^{2}) & =\frac{1}{N}\sum_{i=1}^Nx_i^2 \end{aligned} RMSNorm:yMean(x2)=Mean(x2)+ϵx∗γ=N1i=1∑Nxi2
其中 γ \gamma γ 为可学习的参数
大家感兴趣的可以看看:RMSNorm算子的CUDA实现
2.6 整体的变换
最后我们看下 transformer 的整体变换 Y = T ( X ) Y=\mathcal{T}(X) Y=T(X),它主要分为以下几个部分:
- Attention Z = N ∘ ( X + A ( X ) ) Z=N\circ(X+\mathcal{A}(X)) Z=N∘(X+A(X))
- 位置无关的全连接 Y = N ∘ ( X + F ( Z ) ) Y=\mathcal{N}\circ(X+\mathcal{F}(Z)) Y=N∘(X+F(Z))
- 残差网络
- A ′ ( X ) = N ∘ ( X + A ( X ) ) \mathcal{A}^{\prime}(X)=\mathcal{N}\circ(X+\mathcal{A}(X)) A′(X)=N∘(X+A(X))
- F ′ = N ∘ ( X + F ( X ) ) \mathcal{F}^{\prime}=\mathcal{N}\circ(X+\mathcal{F}(X)) F′=N∘(X+F(X))
前面我们主要是解释了 transformer 中一层的网络结构,实际上我们是多层,可以任意的去嵌套,对于一个 L L L 层的 transformer 模型表达如下:
T ( X ) = T L ∘ … T 2 ∘ T 1 ( X ) \begin{equation*} \begin{split} \mathcal{T}(X) & = \mathcal{T}_L \circ \ldots \mathcal{T}_{2}\circ \mathcal{T}_{1}(X) \end{split} \end{equation*} T(X)=TL∘…T2∘T1(X)
代码实现如下:
import torch.nn as nn
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = SelfAttention(config)
self.norm_1 = nn.LayerNorm(config["hidden_dim"])
self.mlp = PWiseFeedForward(config)
self.norm_2 = nn.LayerNorm(config["hidden_dim"])
def forward(self, x):
x = self.norm_1(x + self.attn(x))
x = self.norm_2(x + self.mlp(x))
return x
3. transformer参数和计算量
最后我们来回顾下 transformer 的参数量
3.1 关于参数量
一般模型增加复杂度的方式包括:
- 增加深度,增加宽度
- 增加 embedding 的维度
- 增加词典的大小
各种 DNN 主要参数的位置:
- CNN: Y i = ( X i − 1 , X i , X i + 1 ) W Y_{i}=(X_{i-1},X_i, X_{i+1}) W Yi=(Xi−1,Xi,Xi+1)W
- RNN: Y i = tanh ( X i W + Y i − 1 U ) Y_{i}=\tanh(X_{i}W + Y_{i-1}U) Yi=tanh(XiW+Yi−1U)
3.2 参数的分布
我们来看下 transformer 它的参数分布是什么样子的:
1. 多头注意力(Multi-Head Attention)
- 每个头有
- 3 个投影矩阵 W Q , W K , W V W_Q,W_K,W_V WQ,WK,WV
- 1 个投影 concat 结果的矩阵 W O W_O WO
- 假设投影到的子空间维度是 D D D,有 H H H 个子空间且 D × H = F D \times H=F D×H=F
- 参数量:
- F × D × 3 × H = 3 F 2 F \times D \times 3 \times H = 3F^2 F×D×3×H=3F2
- F 2 F^2 F2
2. 前馈网络层(Feedforward)
- 两个矩阵,先从 F F F 变宽到 4 F 4F 4F,再收窄回到 F F F
- 参数量: F × 4 F + 4 F × F = 8 F 2 F\times 4F+4F\times F=8F^2 F×4F+4F×F=8F2
3. word embedding
- E E E 是 token 字典的大小
- E × F E \times F E×F
总共:
- L ( 12 F 2 ) + E F L(12F^2)+EF L(12F2)+EF
- L L L 表示模型的层数
例如:
model | 维度 | 层数 | 头数 | 字典大小 | 参数量 |
---|---|---|---|---|---|
bertBase | 768 | 12 | 12 | 30000 | 110M |
bertLarge | 1024 | 24 | 12 | 30000 | 340M |
3.3 linear transformer
transformer 中两个算子的计算量分别是:
- A ( X ) \mathcal{A}(X) A(X) 计算量 O ( N 2 ) O(N^2) O(N2)
- F ( X ) \mathcal{F}(X) F(X) 计算量 O ( N ) O(N) O(N)
softmax 的存在导致 A ( X ) \mathcal{A}(X) A(X) 计算量是 O ( N 2 ) O(N^2) O(N2),我们知道 attention 核心的计算量在 Q K T V QK^TV QKTV 三个矩阵的相乘上,而乘法的计算量密切依赖于矩阵组合的方式
有 softmax 的存在的话只能先计算 H = Q K T H=QK^T H=QKT,对 H H H 做 softmax 变换后再计算 H V HV HV 乘法,这个计算量是 N 2 D + N 2 M N^2D+N^2M N2D+N2M,整体的复杂度是 O ( N 2 ) O(N^2) O(N2)
Q K T V = ( Q K T ) V = ( H 11 , H 12 , … , H 1 N ⋮ H N 1 , H N 2 , … , H N N ) V QK^TV=(QK^T)V=\begin{pmatrix} H_{11},H_{12},\ldots,H_{1N} \\ \vdots\\ H_{N1},H_{N2},\ldots,H_{NN} \\ \end{pmatrix}V QKTV=(QKT)V= H11,H12,…,H1N⋮HN1,HN2,…,HNN V
如果没有 softmax 的话,可以先计算后两个矩阵相乘 H = K T V H=K^TV H=KTV,再计算 Q H QH QH 乘法,这时计算量是 N D M + D M N = 2 N D M NDM+DMN=2NDM NDM+DMN=2NDM,而当 N ≫ D N\gg D N≫D 的时候,计算量可以是 O ( N ) O(N) O(N),因为 K T V K^TV KTV 可以提前算出来缓存,大致如下面这个表达式所示:
Q ( K T V ) = ( Q 1 Q 2 ⋮ Q N ) ( K T V ) Q(K^TV)=\begin{pmatrix} Q_1 \\ Q_2 \\ \vdots\\ Q_{N} \end{pmatrix}(K^TV) Q(KTV)= Q1Q2⋮QN (KTV)
接着我们看下 kernel 的表达形式,前面我们提到过很多次 attention 可以表示成下面这种加权平均的形式
A ( X i ) = ∑ j = 1 N s i m ( Q i , K j ) V j ∑ j = 1 N s i m ( Q i , K j ) \mathcal{A}(X_i)=\dfrac{\sum_{j=1}^{N} sim(Q_i,K_j) V_j}{\sum_{j=1}^N sim(Q_i,K_j)} A(Xi)=∑j=1Nsim(Qi,Kj)∑j=1Nsim(Qi,Kj)Vj
这里的 s i m sim sim 其实是可以用非负的 kernel 来替换掉,对于 kernel 函数可以映射到其他空间 k ( x , y ) = < ϕ ( x ) , ϕ ( y ) > k(x,y)=<\phi(x),\phi(y)> k(x,y)=<ϕ(x),ϕ(y)>,从而将 s i m sim sim 变成内积的形式 k ( x , y ) = ( x ⋅ z ) 2 , ϕ ( x ) = ( x 1 2 , x 2 2 , 2 x 1 x 2 ) k(x,y)=(x\cdot z)^2, \phi(x)=(x_{1}^{2},x_{2}^2,\sqrt{2}x_1x_{2}) k(x,y)=(x⋅z)2,ϕ(x)=(x12,x22,2x1x2)
当前的 sim 函数 s i m ( x , y ) = e x p ( x y T / D ) sim(x,y)=\mathrm{exp}(xy^{T}/\sqrt{D}) sim(x,y)=exp(xyT/D)
Note:kernel 对应一个 feature map
linear transformer 其实就是用 kernel 来替换掉 sim,公式如下:
A ( X i ) = ∑ j = 1 N s i m ( Q i , K j ) V j ∑ j = 1 N s i m ( Q i , K j ) = ∑ j = 1 N ϕ ( Q i ) ϕ ( K j ) T V j ∑ j = 1 N ϕ ( Q i ) ϕ ( K j ) T = ϕ ( Q i ) ∑ j = 1 N ϕ ( K j ) T V j ϕ ( Q i ) ∑ j = 1 N ϕ ( K j ) T \begin{aligned}\mathcal{A}(X_i) &= \frac{\sum_{j=1}^{N} sim(Q_i,K_j) V_j}{\sum_{j=1}^N sim(Q_i,K_j)} \\ &=\frac{\sum_{j=1}^{N} \phi(Q_i)\phi(K_j)^T V_j}{\sum_{j=1}^N \phi(Q_i)\phi(K_j)^T} \\ &=\frac{ \phi(Q_i) \sum_{j=1}^{N}\phi(K_j)^T V_j}{\phi(Q_i)\sum_{j=1}^N \phi(K_j)^T} \end{aligned} A(Xi)=∑j=1Nsim(Qi,Kj)∑j=1Nsim(Qi,Kj)Vj=∑j=1Nϕ(Qi)ϕ(Kj)T∑j=1Nϕ(Qi)ϕ(Kj)TVj=ϕ(Qi)∑j=1Nϕ(Kj)Tϕ(Qi)∑j=1Nϕ(Kj)TVj
- ∑ j = 1 N ϕ ( K j ) T V , ∑ j = 1 N ϕ ( K j ) T \sum_{j=1}^{N}\phi(K_j)^T V, \sum_{j=1}^N \phi(K_j)^T ∑j=1Nϕ(Kj)TV,∑j=1Nϕ(Kj)T 可以提前算好
- O ( N ) O(N) O(N) 复杂度,Linear Transformer
- ϕ ( x ) = e l u ( x ) + 1 \phi(x)=\mathrm{elu}(x)+1 ϕ(x)=elu(x)+1
更多细节大家可以参考原始论文:《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》
4. 补充—线性Attention的探索
在前面的 3.3 小节我们介绍了 linear transformer,用 kernel 核函数代替原有 s i m sim sim 中的 softmax,博主一头雾水,怎么就突然提到了 kernel 核函数呢?🤔
在苏神的 线性Attention的探索:Attention必须有个Softmax吗? 文章中就有详细介绍,我们来简单了解下
原始的 Attention 机制是 Scaled-Dot Attention,形式为:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K ⊤ ) V Attention(\bm{Q,K,V})=softmax(\bm{Q}\bm{K}^{\top})\bm{V} Attention(Q,K,V)=softmax(QK⊤)V
其中 Q ∈ R n × d k , K ∈ R m × d k , V ∈ R m × d v \bm{Q}\in\mathbb{R}^{n\times d_{k}},\bm{K}\in\mathbb{R}^{m\times d_{k}},\bm{V} \in\mathbb{R}^{m\times d_{v}} Q∈Rn×dk,K∈Rm×dk,V∈Rm×dv,简单起见这里我们省略了缩放因子 scaling
这里我们只关注 self-attention 场景,为了介绍方便统一设 Q , K , V ∈ R n × d \bm{Q,K,V}\in\mathbb{R}^{n\times d} Q,K,V∈Rn×d,一般场景下都有 n > d n>d n>d 甚至 n ≫ d n \gg d n≫d
前面我们提到过制约 attention 性能的关键因素是定义里面边的 softmax, Q K ⊤ \bm{QK}^{\top} QK⊤ 这一步我们得到一个 n × n n \times n n×n 的矩阵,就是这一步决定了 attention 的复杂度是 O ( n 2 ) O(n^2) O(n2);如果没有 softmax,那么就是三个矩阵连乘 ( Q K ⊤ ) V (\bm{Q}\bm{K}^{\top})\bm{V} (QK⊤)V,而矩阵乘法是满足结合律的,所以我们可以先算 K ⊤ V \bm{K}^{\top}\bm{V} K⊤V 得到一个 d × d d \times d d×d 的矩阵,然后再用 Q \bm{Q} Q 左乘它,由于 d ≪ n d \ll n d≪n,所以这样算的复杂度只有 O ( n ) O(n) O(n)
也就是说,去掉 softmax 的 attention 的复杂度可以降到最理想的线性级别 O ( n ) O(n) O(n)!这显然就是我们的终极追求:Linear Attention,复杂度为线性级别的 attention
问题是,直接去掉 softmax 还能算是 attention 吗?它还能有标准 attention 的效果吗?为了回答这个问题,我们先将 Scaled-Dot Attention 的定义等价地改写为:
A t t e n t i o n ( Q , K , V ) i = ∑ j = 1 n e q i ⊤ k j v j ∑ j = 1 n e q i ⊤ k j Attention\bm{(Q,K,V)}_{i}=\frac{\sum\limits_{j=1}^{n}{e^{\bm{q}_{i} ^{\top}\bm{k}_{j}}\bm{v}_{j}}}{\sum\limits_{j=1}^{n}{e^{\bm{q}_{i}^{\top}\bm{k}_ {j}}}} Attention(Q,K,V)i=j=1∑neqi⊤kjj=1∑neqi⊤kjvj
Note:苏神文章中提到的向量都是列向量
所以 Scaled-Dot Attention 其实就是以 e q i ⊤ k j e^{\bm{q}_{i} ^{\top}\bm{k}_{j}} eqi⊤kj 为权重对 v j \bm{v}_j vj 做加权平均,因此我们可以提出一个 Attention 的一般化定义:
A t t e n t i o n ( Q , K , V ) i = ∑ j = 1 n s i m ( q i , k j ) v j ∑ j = 1 n s i m ( q i , k j ) Attention\bm{(Q,K,V)}_{i}=\frac{\sum\limits_{j=1}^{n}{sim(\bm{q}_i,\bm{k}_j)\bm{v}_j}}{\sum\limits_{j=1}^{n}{sim(\bm{q}_i,\bm{k}_j)}} Attention(Q,K,V)i=j=1∑nsim(qi,kj)j=1∑nsim(qi,kj)vj
也就是把 e q i ⊤ k j e^{\bm{q}_{i} ^{\top}\bm{k}_{j}} eqi⊤kj 换成 q i , k j \bm{q}_i,\bm{k}_j qi,kj 的一般函数 s i m ( q i , k j ) sim(\bm{q}_i,\bm{k}_j) sim(qi,kj),为了保留 attention 相似的分布特性,我们要求 s i m ( q i , k j ) ≥ 0 sim(\bm{q}_i,\bm{k}_j)\geq 0 sim(qi,kj)≥0 恒成立。也就是说,如果我们要定义新式的 attention,那么要保留上述公式的形式,并且满足 s i m ( q i , k j ) ≥ 0 sim(\bm{q}_i,\bm{k}_j)\geq 0 sim(qi,kj)≥0
这种一般形式的 attention 在 CV 中也被称为 Non-Local 网络,出自论文《Non-local Neural Networks》
如果直接去掉 softmax,那么就是 s i m ( q i , k j ) = q i ⊤ k j sim(\bm{q}_i,\bm{k}_j)={\bm{q}_i^{\top}}\bm{k}_j sim(qi,kj)=qi⊤kj,问题是内积无法保证非负性,所以这还不是一个合理的选择,下面我们简单介绍几种可取的方案
一个自然的想法是:如果 q i , k j {\bm{q}_i},\bm{k}_j qi,kj 的每个元素都是非负的,那么内积自然也就是非负的。为了完成这点,我们可以给 q i , k j {\bm{q}_i},\bm{k}_j qi,kj 各自加个激活函数 ϕ , φ \phi,\varphi ϕ,φ,即
s i m ( q i , k j ) = ϕ ( q i ) ⊤ φ ( k j ) sim(\bm{q}_i,\bm{k}_j)=\phi(\bm{q}_i)^{\top}\varphi(\bm{k}_j) sim(qi,kj)=ϕ(qi)⊤φ(kj)
其中 ϕ ( ⋅ ) , φ ( ⋅ ) \phi(\cdot),\varphi(\cdot) ϕ(⋅),φ(⋅) 是值域非负的激活函数,在论文《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》中选择的是 ϕ ( x ) = φ ( x ) = elu ( x ) + 1 \phi(x)=\varphi(x)=\text{elu}(x)+1 ϕ(x)=φ(x)=elu(x)+1
另一篇更早的文章《Efficient Attention: Attention with Linear Complexities》则给出了一个更有意思的选择。它留意到在 Q K ⊤ \bm{QK}^{\top} QK⊤ 中, Q , K , ∈ R n × d \bm{Q},\bm{K},\in \mathbb{R}^{n \times d} Q,K,∈Rn×d,如果 Q \bm{Q} Q 在 d d d 那一维是归一化的、并且 K \bm{K} K 在 n n n 那一维是归一化的,那么 Q K ⊤ \bm{QK}^{\top} QK⊤ 就是自动满足归一化了,所以它给出的选择是:
A t t e n t i o n ( Q , K , V ) = s o f t m a x 2 ( Q ) s o f t m a x 1 ( K ) ⊤ V Attention(\bm{Q,K,V})=softmax_2(\bm{Q})softmax_1(\bm{K})^{\top}\bm{V} Attention(Q,K,V)=softmax2(Q)softmax1(K)⊤V
其中 s o f t m a x 1 softmax_1 softmax1、 s o f t m a x 2 softmax_2 softmax2 分别指在第一个( n n n)、第二个维度( d d d)进行 softmax 运算。也就是说,这时候我们是各自给 Q , K \bm{Q,K} Q,K 加 softmax,而不是 Q K ⊤ \bm{QK}^{\top} QK⊤ 算完之后才加 softmax
如果直接取 ϕ ( q i ) = s o f t m a x ( q i ) , φ ( k j ) = s o f t m a x ( k j ) \phi(\bm{q}_i)=softmax(\bm{q}_i),\varphi(\bm{k}_j)=softmax(\bm{k}_j) ϕ(qi)=softmax(qi),φ(kj)=softmax(kj),那么很显然这个形式也是前面我们说的核函数形式的一个特例。
最后,苏神给出了他自己的一个构思,这个构思的出发点源于对原始 attention 公式的近似,由泰勒展开我们有:
e q i ⊤ k j ≈ 1 + q i ⊤ k j e^{{\bm{q}_i^{\top}}\bm{k}_j}\approx {1+{\bm{q}_i^{\top}}\bm{k}_j} eqi⊤kj≈1+qi⊤kj
如果 q i ⊤ k j ≥ − 1 {{\bm{q}_i^{\top}}\bm{k}_j} \geq -1 qi⊤kj≥−1,那么就可以保证右端的非负性,从而可以让 s i m ( q i , k j ) = 1 + q i ⊤ k j sim(\bm{q}_i,\bm{k}_j)={1+{\bm{q}_i^{\top}}\bm{k}_j} sim(qi,kj)=1+qi⊤kj。想要保证 q i ⊤ k j ≥ − 1 {{\bm{q}_i^{\top}}\bm{k}_j} \geq -1 qi⊤kj≥−1,只需要分别对 q i , k j \bm{q}_i,\bm{k}_j qi,kj 做 l 2 l_2 l2 归一化,所以苏神最终提出的方案就是:
s i m ( q i , k j ) = 1 + ( q i ∥ q i ∥ ) ⊤ ( k j ∥ k j ∥ ) sim \left( \bm{q}_{i},\bm{k}_{j} \right)=1+ \left( \frac{\bm{q}_{i}}{ \left \| \bm{q}_{i} \right \|} \right)^{\top} \left( \frac{\bm{k}_{j}}{ \left \| \bm{k}_{j} \right \|} \right) sim(qi,kj)=1+(∥qi∥qi)⊤(∥kj∥kj)
这不同于核函数形式,但理论上它更加接近原始的 Scaled-Dot Attention
结语
本篇文章从矩阵乘法的角度来探究 transformer,首先从 encoder 的角度去观察 attention,其本质是将低阶语义向量转换为高阶语义向量的形式,其动机是要对低阶语义向量 X i X_i Xi 周围的每一个低阶语义向量做 similarity(相似性),然后再把它们的低阶语义向量基于相似性做一个加权平均,如果直接做的话是没有参数的,也无法进行学习,因此我们给它投影到多个子空间去做 attention 最后做一个拼接
transformer 的核心变换是两次,先做一次 self-attention,最后做 Feedforward。self-attention 核心的计算量在 Q K T V QK^TV QKTV 三个矩阵的相乘上,先计算 Q K T QK^T QKT 计算量是 O ( N 2 ) O(N^2) O(N2),先计算 K T V K^TV KTV 计算量是 O ( N ) O(N) O(N)
最后我们介绍了下 linear transformer,利用 kernel 表达可以将原来的 Q , K Q,K Q,K 映射到新空间 ϕ ( Q i ) , ϕ ( K j ) \phi(Q_i),\phi(K_j) ϕ(Qi),ϕ(Kj) 上,把 ϕ ( Q i ) \phi(Q_i) ϕ(Qi) 提取出来从而使得 attention 的计算量只有 O ( N ) O(N) O(N)
大家可以多看看连博的讲解,非常的不错🤗
参考
- 从矩阵乘法探秘transformer+代码讲解
- 深入理解transformer
- 《Attention Is All Your Need》
- https://github.com/karpathy/minGPT/tree/master/mingpt
- 《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》
- 线性Attention的探索:Attention必须有个Softmax吗?