当前位置: 首页 > article >正文

RNN并行化——《Were RNNs All We Needed?》论文解读

Info
Paperhttps://arxiv.org/abs/2410.01201
GitHubhttps://github.com/lucidrains/minGRU-pytorch
个人博客地址http://myhz0606.com/article/mini_rnn

最近在看并行RNN相关的paper,发现很多都利用了Parallel Scanning算法。本文将从Parallel Scanning算法开始,介绍Bengio团队不久前发表的《Were RNNs All We Needed?》

1 Parallel Scanning算法介绍

首先来看定义。Parallel Scanning字面意思,就是对scan操作进行并行化,那么什么是scan(扫描)操作呢?

1.1 Scan的定义

1.1.1 inclusive scan

scan (inclusive scan)也称为all-prefix-sum,其定义如下:

若给定:

  • 有序集合(order set) A = [ a 0 , a 1 , . . . , a n − 1 ] A=[ a _ { 0 } , a _ { 1 } , . . . , a _ { n - 1 } ] A=[a0,a1,...,an1] ,
  • 二元结合运算符(binary associative operation) ⊕ \oplus ,并且 ⊕ \oplus 的单位元 I \mathcal{I} I存在

输出一个order set,并满足

B = [ a 0 , ( a 0 ⊕ a 1 ) , . . . , ( a 0 ⊕ a 1 ⊕ . . . ⊕ a n − 1 ) ] B=[ a _ { 0 } , ( a _ { 0 } \oplus a _ { 1 } ) , . . . , ( a _ { 0 } \oplus a _ { 1 } \oplus . . . \oplus a _ { n - 1 } ) ] B=[a0,(a0a1),...,(a0a1...an1)] .

将满足上述规则的操作称为scan

显然上式可以写成递归形式,时间复杂度为 O ( n ) \mathcal{O}(n) O(n)

B [ i ] = { A [ 0 ] i f i = 0 B [ i − 1 ] ⊕ A [ i ] i f 0 < i < n (1) B [ i ] = \left\{ \begin{array} { r c l } { A [ 0 ] } & { \mathrm { i f } } & { i = 0 } \\ { B [ i - 1 ] \oplus A [ i ] } & { \mathrm { i f } } & { 0 \lt i \lt n } \\ \end{array} \right. \tag{1} B[i]={A[0]B[i1]A[i]ififi=00<i<n(1)

注1:二元结合运算符作用于两个操作数返回一个结果,且运算满足结合率。常见的二元结合运算符包括加法( + + +)、乘法( ∗ * )、逻辑与( & \& &)和逻辑或( ∣ | )等. 注2: ⊕ \oplus 的单位元 I \mathcal{I} I:若: a ⊕ I = a a \oplus \mathcal{I} = a aI=a,则称 I \mathcal{I} I是运算 ⊕ \oplus 的单位元。例如,加法的单位元是0,乘法的单位元是1,向量点乘的单位元是单位向量。

1.1.2 exclusive scan

实践中,scan另一种变体prescan(也叫exclusive scan)也经常用到,输入和scan一致,输出为:

C = [ I , a 0 , ( a 0 ⊕ a 1 ) , . . . , ( a 0 ⊕ a 1 ⊕ . . . ⊕ a n − 2 ) ] . C =[ \mathcal{I}, a _ { 0 } , ( a _ { 0 } \oplus a _ { 1 } ) , . . . , ( a _ { 0 } \oplus a _ { 1 } \oplus . . . \oplus a _ { n - 2 } ) ] . C=[I,a0,(a0a1),...,(a0a1...an2)].

其递归形式为

C [ i ] = { I i f i = 0 C [ i − 1 ] ⊕ A [ i − 1 ] i f 0 < i < n (2) C [ i ] = \left\{ \begin{array} { r c l } { \mathcal{I} } & { \mathrm { i f } } & { i = 0 } \\ { C [ i - 1 ] \oplus A [ i - 1 ] } & { \mathrm { i f } } & { 0 \lt i \lt n } \\ \end{array} \right. \tag{2} C[i]={IC[i1]A[i1]ififi=00<i<n(2)

inclusive scan与exclusive scan可以很方便的转化,

inclusive scan → exclusive scan,只需将输出序列向右移一个单位,并且在序列第一个元素填充单位元。

exclusive scan → inclusive scan,只需将输出序列向左移一个单位,并且用最后一个输入元素加上最后一个输出元素的结果填充最后一个元素。

1.1.3 例子: prefix sum

已知输入有序集合 A = [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ] A= [0, 1, 2, 3, 4, 5, 6, 7] A=[0,1,2,3,4,5,6,7],二元结合运算符为加法 + + +,计算A在 + + +下的inclusive scan和exclusive scan

根据式1,易得inclusive scan的结果为: [ 0 , 1 , 3 , 6 , 10 , 15 , 21 , 28 ] [0, 1, 3, 6, 10, 15, 21, 28] [0,1,3,6,10,15,21,28]

根据式2,易得exclusive scan的结果为: [ 0 , 0 , 1 , 3 , 6 , 10 , 15 , 21 ] [0, 0, 1, 3, 6, 10, 15, 21] [0,0,1,3,6,10,15,21]

代码实现:

def inclusive_scan_recursive(ls: list) -> list:
    n = len(ls)
    if n <= 1:
        return ls
    output = [ls[0]] * n
    for i in range(1, n):
        output[i] = ls[i] + output[i - 1]
    return output

1.2 Parallel Scanning

前文所述基于递归式计算scan的算法称之为sequential algorithm,其计算复杂度为 O ( n ) \mathcal{O}(n) O(n),并且无法并行化。那么如何并行化计算scan呢?

1.2.1 Kogge-Stone Parallel Scanning algorithm[2]

Kogge-Stone 并行扫描算法的基本计算流程如下图所示(从最底部往上看)

总计分为 ⌊ log ⁡ 2 ( N ) ⌋ \lfloor \log_2 (N) \rfloor log2(N)⌋个阶段,在每一个阶段并行计算 a [ i + 2 d ] = a [ i ] + a [ i + 2 d ] a[i+2^d] = a[i] + a[i+2^d] a[i+2d]=a[i]+a[i+2d] ( d d d表示阶段, 从0开始取)。该方法的加法运算次数为 ∑ d = 0 ⌊ log ⁡ 2 ( N ) ⌋ − 1 ( N − 2 d ) \sum _{d=0}^ {\lfloor \log_2 (N) \rfloor - 1}(N - 2^d) d=0log2(N)⌋1(N2d)多于顺序算法的 N N N,不考虑并行的情况下时间复杂度为 O ( n log ⁡ n ) \mathcal{O}(n\log n) O(nlogn)。但在processor足够时,Kogge-Stone 的时间复杂度为 O ( log ⁡ n ) \mathcal{O}(\log n) O(logn)

python代码实现如下:

注意由于python原生的多线程存在GIL,无法利用多核优势,故使用numpy实现

def inclusive_scan_kogge_stone(ls: list) -> list:
    n = len(ls)
    if n <= 1:
        return ls
    stages = math.floor(math.log2(n))
    ls_arr = np.array(ls)
    for i in range(stages):
        stride = 2 ** i
        ls_arr[stride:] += ls_arr[:n-stride]  # 并行计算
    return ls_arr.tolist()

在这里插入图片描述

1.2.2 Brent-Kung Parallel Scanning algorithm[3]

从上文中,Kogge-Stone 算法虽然在并行的情况下将scan的时间复杂度从 O ( n ) \mathcal{O}(n) O(n)降到了 O ( log ⁡ n ) \mathcal{O}(\log n) O(logn),但Kogge-Stone 算法实际的计算量是比顺序执行多不少的。下面来看计算效率更高的Brent-Kung 算法。

Kogge-Stone 算法分为两个阶段

stage1: 上行阶段,计算reduce (up sweep)

上行阶段有 log ⁡ 2 ( N ) \log_2 (N) log2(N) 个阶段,每个阶段 ( d = 0 , 1 , . . . , log ⁡ 2 ( N ) ) (d=0,1,...,\log_2(N)) (d=0,1,...,log2(N))执行

a [ i + 2 d + 1 − 1 ] = a [ i + 2 d − 1 ] + a [ i + 2 d + 1 − 1 ] a[i + 2^{d+1}-1] = a[i + 2^{d}-1] + a[i + 2^{d+1}-1] a[i+2d+11]=a[i+2d1]+a[i+2d+11]

算法流程:

f o r      d      f r o m      0      t o      ( l o g 2    N ) − 1 i n    p a r a l l e l      f o r      i      f r o m      0      t o      N − 1      b y      2 d + 1 a [ i + 2 d + 1 − 1 ] ← a [ i + 2 d − 1 ] + a [ i + 2 d + 1 − 1 ] \begin{array} { l } { { \bf { f o r } } \; \; d \; \; { \bf { f r o m } } \; \; 0 \; \; { \bf { t o } } \; \; ( { \bf { log_{2} } } \; N ) - 1 } \\ { \quad { \bf { i n } } \; { \bf { p a r a l l e l } } \; \; { \bf { f o r } } \; \; i \; \; { \bf { f r o m } } \; \; 0 \; \; { \bf { t o } } \; \; N - 1 \; \; { \bf { b y } } \; \; 2 ^ { d + 1 } } \\ { \quad \quad a [ i + 2 ^ { d + 1 } - 1 ] \gets a [ i + 2 ^ { d } - 1 ] + a [ i + 2 ^ { d + 1 } - 1 ] } \\ \end{array} fordfrom0to(log2N)1inparallelforifrom0toN1by2d+1a[i+2d+11]a[i+2d1]+a[i+2d+11]

在这里插入图片描述

下面来分析一下up sweep的时间复杂度

up sweep的计算量为

∑ d = 0 log ⁡ 2 N − 1 N 2 d + 1 = ∑ d = 0 log ⁡ 2 N − 1 1 2 d N 2 = N 2 ∑ ∗ d = 0 log ⁡ 2 N − 1 1 2 d ⏟ 等比数列求和 = N 2 1 − ( 1 2 ) log ⁡ 2 N 1 − 1 2 = N ( 1 − ( 1 2 ) log ⁡ 2 N ) = N ( 1 − 1 N ) = N − 1 (3) \begin{aligned} \sum _{d=0} ^ {\log_2{N} - 1} \frac{N}{2 ^ {d + 1}} & = \sum _{d=0} ^ {\log_2{N} - 1} \frac{1}{2 ^ {d}} \frac{N}{2} \\ & = \frac{N}{2} \underbrace{\sum *{d=0} ^ {\log_2{N} - 1} \frac{1}{2 ^ {d}}}_{等比数列求和} \\ & = \frac{N}{2} \frac{1 - (\frac{1}{2})^{\log_2{N}}}{1 - \frac{1}{2}} \\ & = N(1 - (\frac{1}{2})^{\log_2{N}}) \\ & = N(1 - \frac{1}{N}) \\ & = N - 1 \end{aligned} \tag{3} d=0log2N12d+1N=d=0log2N12d12N=2N等比数列求和 d=0log2N12d1=2N1211(21)log2N=N(1(21)log2N)=N(1N1)=N1(3)

不做并行的时间复杂度为 O ( n ) \mathcal{O}(n) O(n),并行时的时间复杂度为 O ( log ⁡ n ) \mathcal{O}(\log n) O(logn)

python代码如下:

此处为了便于理解,第二个循环没有用并行

def brent_kung_up_sweep(ls: list) -> list:
    n = len(ls)
    if n <= 1:
        return ls
    stages = math.floor(math.log2(n))
    assert 2 ** stages == n

    for d in range(stages):
        # parallel
        for i in range(0, n, 2 ** (d + 1)):
            ls[i + 2 ** (d + 1) - 1] = ls[i + 2 ** d - 1] + ls[i + 2 ** (d + 1) - 1]
    return ls

通过up sweep 我们可以得到reduce的结果,但无法得到完整的scan结果,需要继续进行down sweep。

stage2: 下行阶段(down sweep)

算法流程:

p r o c e d u r e     d o w n − s w e e p ( A ) a [ n − 1 ] ← 0 f o r    d   f r o m   ( log ⁡ 2 N ) − 1     d o w n t o   0 i n   p a r a l l e l   f o r    i   f r o m    0     t o    N − 1     b y    2 d + 1 t ← a [ i + 2 d − 1 ] a [ i + 2 d − 1 ] ← a [ i + 2 d + 1 − 1 ] \begin{array} { r l } & { \mathbf { p r o c e d u r e \ \, d o w n - s w e e p } ( \mathtt { A } ) } \\ & { \quad a [ n - 1 ] \leftarrow 0 } \\ & { \quad \mathbf { f o r } \ \ d \ \mathbf { f r o m } \ \left( \log_2 N \right) - 1 \ \mathbf { \ d o w n t o \ } 0 } \\ & { \quad \quad \mathbf { i n \ p a r a l l e l \ f o r } \ \ i \ \mathbf { f r o m } \ \ 0 \ \mathbf { \ t o } \ \ N - 1 \ \mathbf { \ b y } \ \ 2 ^ { d + 1 } } \\ & { \quad \quad \quad t \gets a [ i + 2 ^ { d } - 1 ] } \\ & { \quad \quad \quad a [ i + 2 ^ { d } - 1 ] \gets a [ i + 2 ^ { d + 1 } - 1 ] } & { \quad \quad } \end{array} procedure downsweep(A)a[n1]0for  d from (log2N)1  downto 0in parallel for  i from  0  to  N1  by  2d+1ta[i+2d1]a[i+2d1]a[i+2d+11]

计算复杂度与up-sweep一致

在这里插入图片描述

python代码如下:

def brent_kung_down_sweep(ls: list) -> list:
    n = len(ls)
    if n <= 0:
        return ls
    ls[-1] = 0
    stages = math.floor(math.log2(n))
    assert 2 ** stages == n
    for d in range(stages - 1, -1, -1):
        # parallel
        for i in range(0, n, 2 ** (d + 1)):
            ls[i + 2 ** d - 1], ls[i + 2 ** (d + 1) - 1] = ls[i + 2 ** (d + 1) - 1], ls[i + 2 ** d - 1] + ls[i + 2 ** (d + 1) - 1]
    return ls

综上所述,我们详细介绍了Kogge-Stone 算法,它分为up sweep和down sweep两个阶段,每个阶段的计算量为 N − 1 N-1 N1,不做并行的计算时间复杂度为: O ( n ) \mathcal{O}(n) O(n),并行时的计算复杂度为 O ( log ⁡ n ) \mathcal{O}(\log{n}) O(logn)

def inclusive_scan_brent_kung(ls):
    n = len(ls)
    if n <= 1:
        return ls
    arr = np.array(ls)
    logn = int(math.log2(n))
    # 确保输入长度是2的幂
    if n & (n - 1) != 0:
        raise ValueError("Input length must be a power of 2")

    # Up-sweep阶段
    for d in range(logn):
        i = np.array(list(range(0, n, 2 ** (d + 1))))
        arr[i + 2 ** (d + 1) - 1] = arr[i + 2 ** d - 1] + arr[i + 2 ** (d + 1) - 1]

    # Down-sweep阶段
    final_item = arr[-1]
    arr[-1] = 0
    for d in range(logn - 1, -1, -1):
        i = np.array(list(range(0, n, 2 ** (d + 1))))
        # numpy based parallel
        arr[i + 2 ** d - 1], arr[i + 2 ** (d + 1) - 1] = arr[i + 2 ** (d + 1) - 1], arr[i + 2 ** d - 1] +  arr[i + 2 ** (d + 1) - 1]
    ls = arr.tolist()
    inclusive_scan_res = ls[1:] + [final_item]
    return inclusive_scan_res

❓小练习

不妨尝试回答一下几个问题:

  1. 当输入序列的长度并不是2的N次幂,如何用 Brent-Kung 算法进行并行?
  2. 如果系统的processor有限,此时的时间复杂度是多少?

2 并行RNN

通过上文的介绍我们可以用并行的方法计算递归式 b ( t ) = b ( t − 1 ) ⊕ a ( t ) , g i v e n   b ( 0 ) , a ( t ) b(t) = b(t-1) \oplus a(t), \mathrm{given} \, b(0), a(t) b(t)=b(t1)a(t),givenb(0),a(t)。那如何将其与RNN建立起联系呢?

先来回顾一下两个经典的RNN算法,1)LSTM, 2)GRU

2.1 经典RNN回顾

2.1.1 LSTM

LSTM引入记忆细胞C(t)来存储长期信息,解决传统RNN无法处理长程依赖的问题。并引入3个门(遗忘门、输入门、输出门)来控制新老信息的交互。

在这里插入图片描述

下面来详细看其计算流程:

给定

  • 输入序列: X = [ X ( 0 ) , X ( 1 ) , ⋯   , X ( t ) , ⋯   , X ( T ) ] , X ∈ R T × d , X ( t ) × R d X=[X(0), X(1), \cdots, X(t), \cdots, X(T)], X \in \mathbb{R} ^{T \times d}, X(t) \times \mathbb{R}^{d} X=[X(0),X(1),,X(t),,X(T)],XRT×d,X(t)×Rd
  • 初始化隐藏状态 H ( 0 ) ∈ R d H(0) \in \mathbb{R}^{d} H(0)Rd
  • 初始化记忆细胞 C ( 0 ) ∈ R d C(0) \in \mathbb{R}^{d} C(0)Rd

Forget   Gate: F ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) Input   Gate: I ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) Output   Gate: O ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) Candidate   Cell: C ~ ( t ) = tanh ⁡ ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) ⇒ U p d a t e Memory   Cell: C ( t ) = C ( t − 1 ) ⊙ F ( t ) + I ( t ) ⊙ C ~ ( t ) Hidden   State: H ( t ) = tanh ⁡ ( C ( t ) ) ⊙ O ( t ) (4) \begin{aligned} &\textbf{Forget Gate:} &F(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ &\textbf{Input Gate:} &I(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ &\textbf{Output Gate:} &O(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ &\textbf{Candidate Cell:} &\widetilde{C}(t) &= \mathrm{\tanh}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ \stackrel{\mathrm{Update}}{\Rightarrow } \\ & \textbf{Memory Cell:} &C(t) & =C(t-1) \odot F(t) + I(t) \odot \widetilde{C}(t) \\ & \textbf{Hidden State:} &H(t) & = \tanh(C(t)) \odot O(t) \end{aligned} \tag{4} UpdateForget Gate:Input Gate:Output Gate:Candidate Cell:Memory Cell:Hidden State:F(t)I(t)O(t)C (t)C(t)H(t)=Sigmoid(Linear(Cat([H(t1),X(t)])))=Sigmoid(Linear(Cat([H(t1),X(t)])))=Sigmoid(Linear(Cat([H(t1),X(t)])))=tanh(Linear(Cat([H(t1),X(t)])))=C(t1)F(t)+I(t)C (t)=tanh(C(t))O(t)(4)

三个门的输出在0~1之间,通过点乘来控制信息的流入量。

2.1.2 GRU

GRU简化了LSTM的门控机制达到和LSTM类似的效果。GRU主要通过两个门(重置门、更新门)来控制信息的交互。

在这里插入图片描述

下面来详细看其计算流程:

给定

  • 输入序列: X = [ X ( 0 ) , X ( 1 ) , ⋯   , X ( t ) , ⋯   , X ( T ) ] , X ∈ R T × d , X ( t ) × R d X=[X(0), X(1), \cdots, X(t), \cdots, X(T)], X \in \mathbb{R} ^{T \times d}, X(t) \times \mathbb{R}^{d} X=[X(0),X(1),,X(t),,X(T)],XRT×d,X(t)×Rd
  • 初始化隐藏状态 H ( 0 ) ∈ R d H(0) \in \mathbb{R}^{d} H(0)Rd

Reset   Gate: R ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) Update   Gate: Z ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) Candidate   Hidden   State: H ~ ( t ) = tanh ⁡ ( L i n e a r ( C a t ( [ H ( t − 1 ) ⊙ R ( t ) , X ( t ) ] ) ) ) ⇒ U p d a t e Hidden   State: H ( t ) = H ( t − 1 ) ⊙ ( 1 − Z ( t ) ) + H ~ ( t ) ⊙ Z ( t ) (5) \begin{aligned} &\textbf{Reset Gate:} &R(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ &\textbf{Update Gate:} &Z(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ &\textbf{Candidate Hidden State:} &\widetilde {H}(t) &= \mathrm{\tanh} \bigl(\mathrm{Linear}(\mathrm{Cat}([H(t-1) \odot R(t), X(t)] ))\bigr) \\ \stackrel{\mathrm{Update}}{\Rightarrow } \\ & \textbf{Hidden State:} &H(t) & = H(t-1) \odot (1 - Z(t)) + \widetilde {H}(t) \odot Z(t) \end{aligned} \tag{5} UpdateReset Gate:Update Gate:Candidate Hidden State:Hidden State:R(t)Z(t)H (t)H(t)=Sigmoid(Linear(Cat([H(t1),X(t)])))=Sigmoid(Linear(Cat([H(t1),X(t)])))=tanh(Linear(Cat([H(t1)R(t),X(t)])))=H(t1)(1Z(t))+H (t)Z(t)(5)

2.2 经典RNN并行化

2.2.1 理论基础

通过前文介绍,我们回顾了经典RNN的递归更新公式,但显然,无法直接沿用parallel scan算法进行并行

递归更新公式
LSTM Memory   Cell: C ( t ) = C ( t − 1 ) ⊙ F ( t ) + I ( t ) ⊙ C ~ ( t ) Hidden   State: H ( t ) = tanh ⁡ ( C ( t ) ) ⊙ O ( t ) \begin{aligned}& \textbf{Memory Cell:} &C(t) & =C(t-1) \odot F(t) + I(t) \odot \widetilde{C}(t) \\ & \textbf{Hidden State:} &H(t) & = \tanh(C(t)) \odot O(t) \\ \end{aligned} Memory Cell:Hidden State:C(t)H(t)=C(t1)F(t)+I(t)C (t)=tanh(C(t))O(t)
GRU Hidden   State: H ( t ) = H ( t − 1 ) ⊙ ( 1 − Z ( t ) ) + H ~ ( t ) ⊙ Z ( t ) \begin{aligned} \textbf{Hidden State:} \quad \quad H(t) & = H(t-1) \odot (1 - Z(t)) + \widetilde {H}(t) \odot Z(t) \\ \end{aligned} Hidden State:H(t)=H(t1)(1Z(t))+H (t)Z(t)
  • 对于LSTM而言 F ( t ) , I ( t ) , C ~ ( t ) , O ( t ) F(t), I(t), \widetilde{C}(t), O(t) F(t),I(t),C (t),O(t)依赖上一个时间步的 H ( t − 1 ) H(t-1) H(t1)的计算,且其递归式的形式并非 y ( t ) = y ( t − 1 ) ⊕ a ( t ) , a ( t ) y(t)=y(t-1) \oplus a(t), a(t) y(t)=y(t1)a(t),a(t)已知。
  • 对于GRU而言, Z ( t ) , H ~ ( t ) Z(t), \widetilde {H}(t) Z(t),H (t)同样依赖上一个时间步的 H ( t − 1 ) H(t-1) H(t1)的计算,且其递归式的形式并非 y ( t ) = y ( t − 1 ) ⊕ a ( t ) , a ( t ) y(t)=y(t-1) \oplus a(t), a(t) y(t)=y(t1)a(t),a(t)已知。

故他们都无法利用parallel scan算法进行并行化。

如何让LSTM,GRU能够使用parallel scan算法进行并行呢?

不考虑对以往时间步的依赖,LSTM,GRU的递归更新公式形如:

y ( t ) = y ( t − 1 ) ⊙ a ( t ) + b ( t ) (6) y(t)=y(t-1) \odot a(t) + b(t) \tag{6} y(t)=y(t1)a(t)+b(t)(6)

∀ t , a ( t ) , b ( t ) \forall t, a(t), b(t) t,a(t),b(t)已知。这个式子和标准的scan多了一个偏置项 b ( t ) b(t) b(t)。文献[6]指出,只需对式6进行适当变形,即可用两次parallel scan算法对式6进行并行计算。

推导前,不妨将式(6)简写为: y t = y t − 1 a t + b t y_t=y_{t-1} a_t + b_t yt=yt1at+bt

t=1 y 1 = y 0 a 1 + b 1 t=2 y 2 = y 1 a 2 + b 2 = ( y 0 a 1 + b 1 ) ⊙ a 2 + b 2 = a 1 a 2 [ y 0 + b 1 a 1 + b 2 a 1 a 2 ] t=3 y 3 = y 2 a 3 + b 3 = a 1 a 2 a 3 [ y 0 + b 1 a 1 + b 2 a 1 a 2 ] + b 3 = a 1 a 2 a 3 [ y 0 + b 1 a 1 + b 2 a 1 a 2 + b 3 a 1 a 2 a 3 ] (7) \begin{aligned} &\textbf{t=1} & y_1&=y_0 a_1 + b_1 \\ &\textbf{t=2} &y_2 &= y_1 a_2 + b_2 \\ &&& = (y_0 a_1 + b_1) \odot a_2 + b_2 \\ &&& = a_1 a_2 [y_0 + \frac {b_1}{a_1} + \frac{b_2}{a_1 a_2}] \\ &\textbf{t=3} &y_3 &= y_2 a_3 + b_3 \\ &&& = a_1 a_2 a_3[y_0 + \frac {b_1}{a_1} + \frac{b_2}{a_1 a_2}] + b_3 \\ &&& = a_1 a_2 a_3[y_0 + \frac {b_1}{a_1} + \frac{b_2}{a_1 a_2} + \frac{b_3}{a_1 a_2 a_3}] \\ \end{aligned} \tag{7} t=1t=2t=3y1y2y3=y0a1+b1=y1a2+b2=(y0a1+b1)a2+b2=a1a2[y0+a1b1+a1a2b2]=y2a3+b3=a1a2a3[y0+a1b1+a1a2b2]+b3=a1a2a3[y0+a1b1+a1a2b2+a1a2a3b3](7)

通过归纳,不难得出

y n = y n − 1 a n + b n = ∏ t = 1 n a t ( y 0 + ∑ t = 1 n b t ∏ i = 1 t a i ) = ∏ t = 1 n a t ( y 0 + ∑ t = 1 n exp ⁡ ( log ⁡ b t ∏ i = 1 t a i ) ) = ∏ t = 1 n a t ( y 0 + ∑ t = 1 n exp ⁡ ( log ⁡ b t − ∑ i = 1 t log ⁡ a i ) ) (8) \begin{aligned} y_n &= y_{n-1} a_n + b_n \\ &= \prod_{t=1}^{n}a_t \left(y_0 + \sum_{t=1}^{n}\frac{b_t}{\prod_{i=1}^{t}a_i } \right) \\ &= \prod_{t=1}^{n}a_t \left(y_0 + \sum_{t=1}^{n} \exp \left(\log {\frac{b_t}{\prod_{i=1}^{t}a_i }} \right) \right) \\ &= \prod_{t=1}^{n}a_t \left(y_0 + \sum_{t=1}^{n} \exp \left(\log b_t - {\sum_{i=1}^{t}\log a_i }\right)\right) \\ \end{aligned} \tag{8} yn=yn1an+bn=t=1nat(y0+t=1ni=1taibt)=t=1nat(y0+t=1nexp(logi=1taibt))=t=1nat(y0+t=1nexp(logbti=1tlogai))(8)

对上式子两边取对数,有

log ⁡ y n = ∑ t = 1 n log ⁡ a t + log ⁡ ( y 0 + ∑ t = 1 n exp ⁡ ( log ⁡ b t − ∑ i = 1 t log ⁡ a i ) ) (9) \log y_n = \sum_{t=1}^{n} \log a_t + \log \left(y_0 + \sum_{t=1}^{n} \exp \left(\log b_t - {\sum_{i=1}^{t}\log a_i }\right)\right) \tag{9} logyn=t=1nlogat+log(y0+t=1nexp(logbti=1tlogai))(9)

从上述递归式可以看出,有两处可以用两次parallel scan算法

第一次parallel scan计算有序集合 { ∑ t = 1 n log ⁡ a t ∣ n = 1 , 2 , ⋯   , n } \{\sum_{t=1}^{n} \log a_t | n=1, 2, \cdots , n\} {t=1nlogatn=1,2,,n},

第二次parallel scan计算有序集合 { ∑ t = 1 n exp ⁡ ( log ⁡ b t − ∑ i = 1 t log ⁡ a i ) ∣ n = 1 , 2 , ⋯   , n } \{\sum_{t=1}^{n} \exp \left(\log b_t - {\sum_{i=1}^{t}\log a_i }\right)|n=1, 2, \cdots , n\} {t=1nexp(logbti=1tlogai)n=1,2,,n}

有了他们,我们可以并行计算有序集合 { y n ∣ n = 1 , 2 , ⋯   , n } \{ y_n | n=1, 2, \cdots , n \} {ynn=1,2,,n}

下面来看,如何将LSTM,GRU转变为式(6)的形式

2.2.2 LSTM的并行化

Step 1: Drop previous hidden state dependencies from gates

F ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) I ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) O ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) C ~ ( t ) = tanh ⁡ ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) ⇒ F ( t ) = S i g m o i d ( L i n e a r X ( t ) ) ) I ( t ) = S i g m o i d ( L i n e a r ( X ( t ) ) ) O ( t ) = S i g m o i d ( L i n e a r ( X ( t ) ) ) C ~ ( t ) = tanh ⁡ ( L i n e a r ( X ( t ) ) ) (10) \boxed{\begin{aligned} F(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ I(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ O(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ \widetilde{C}(t) &= \mathrm{\tanh}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ \end{aligned}} \Rightarrow \begin{aligned} F(t) &= \mathrm{Sigmoid}(\mathrm{Linear}X(t))) \\ I(t) &= \mathrm{Sigmoid}(\mathrm{Linear}( X(t))) \\ O(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(X(t))) \\ \widetilde{C}(t) &= \mathrm{\tanh}(\mathrm{Linear}(X(t))) \\ \end{aligned} \tag{10} F(t)I(t)O(t)C (t)=Sigmoid(Linear(Cat([H(t1),X(t)])))=Sigmoid(Linear(Cat([H(t1),X(t)])))=Sigmoid(Linear(Cat([H(t1),X(t)])))=tanh(Linear(Cat([H(t1),X(t)])))F(t)I(t)O(t)C (t)=Sigmoid(LinearX(t)))=Sigmoid(Linear(X(t)))=Sigmoid(Linear(X(t)))=tanh(Linear(X(t)))(10)

Step 2: Drop range restriction of candidate states

C ~ ( t ) = tanh ⁡ ( L i n e a r ( X ( t ) ) ) H ( t ) = tanh ⁡ ( C ( t ) ) ⊙ O ( t ) ⇒ C ~ ( t ) = L i n e a r ( X ( t ) ) H ( t ) = C ( t ) ⊙ O ( t ) (11) \boxed{\begin{aligned} \widetilde{C}(t) &= \mathrm{\tanh}(\mathrm{Linear}(X(t))) \\ H(t) & = \tanh(C(t)) \odot O(t) \end{aligned}} \Rightarrow \begin{aligned} \widetilde{C}(t) &= \mathrm{Linear}(X(t)) \\ H(t) & = C(t) \odot O(t) \end{aligned} \tag{11} C (t)H(t)=tanh(Linear(X(t)))=tanh(C(t))O(t)C (t)H(t)=Linear(X(t))=C(t)O(t)(11)

Step 3: Ensure output is time-independent in scale

F ( t ) = S i g m o i d ( L i n e a r X ( t ) ) ) I ( t ) = S i g m o i d ( L i n e a r ( X ( t ) ) ) O ( t ) = S i g m o i d ( L i n e a r ( X ( t ) ) ) C ~ ( t ) = L i n e a r ( X ( t ) ) C ( t ) = C ( t − 1 ) ⊙ F ( t ) + I ( t ) ⊙ C ~ ( t ) H ( t ) = C ( t ) ⊙ O ( t ) ⇒ H ~ ( t ) = S i g m o i d ( L i n e a r ( X ( t ) ) ) F ′ ( t ) = F ( t ) F ( t ) + I ( t ) I ′ ( t ) = I ( t ) F ( t ) + I ( t ) H ( t ) = F ′ ( t ) ⊙ H ( t − 1 ) + I ′ ( t ) ⊙ H ~ ( t ) (12) \boxed{\begin{aligned} F(t) &= \mathrm{Sigmoid}(\mathrm{Linear}X(t))) \\ I(t) &= \mathrm{Sigmoid}(\mathrm{Linear}( X(t))) \\ O(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(X(t))) \\ \widetilde{C}(t) &= \mathrm{Linear}(X(t)) \\ \hline C(t) & =C(t-1) \odot F(t) + I(t) \odot \widetilde{C}(t) \\ H(t) & = C(t) \odot O(t) \end{aligned}} \Rightarrow \begin{aligned} \widetilde{H}(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(X(t))) \\ F'(t) & = \frac{F(t)}{F(t) + I(t)} \\ I'(t) & = \frac{I(t)}{F(t) + I(t)} \\ \hline H(t) & = F'(t) \odot H(t-1) + I'(t) \odot \widetilde{H}(t) \\ \end{aligned} \tag{12} F(t)I(t)O(t)C (t)C(t)H(t)=Sigmoid(LinearX(t)))=Sigmoid(Linear(X(t)))=Sigmoid(Linear(X(t)))=Linear(X(t))=C(t1)F(t)+I(t)C (t)=C(t)O(t)H (t)F(t)I(t)H(t)=Sigmoid(Linear(X(t)))=F(t)+I(t)F(t)=F(t)+I(t)I(t)=F(t)H(t1)+I(t)H (t)(12)

通过上述的操作,结合文献[6]的技巧(式9)完成LSTM的并行化。

2.2.3 GRU的并行化

GRU的并行化的操作和LSTM类似

Step 1: Drop previous hidden state dependencies from gates

R ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) Z ( t ) = S i g m o i d ( L i n e a r ( C a t ( [ H ( t − 1 ) , X ( t ) ] ) ) ) H ~ ( t ) = tanh ⁡ ( L i n e a r ( C a t ( [ H ( t − 1 ) ⊙ R ( t ) , X ( t ) ] ) ) ) ⇒ Z ( t ) = S i g m o i d ( L i n e a r ( X ( t ) ) ) H ~ ( t ) = tanh ⁡ ( L i n e a r ( X ( t ) ) ) (13) \boxed{\begin{aligned} R(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ Z(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(\mathrm{Cat}([H(t-1), X(t)]))) \\ \widetilde {H}(t) &= \mathrm{\tanh} \bigl(\mathrm{Linear}(\mathrm{Cat}([H(t-1) \odot R(t), X(t)] ))\bigr) \\ \end{aligned}} \Rightarrow \begin{aligned} Z(t) &= \mathrm{Sigmoid}(\mathrm{Linear}(X(t))) \\ \widetilde {H}(t) &= \mathrm{\tanh} (\mathrm{Linear}( X(t) ))\\ \end{aligned} \tag{13} R(t)Z(t)H (t)=Sigmoid(Linear(Cat([H(t1),X(t)])))=Sigmoid(Linear(Cat([H(t1),X(t)])))=tanh(Linear(Cat([H(t1)R(t),X(t)])))Z(t)H (t)=Sigmoid(Linear(X(t)))=tanh(Linear(X(t)))(13)

Step 2: Drop range restriction of candidate states

H ~ ( t ) = tanh ⁡ ( L i n e a r ( X ( t ) ) ) ⇒ H ~ ( t ) = L i n e a r ( X ( t ) ) (14) \boxed{\begin{aligned} \widetilde {H}(t) &= \mathrm{\tanh} (\mathrm{Linear}( X(t) )) \end{aligned}} \Rightarrow \widetilde {H}(t)= \mathrm{Linear}( X(t) ) \tag{14} H (t)=tanh(Linear(X(t)))H (t)=Linear(X(t))(14)

3 小结

本文从parallel scan算法出发,介绍了如何将经典RNN算法——LSTM,GRU进行变换,使其能够并行化。实验结果本文不做介绍,请参考原论文。

Reference:

[1] Prefix Sums and Their Applications

[2] A parallel algorithm for the efficient solution of a general class of recurrence equations

[3] A Regular Layout for Parallel Adders

[4] LONG SHORT-TERM MEMORY

[5] Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling

[6] Efficient Parallelization of a Ubiquitous Sequential Computation

[7] https://www.csd.uwo.ca/~mmorenom/HPC-Slides/Parallel_prefix_sum.pdf

[8] https://people.cs.vt.edu/yongcao/teaching/cs5234/spring2013/slides/Lecture10.pdf

[9] https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda


http://www.kler.cn/a/405630.html

相关文章:

  • 移动充储机器人“小奥”的多场景应用(上)
  • react中Fragment的使用场景
  • Transformer中的Self-Attention机制如何自然地适应于目标检测任务
  • SpringBoot(8)-任务
  • Cesium 加载B3DM模型
  • Elasticsearch面试内容整理-实践与应用场景
  • SQL 通配符
  • Java并发CountDownLatch:原理、机制与应用场景
  • 基于SpringBoot的在线教育系统【附源码】
  • bert的模型训练和使用情绪识别
  • 【大数据学习 | Spark】yarn-client与yarn-cluster的区别
  • eclipse-git项目提示NO-HEAD
  • Label-studio-ml-backend 和YOLOV8 YOLO11自动化标注,目标检测,实例分割,图像分类,关键点估计,视频跟踪
  • 后端数据增删改查基于Springboot+mybatis mysql 时间根据当时时间自动填充,数据库连接查询不一致,mysql数据库连接不好用
  • 23省赛区块链应用与维护(房屋租凭)
  • Windows系统编程 - 注册表
  • python语言基础-5 进阶语法-5.4 正则表达式
  • Flink CDC的安装配置
  • 招聘和面试
  • MySQL性能分析工具的使用
  • 用python简单集成一个分词工具
  • 基于 DRNN 神经网络整定的 PID 解耦控制
  • Python 使用 Selenuim进行自动化点击入门,谷歌驱动,以百度为例
  • 数据驱动与并行策略:用 JUnit 5 让软件测试更高效
  • 前端面试题大汇总:React 篇
  • 2025杭州国际智能网联新能源汽车展览会