RNN并行化——《Were RNNs All We Needed?》论文解读
Info | |
---|---|
Paper | https://arxiv.org/abs/2410.01201 |
GitHub | https://github.com/lucidrains/minGRU-pytorch |
个人博客地址 | http://myhz0606.com/article/mini_rnn |
最近在看并行RNN相关的paper,发现很多都利用了Parallel Scanning算法。本文将从Parallel Scanning算法开始,介绍Bengio团队不久前发表的《Were RNNs All We Needed?》
1 Parallel Scanning算法介绍
首先来看定义。Parallel Scanning字面意思,就是对scan操作进行并行化,那么什么是scan
(扫描)操作呢?
1.1 Scan的定义
1.1.1 inclusive scan
scan
(inclusive scan)也称为all-prefix-sum,其定义如下:
若给定:
- 有序集合(order set) A = [ a 0 , a 1 , . . . , a n − 1 ] A=[ a _ { 0 } , a _ { 1 } , . . . , a _ { n - 1 } ] A=[a0,a1,...,an−1] ,
- 二元结合运算符(binary associative operation) ⊕ \oplus ⊕ ,并且 ⊕ \oplus ⊕ 的单位元 I \mathcal{I} I存在
输出一个order set,并满足
B = [ a 0 , ( a 0 ⊕ a 1 ) , . . . , ( a 0 ⊕ a 1 ⊕ . . . ⊕ a n − 1 ) ] B=[ a _ { 0 } , ( a _ { 0 } \oplus a _ { 1 } ) , . . . , ( a _ { 0 } \oplus a _ { 1 } \oplus . . . \oplus a _ { n - 1 } ) ] B=[a0,(a0⊕a1),...,(a0⊕a1⊕...⊕an−1)] .
将满足上述规则的操作称为scan
。
显然上式可以写成递归形式,时间复杂度为 O ( n ) \mathcal{O}(n) O(n)
B [ i ] = { A [ 0 ] i f i = 0 B [ i − 1 ] ⊕ A [ i ] i f 0 < i < n (1) B [ i ] = \left\{ \begin{array} { r c l } { A [ 0 ] } & { \mathrm { i f } } & { i = 0 } \\ { B [ i - 1 ] \oplus A [ i ] } & { \mathrm { i f } } & { 0 \lt i \lt n } \\ \end{array} \right. \tag{1} B[i]={A[0]B[i−1]⊕A[i]ififi=00<i<n(1)
注1:二元结合运算符作用于两个操作数返回一个结果,且运算满足结合率。常见的二元结合运算符包括加法( + + +)、乘法( ∗ * ∗)、逻辑与( & \& &)和逻辑或( ∣ | ∣)等. 注2: ⊕ \oplus ⊕ 的单位元 I \mathcal{I} I:若: a ⊕ I = a a \oplus \mathcal{I} = a a⊕I=a,则称 I \mathcal{I} I是运算 ⊕ \oplus ⊕ 的单位元。例如,加法的单位元是0,乘法的单位元是1,向量点乘的单位元是单位向量。
1.1.2 exclusive scan
实践中,scan
另一种变体prescan
(也叫exclusive scan)也经常用到,输入和scan
一致,输出为:
C = [ I , a 0 , ( a 0 ⊕ a 1 ) , . . . , ( a 0 ⊕ a 1 ⊕ . . . ⊕ a n − 2 ) ] . C =[ \mathcal{I}, a _ { 0 } , ( a _ { 0 } \oplus a _ { 1 } ) , . . . , ( a _ { 0 } \oplus a _ { 1 } \oplus . . . \oplus a _ { n - 2 } ) ] . C=[I,a0,(a0⊕a1),...,(a0⊕a1⊕...⊕an−2)].
其递归形式为
C [ i ] = { I i f i = 0 C [ i − 1 ] ⊕ A [ i − 1 ] i f 0 < i < n (2) C [ i ] = \left\{ \begin{array} { r c l } { \mathcal{I} } & { \mathrm { i f } } & { i = 0 } \\ { C [ i - 1 ] \oplus A [ i - 1 ] } & { \mathrm { i f } } & { 0 \lt i \lt n } \\ \end{array} \right. \tag{2} C[i]={IC[i−1]⊕A[i−1]ififi=00<i<n(2)
inclusive scan与exclusive scan可以很方便的转化,
inclusive scan → exclusive scan,只需将输出序列向右移一个单位,并且在序列第一个元素填充单位元。
exclusive scan → inclusive scan,只需将输出序列向左移一个单位,并且用最后一个输入元素加上最后一个输出元素的结果填充最后一个元素。
1.1.3 例子: prefix sum
已知输入有序集合 A = [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ] A= [0, 1, 2, 3, 4, 5, 6, 7] A=[0,1,2,3,4,5,6,7],二元结合运算符为加法 + + +,计算A在 + + +下的inclusive scan和exclusive scan
根据式1,易得inclusive scan的结果为: [ 0 , 1 , 3 , 6 , 10 , 15 , 21 , 28 ] [0, 1, 3, 6, 10, 15, 21, 28] [0,1,3,6,10,15,21,28]
根据式2,易得exclusive scan的结果为: [ 0 , 0 , 1 , 3 , 6 , 10 , 15 , 21 ] [0, 0, 1, 3, 6, 10, 15, 21] [0,0,1,3,6,10,15,21]
代码实现:
def inclusive_scan_recursive(ls: list) -> list:
n = len(ls)
if n <= 1:
return ls
output = [ls[0]] * n
for i in range(1, n):
output[i] = ls[i] + output[i - 1]
return output
1.2 Parallel Scanning
前文所述基于递归式计算scan的算法称之为sequential algorithm,其计算复杂度为 O ( n ) \mathcal{O}(n) O(n),并且无法并行化。那么如何并行化计算scan呢?
1.2.1 Kogge-Stone
Parallel Scanning algorithm[2]
Kogge-Stone
并行扫描算法的基本计算流程如下图所示(从最底部往上看)
总计分为
⌊
log
2
(
N
)
⌋
\lfloor \log_2 (N) \rfloor
⌊log2(N)⌋个阶段,在每一个阶段并行计算
a
[
i
+
2
d
]
=
a
[
i
]
+
a
[
i
+
2
d
]
a[i+2^d] = a[i] + a[i+2^d]
a[i+2d]=a[i]+a[i+2d] (
d
d
d表示阶段, 从0开始取)。该方法的加法运算次数为
∑
d
=
0
⌊
log
2
(
N
)
⌋
−
1
(
N
−
2
d
)
\sum _{d=0}^ {\lfloor \log_2 (N) \rfloor - 1}(N - 2^d)
∑d=0⌊log2(N)⌋−1(N−2d)多于顺序算法的
N
N
N,不考虑并行的情况下时间复杂度为
O
(
n
log
n
)
\mathcal{O}(n\log n)
O(nlogn)。但在processor足够时,Kogge-Stone
的时间复杂度为
O
(
log
n
)
\mathcal{O}(\log n)
O(logn)。
python代码实现如下:
注意由于python原生的多线程存在GIL,无法利用多核优势,故使用numpy实现
def inclusive_scan_kogge_stone(ls: list) -> list:
n = len(ls)
if n <= 1:
return ls
stages = math.floor(math.log2(n))
ls_arr = np.array(ls)
for i in range(stages):
stride = 2 ** i
ls_arr[stride:] += ls_arr[:n-stride] # 并行计算
return ls_arr.tolist()
1.2.2 Brent-Kung
Parallel Scanning algorithm[3]
从上文中,Kogge-Stone
算法虽然在并行的情况下将scan的时间复杂度从
O
(
n
)
\mathcal{O}(n)
O(n)降到了
O
(
log
n
)
\mathcal{O}(\log n)
O(logn),但Kogge-Stone
算法实际的计算量是比顺序执行多不少的。下面来看计算效率更高的Brent-Kung
算法。
Kogge-Stone
算法分为两个阶段
stage1: 上行阶段,计算reduce (up sweep)
上行阶段有 log 2 ( N ) \log_2 (N) log2(N) 个阶段,每个阶段 ( d = 0 , 1 , . . . , log 2 ( N ) ) (d=0,1,...,\log_2(N)) (d=0,1,...,log2(N))执行
a [ i + 2 d + 1 − 1 ] = a [ i + 2 d − 1 ] + a [ i + 2 d + 1 − 1 ] a[i + 2^{d+1}-1] = a[i + 2^{d}-1] + a[i + 2^{d+1}-1] a[i+2d+1−1]=a[i+2d−1]+a[i+2d+1−1]
算法流程:
f o r d f r o m 0 t o ( l o g 2 N ) − 1 i n p a r a l l e l f o r i f r o m 0 t o N − 1 b y 2 d + 1 a [ i + 2 d + 1 − 1 ] ← a [ i + 2 d − 1 ] + a [ i + 2 d + 1 − 1 ] \begin{array} { l } { { \bf { f o r } } \; \; d \; \; { \bf { f r o m } } \; \; 0 \; \; { \bf { t o } } \; \; ( { \bf { log_{2} } } \; N ) - 1 } \\ { \quad { \bf { i n } } \; { \bf { p a r a l l e l } } \; \; { \bf { f o r } } \; \; i \; \; { \bf { f r o m } } \; \; 0 \; \; { \bf { t o } } \; \; N - 1 \; \; { \bf { b y } } \; \; 2 ^ { d + 1 } } \\ { \quad \quad a [ i + 2 ^ { d + 1 } - 1 ] \gets a [ i + 2 ^ { d } - 1 ] + a [ i + 2 ^ { d + 1 } - 1 ] } \\ \end{array} fordfrom0to(log2N)−1inparallelforifrom0toN−1by2d+1a[i+2d+1−1]←a[i+2d−1]+a[i+2d+1−1]
下面来分析一下up sweep的时间复杂度
up sweep的计算量为
∑ d = 0 log 2 N − 1 N 2 d + 1 = ∑ d = 0 log 2 N − 1 1 2 d N 2 = N 2 ∑ ∗ d = 0 log 2 N − 1 1 2 d ⏟ 等比数列求和 = N 2 1 − ( 1 2 ) log 2 N 1 − 1 2 = N ( 1 − ( 1 2 ) log 2 N ) = N ( 1 − 1 N ) = N − 1 (3) \begin{aligned} \sum _{d=0} ^ {\log_2{N} - 1} \frac{N}{2 ^ {d + 1}} & = \sum _{d=0} ^ {\log_2{N} - 1} \frac{1}{2 ^ {d}} \frac{N}{2} \\ & = \frac{N}{2} \underbrace{\sum *{d=0} ^ {\log_2{N} - 1} \frac{1}{2 ^ {d}}}_{等比数列求和} \\ & = \frac{N}{2} \frac{1 - (\frac{1}{2})^{\log_2{N}}}{1 - \frac{1}{2}} \\ & = N(1 - (\frac{1}{2})^{\log_2{N}}) \\ & = N(1 - \frac{1}{N}) \\ & = N - 1 \end{aligned} \tag{3} d=0∑log2N−12d+1N=d=0∑log2N−12d12N=2N等比数列求和 ∑∗d=0log2N−12d1=2N1−211−(21)log2N=N(1−(21)log2N)=N(1−N1)=N−1(3)
不做并行的时间复杂度为 O ( n ) \mathcal{O}(n) O(n),并行时的时间复杂度为 O ( log n ) \mathcal{O}(\log n) O(logn)
python代码如下:
此处为了便于理解,第二个循环没有用并行
def brent_kung_up_sweep(ls: list) -> list:
n = len(ls)
if n <= 1:
return ls
stages = math.floor(math.log2(n))
assert 2 ** stages == n
for d in range(stages):
# parallel
for i in range(0, n, 2 ** (d + 1)):
ls[i + 2 ** (d + 1) - 1] = ls[i + 2 ** d - 1] + ls[i + 2 ** (d + 1) - 1]
return ls
通过up sweep 我们可以得到reduce的结果,但无法得到完整的scan结果,需要继续进行down sweep。
stage2: 下行阶段(down sweep)
算法流程:
p r o c e d u r e d o w n − s w e e p ( A ) a [ n − 1 ] ← 0 f o r d f r o m ( log 2 N ) − 1 d o w n t o 0 i n p a r a l l e l f o r i f r o m 0 t o N − 1 b y 2 d + 1 t ← a [ i + 2 d − 1 ] a [ i + 2 d − 1 ] ← a [ i + 2 d + 1 − 1 ] \begin{array} { r l } & { \mathbf { p r o c e d u r e \ \, d o w n - s w e e p } ( \mathtt { A } ) } \\ & { \quad a [ n - 1 ] \leftarrow 0 } \\ & { \quad \mathbf { f o r } \ \ d \ \mathbf { f r o m } \ \left( \log_2 N \right) - 1 \ \mathbf { \ d o w n t o \ } 0 } \\ & { \quad \quad \mathbf { i n \ p a r a l l e l \ f o r } \ \ i \ \mathbf { f r o m } \ \ 0 \ \mathbf { \ t o } \ \ N - 1 \ \mathbf { \ b y } \ \ 2 ^ { d + 1 } } \\ & { \quad \quad \quad t \gets a [ i + 2 ^ { d } - 1 ] } \\ & { \quad \quad \quad a [ i + 2 ^ { d } - 1 ] \gets a [ i + 2 ^ { d + 1 } - 1 ] } & { \quad \quad } \end{array} procedure down−sweep(A)a[n−1]←0for d from (log2N)−1 downto 0in parallel for i from 0 to N−1 by 2d+1t←a[i+2d−1]a[i+2d−1]←a[i+2d+1−1]
计算复杂度与up-sweep一致
python代码如下:
def brent_kung_down_sweep(ls: list) -> list:
n = len(ls)
if n <= 0:
return ls
ls[-1] = 0
stages = math.floor(math.log2(n))
assert 2 ** stages == n
for d in range(stages - 1, -1, -1):
# parallel
for i in range(0, n, 2 ** (d + 1)):
ls[i + 2 ** d - 1], ls[i + 2 ** (d + 1) - 1] = ls[i + 2 ** (d + 1) - 1], ls[i + 2 ** d - 1] + ls[i + 2 ** (d + 1) - 1]
return ls
综上所述,我们详细介绍了Kogge-Stone
算法,它分为up sweep和down sweep两个阶段,每个阶段的计算量为
N
−
1
N-1
N−1,不做并行的计算时间复杂度为:
O
(
n
)
\mathcal{O}(n)
O(n),并行时的计算复杂度为
O
(
log
n
)
\mathcal{O}(\log{n})
O(logn)
def inclusive_scan_brent_kung(ls):
n = len(ls)
if n <= 1:
return ls
arr = np.array(ls)
logn = int(math.log2(n))
# 确保输入长度是2的幂
if n & (n - 1) != 0:
raise ValueError("Input length must be a power of 2")
# Up-sweep阶段
for d in range(logn):
i = np.array(list(range(0, n, 2 ** (d + 1))))
arr[i + 2 ** (d + 1) - 1] = arr[i + 2 ** d - 1] + arr[i + 2 ** (d + 1) - 1]
# Down-sweep阶段
final_item = arr[-1]
arr[-1] = 0
for d in range(logn - 1, -1, -1):
i = np.array(list(range(0, n, 2 ** (d + 1))))
# numpy based parallel
arr[i + 2 ** d - 1], arr[i + 2 ** (d + 1) - 1] = arr[i + 2 ** (d + 1) - 1], arr[i + 2 ** d - 1] + arr[i + 2 ** (d + 1) - 1]
ls = arr.tolist()
inclusive_scan_res = ls[1:] + [final_item]
return inclusive_scan_res
❓小练习
不妨尝试回答一下几个问题:
- 当输入序列的长度并不是2的N次幂,如何用
Brent-Kung
算法进行并行? - 如果系统的processor有限,此时的时间复杂度是多少?
2 并行RNN
通过上文的介绍我们可以用并行的方法计算递归式 b ( t ) = b ( t − 1 ) ⊕ a ( t ) , g i v e n b ( 0 ) , a ( t ) b(t) = b(t-1) \oplus a(t), \mathrm{given} \, b(0), a(t) b(t)=b(t−1)⊕a(t),givenb(0),a(t)。那如何将其与RNN建立起联系呢?
先来回顾一下两个经典的RNN算法,1)LSTM, 2)GRU
2.1 经典RNN回顾
2.1.1 LSTM
LSTM引入记忆细胞C(t)来存储长期信息,解决传统RNN无法处理长程依赖的问题。并引入3个门(遗忘门、输入门、输出门)来控制新老信息的交互。
下面来详细看其计算流程:
给定
- 输入序列: X = [ X ( 0 ) , X ( 1 ) , ⋯ , X ( t ) , ⋯ , X ( T ) ] , X ∈ R T × d , X ( t ) × R d X=[X(0), X(1), \cdots, X(t), \cdots, X(T)], X \in \mathbb{R} ^{T \times d}, X(t) \times \mathbb{R}^{d} X=[X(0),X(1),⋯,X(t),⋯,X(T)],X∈RT×d,X(t)×Rd
- 初始化隐藏状态 H ( 0 ) ∈ R d H(0) \in \mathbb{R}^{d} H(0)∈Rd
- 初始化记忆细胞 C ( 0 ) ∈ R d C(0) \in \mathbb{R}^{d} C(0)∈Rd
Forget Gate: F ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) Input Gate: I ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) Output Gate: O ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) Candidate Cell: C ~ ( t ) = tanh ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) ⇒ U p d a t e Memory Cell: C ( t ) = C ( t − 1 ) ⊙ F ( t ) + I ( t ) ⊙ C ~ ( t ) Hidden State: H ( t ) = tanh ( C ( t ) ) ⊙ O ( t ) (4) \begin{aligned} &\textbf{Forget Gate:} &F(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ &\textbf{Input Gate:} &I(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ &\textbf{Output Gate:} &O(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ &\textbf{Candidate Cell:} &\widetilde{C}(t) &= \mathrm{\tanh}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ \stackrel{\mathrm{Update}}{\Rightarrow } \\ & \textbf{Memory Cell:} &C(t) & =C(t-1) \odot F(t) + I(t) \odot \widetilde{C}(t) \\ & \textbf{Hidden State:} &H(t) & = \tanh(C(t)) \odot O(t) \end{aligned} \tag{4} ⇒UpdateForget Gate:Input Gate:Output Gate:Candidate Cell:Memory Cell:Hidden State:F(t)I(t)O(t)C (t)C(t)H(t)=Sigmoid(Linear(Cat([H(t−1),X(t)])))=Sigmoid(Linear(Cat([H(t−1),X(t)])))=Sigmoid(Linear(Cat([H(t−1),X(t)])))=tanh(Linear(Cat([H(t−1),X(t)])))=C(t−1)⊙F(t)+I(t)⊙C (t)=tanh(C(t))⊙O(t)(4)
三个门的输出在0~1之间,通过点乘来控制信息的流入量。
2.1.2 GRU
GRU简化了LSTM的门控机制达到和LSTM类似的效果。GRU主要通过两个门(重置门、更新门)来控制信息的交互。
下面来详细看其计算流程:
给定
- 输入序列: X = [ X ( 0 ) , X ( 1 ) , ⋯ , X ( t ) , ⋯ , X ( T ) ] , X ∈ R T × d , X ( t ) × R d X=[X(0), X(1), \cdots, X(t), \cdots, X(T)], X \in \mathbb{R} ^{T \times d}, X(t) \times \mathbb{R}^{d} X=[X(0),X(1),⋯,X(t),⋯,X(T)],X∈RT×d,X(t)×Rd
- 初始化隐藏状态 H ( 0 ) ∈ R d H(0) \in \mathbb{R}^{d} H(0)∈Rd
Reset Gate: R ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) Update Gate: Z ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) Candidate Hidden State: H ~ ( t ) = tanh ( L i n e a r ( C a t ( [ H ( t − 1 ) ⊙ R ( t ) , X ( t ) ] ) ) ) ⇒ U p d a t e Hidden State: H ( t ) = H ( t − 1 ) ⊙ ( 1 − Z ( t ) ) + H ~ ( t ) ⊙ Z ( t ) (5) \begin{aligned} &\textbf{Reset Gate:} &R(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ &\textbf{Update Gate:} &Z(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ &\textbf{Candidate Hidden State:} &\widetilde {H}(t) &= \mathrm{\tanh} \bigl(\mathrm{Linear}(\mathrm{Cat}([H(t-1) \odot R(t), X(t)] ))\bigr) \\ \stackrel{\mathrm{Update}}{\Rightarrow } \\ & \textbf{Hidden State:} &H(t) & = H(t-1) \odot (1 - Z(t)) + \widetilde {H}(t) \odot Z(t) \end{aligned} \tag{5} ⇒UpdateReset Gate:Update Gate:Candidate Hidden State:Hidden State:R(t)Z(t)H (t)H(t)=Sigmoid(Linear(Cat([H(t−1),X(t)])))=Sigmoid(Linear(Cat([H(t−1),X(t)])))=tanh(Linear(Cat([H(t−1)⊙R(t),X(t)])))=H(t−1)⊙(1−Z(t))+H (t)⊙Z(t)(5)
2.2 经典RNN并行化
2.2.1 理论基础
通过前文介绍,我们回顾了经典RNN的递归更新公式,但显然,无法直接沿用parallel scan算法进行并行
递归更新公式 | |
---|---|
LSTM | Memory Cell: C ( t ) = C ( t − 1 ) ⊙ F ( t ) + I ( t ) ⊙ C ~ ( t ) Hidden State: H ( t ) = tanh ( C ( t ) ) ⊙ O ( t ) \begin{aligned}& \textbf{Memory Cell:} &C(t) & =C(t-1) \odot F(t) + I(t) \odot \widetilde{C}(t) \\ & \textbf{Hidden State:} &H(t) & = \tanh(C(t)) \odot O(t) \\ \end{aligned} Memory Cell:Hidden State:C(t)H(t)=C(t−1)⊙F(t)+I(t)⊙C (t)=tanh(C(t))⊙O(t) |
GRU | Hidden State: H ( t ) = H ( t − 1 ) ⊙ ( 1 − Z ( t ) ) + H ~ ( t ) ⊙ Z ( t ) \begin{aligned} \textbf{Hidden State:} \quad \quad H(t) & = H(t-1) \odot (1 - Z(t)) + \widetilde {H}(t) \odot Z(t) \\ \end{aligned} Hidden State:H(t)=H(t−1)⊙(1−Z(t))+H (t)⊙Z(t) |
- 对于LSTM而言 F ( t ) , I ( t ) , C ~ ( t ) , O ( t ) F(t), I(t), \widetilde{C}(t), O(t) F(t),I(t),C (t),O(t)依赖上一个时间步的 H ( t − 1 ) H(t-1) H(t−1)的计算,且其递归式的形式并非 y ( t ) = y ( t − 1 ) ⊕ a ( t ) , a ( t ) y(t)=y(t-1) \oplus a(t), a(t) y(t)=y(t−1)⊕a(t),a(t)已知。
- 对于GRU而言, Z ( t ) , H ~ ( t ) Z(t), \widetilde {H}(t) Z(t),H (t)同样依赖上一个时间步的 H ( t − 1 ) H(t-1) H(t−1)的计算,且其递归式的形式并非 y ( t ) = y ( t − 1 ) ⊕ a ( t ) , a ( t ) y(t)=y(t-1) \oplus a(t), a(t) y(t)=y(t−1)⊕a(t),a(t)已知。
故他们都无法利用parallel scan算法进行并行化。
如何让LSTM,GRU能够使用parallel scan算法进行并行呢?
不考虑对以往时间步的依赖,LSTM,GRU的递归更新公式形如:
y ( t ) = y ( t − 1 ) ⊙ a ( t ) + b ( t ) (6) y(t)=y(t-1) \odot a(t) + b(t) \tag{6} y(t)=y(t−1)⊙a(t)+b(t)(6)
对 ∀ t , a ( t ) , b ( t ) \forall t, a(t), b(t) ∀t,a(t),b(t)已知。这个式子和标准的scan多了一个偏置项 b ( t ) b(t) b(t)。文献[6]指出,只需对式6进行适当变形,即可用两次parallel scan算法对式6进行并行计算。
推导前,不妨将式(6)简写为: y t = y t − 1 a t + b t y_t=y_{t-1} a_t + b_t yt=yt−1at+bt
t=1 y 1 = y 0 a 1 + b 1 t=2 y 2 = y 1 a 2 + b 2 = ( y 0 a 1 + b 1 ) ⊙ a 2 + b 2 = a 1 a 2 [ y 0 + b 1 a 1 + b 2 a 1 a 2 ] t=3 y 3 = y 2 a 3 + b 3 = a 1 a 2 a 3 [ y 0 + b 1 a 1 + b 2 a 1 a 2 ] + b 3 = a 1 a 2 a 3 [ y 0 + b 1 a 1 + b 2 a 1 a 2 + b 3 a 1 a 2 a 3 ] (7) \begin{aligned} &\textbf{t=1} & y_1&=y_0 a_1 + b_1 \\ &\textbf{t=2} &y_2 &= y_1 a_2 + b_2 \\ &&& = (y_0 a_1 + b_1) \odot a_2 + b_2 \\ &&& = a_1 a_2 [y_0 + \frac {b_1}{a_1} + \frac{b_2}{a_1 a_2}] \\ &\textbf{t=3} &y_3 &= y_2 a_3 + b_3 \\ &&& = a_1 a_2 a_3[y_0 + \frac {b_1}{a_1} + \frac{b_2}{a_1 a_2}] + b_3 \\ &&& = a_1 a_2 a_3[y_0 + \frac {b_1}{a_1} + \frac{b_2}{a_1 a_2} + \frac{b_3}{a_1 a_2 a_3}] \\ \end{aligned} \tag{7} t=1t=2t=3y1y2y3=y0a1+b1=y1a2+b2=(y0a1+b1)⊙a2+b2=a1a2[y0+a1b1+a1a2b2]=y2a3+b3=a1a2a3[y0+a1b1+a1a2b2]+b3=a1a2a3[y0+a1b1+a1a2b2+a1a2a3b3](7)
通过归纳,不难得出
y n = y n − 1 a n + b n = ∏ t = 1 n a t ( y 0 + ∑ t = 1 n b t ∏ i = 1 t a i ) = ∏ t = 1 n a t ( y 0 + ∑ t = 1 n exp ( log b t ∏ i = 1 t a i ) ) = ∏ t = 1 n a t ( y 0 + ∑ t = 1 n exp ( log b t − ∑ i = 1 t log a i ) ) (8) \begin{aligned} y_n &= y_{n-1} a_n + b_n \\ &= \prod_{t=1}^{n}a_t \left(y_0 + \sum_{t=1}^{n}\frac{b_t}{\prod_{i=1}^{t}a_i } \right) \\ &= \prod_{t=1}^{n}a_t \left(y_0 + \sum_{t=1}^{n} \exp \left(\log {\frac{b_t}{\prod_{i=1}^{t}a_i }} \right) \right) \\ &= \prod_{t=1}^{n}a_t \left(y_0 + \sum_{t=1}^{n} \exp \left(\log b_t - {\sum_{i=1}^{t}\log a_i }\right)\right) \\ \end{aligned} \tag{8} yn=yn−1an+bn=t=1∏nat(y0+t=1∑n∏i=1taibt)=t=1∏nat(y0+t=1∑nexp(log∏i=1taibt))=t=1∏nat(y0+t=1∑nexp(logbt−i=1∑tlogai))(8)
对上式子两边取对数,有
log y n = ∑ t = 1 n log a t + log ( y 0 + ∑ t = 1 n exp ( log b t − ∑ i = 1 t log a i ) ) (9) \log y_n = \sum_{t=1}^{n} \log a_t + \log \left(y_0 + \sum_{t=1}^{n} \exp \left(\log b_t - {\sum_{i=1}^{t}\log a_i }\right)\right) \tag{9} logyn=t=1∑nlogat+log(y0+t=1∑nexp(logbt−i=1∑tlogai))(9)
从上述递归式可以看出,有两处可以用两次parallel scan算法
第一次parallel scan计算有序集合 { ∑ t = 1 n log a t ∣ n = 1 , 2 , ⋯ , n } \{\sum_{t=1}^{n} \log a_t | n=1, 2, \cdots , n\} {∑t=1nlogat∣n=1,2,⋯,n},
第二次parallel scan计算有序集合 { ∑ t = 1 n exp ( log b t − ∑ i = 1 t log a i ) ∣ n = 1 , 2 , ⋯ , n } \{\sum_{t=1}^{n} \exp \left(\log b_t - {\sum_{i=1}^{t}\log a_i }\right)|n=1, 2, \cdots , n\} {∑t=1nexp(logbt−∑i=1tlogai)∣n=1,2,⋯,n}
有了他们,我们可以并行计算有序集合 { y n ∣ n = 1 , 2 , ⋯ , n } \{ y_n | n=1, 2, \cdots , n \} {yn∣n=1,2,⋯,n}
下面来看,如何将LSTM,GRU转变为式(6)的形式
2.2.2 LSTM的并行化
Step 1: Drop previous hidden state dependencies from gates
F ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) I ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) O ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) C ~ ( t ) = tanh ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) ⇒ F ( t ) = S i g m o i d ( L i n e a r X ( t ) ) ) I ( t ) = S i g m o i d ( L i n e a r ( X ( t ) ) ) O ( t ) = S i g m o i d ( L i n e a r ( X ( t ) ) ) C ~ ( t ) = tanh ( L i n e a r ( X ( t ) ) ) (10) \boxed{\begin{aligned} F(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ I(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ O(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ \widetilde{C}(t) &= \mathrm{\tanh}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ \end{aligned}} \Rightarrow \begin{aligned} F(t) &= \mathrm{Sigmoid}(\mathrm{Linear}X(t))) \\ I(t) &= \mathrm{Sigmoid}(\mathrm{Linear}( X(t))) \\ O(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(X(t))) \\ \widetilde{C}(t) &= \mathrm{\tanh}(\mathrm{Linear}(X(t))) \\ \end{aligned} \tag{10} F(t)I(t)O(t)C (t)=Sigmoid(Linear(Cat([H(t−1),X(t)])))=Sigmoid(Linear(Cat([H(t−1),X(t)])))=Sigmoid(Linear(Cat([H(t−1),X(t)])))=tanh(Linear(Cat([H(t−1),X(t)])))⇒F(t)I(t)O(t)C (t)=Sigmoid(LinearX(t)))=Sigmoid(Linear(X(t)))=Sigmoid(Linear(X(t)))=tanh(Linear(X(t)))(10)
Step 2: Drop range restriction of candidate states
C ~ ( t ) = tanh ( L i n e a r ( X ( t ) ) ) H ( t ) = tanh ( C ( t ) ) ⊙ O ( t ) ⇒ C ~ ( t ) = L i n e a r ( X ( t ) ) H ( t ) = C ( t ) ⊙ O ( t ) (11) \boxed{\begin{aligned} \widetilde{C}(t) &= \mathrm{\tanh}(\mathrm{Linear}(X(t))) \\ H(t) & = \tanh(C(t)) \odot O(t) \end{aligned}} \Rightarrow \begin{aligned} \widetilde{C}(t) &= \mathrm{Linear}(X(t)) \\ H(t) & = C(t) \odot O(t) \end{aligned} \tag{11} C (t)H(t)=tanh(Linear(X(t)))=tanh(C(t))⊙O(t)⇒C (t)H(t)=Linear(X(t))=C(t)⊙O(t)(11)
Step 3: Ensure output is time-independent in scale
F ( t ) = S i g m o i d ( L i n e a r X ( t ) ) ) I ( t ) = S i g m o i d ( L i n e a r ( X ( t ) ) ) O ( t ) = S i g m o i d ( L i n e a r ( X ( t ) ) ) C ~ ( t ) = L i n e a r ( X ( t ) ) C ( t ) = C ( t − 1 ) ⊙ F ( t ) + I ( t ) ⊙ C ~ ( t ) H ( t ) = C ( t ) ⊙ O ( t ) ⇒ H ~ ( t ) = S i g m o i d ( L i n e a r ( X ( t ) ) ) F ′ ( t ) = F ( t ) F ( t ) + I ( t ) I ′ ( t ) = I ( t ) F ( t ) + I ( t ) H ( t ) = F ′ ( t ) ⊙ H ( t − 1 ) + I ′ ( t ) ⊙ H ~ ( t ) (12) \boxed{\begin{aligned} F(t) &= \mathrm{Sigmoid}(\mathrm{Linear}X(t))) \\ I(t) &= \mathrm{Sigmoid}(\mathrm{Linear}( X(t))) \\ O(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(X(t))) \\ \widetilde{C}(t) &= \mathrm{Linear}(X(t)) \\ \hline C(t) & =C(t-1) \odot F(t) + I(t) \odot \widetilde{C}(t) \\ H(t) & = C(t) \odot O(t) \end{aligned}} \Rightarrow \begin{aligned} \widetilde{H}(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(X(t))) \\ F'(t) & = \frac{F(t)}{F(t) + I(t)} \\ I'(t) & = \frac{I(t)}{F(t) + I(t)} \\ \hline H(t) & = F'(t) \odot H(t-1) + I'(t) \odot \widetilde{H}(t) \\ \end{aligned} \tag{12} F(t)I(t)O(t)C (t)C(t)H(t)=Sigmoid(LinearX(t)))=Sigmoid(Linear(X(t)))=Sigmoid(Linear(X(t)))=Linear(X(t))=C(t−1)⊙F(t)+I(t)⊙C (t)=C(t)⊙O(t)⇒H (t)F′(t)I′(t)H(t)=Sigmoid(Linear(X(t)))=F(t)+I(t)F(t)=F(t)+I(t)I(t)=F′(t)⊙H(t−1)+I′(t)⊙H (t)(12)
通过上述的操作,结合文献[6]的技巧(式9)完成LSTM的并行化。
2.2.3 GRU的并行化
GRU的并行化的操作和LSTM类似
Step 1: Drop previous hidden state dependencies from gates
R ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) Z ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) H ~ ( t ) = tanh ( L i n e a r ( C a t ( [ H ( t − 1 ) ⊙ R ( t ) , X ( t ) ] ) ) ) ⇒ Z ( t ) = S i g m o i d ( L i n e a r ( X ( t ) ) ) H ~ ( t ) = tanh ( L i n e a r ( X ( t ) ) ) (13) \boxed{\begin{aligned} R(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ Z(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ \widetilde {H}(t) &= \mathrm{\tanh} \bigl(\mathrm{Linear}(\mathrm{Cat}([H(t-1) \odot R(t), X(t)] ))\bigr) \\ \end{aligned}} \Rightarrow \begin{aligned} Z(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(X(t))) \\ \widetilde {H}(t) &= \mathrm{\tanh} (\mathrm{Linear}( X(t) ))\\ \end{aligned} \tag{13} R(t)Z(t)H (t)=Sigmoid(Linear(Cat([H(t−1),X(t)])))=Sigmoid(Linear(Cat([H(t−1),X(t)])))=tanh(Linear(Cat([H(t−1)⊙R(t),X(t)])))⇒Z(t)H (t)=Sigmoid(Linear(X(t)))=tanh(Linear(X(t)))(13)
Step 2: Drop range restriction of candidate states
H ~ ( t ) = tanh ( L i n e a r ( X ( t ) ) ) ⇒ H ~ ( t ) = L i n e a r ( X ( t ) ) (14) \boxed{\begin{aligned} \widetilde {H}(t) &= \mathrm{\tanh} (\mathrm{Linear}( X(t) )) \end{aligned}} \Rightarrow \widetilde {H}(t)= \mathrm{Linear}( X(t) ) \tag{14} H (t)=tanh(Linear(X(t)))⇒H (t)=Linear(X(t))(14)
3 小结
本文从parallel scan算法出发,介绍了如何将经典RNN算法——LSTM,GRU进行变换,使其能够并行化。实验结果本文不做介绍,请参考原论文。
Reference:
[1] Prefix Sums and Their Applications
[2] A parallel algorithm for the efficient solution of a general class of recurrence equations
[3] A Regular Layout for Parallel Adders
[4] LONG SHORT-TERM MEMORY
[5] Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling
[6] Efficient Parallelization of a Ubiquitous Sequential Computation
[7] https://www.csd.uwo.ca/~mmorenom/HPC-Slides/Parallel_prefix_sum.pdf
[8] https://people.cs.vt.edu/yongcao/teaching/cs5234/spring2013/slides/Lecture10.pdf
[9] https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda