循环神经网络RNN-数据流动
1. RNN 的结构概述
RNN 的核心是一个循环单元,它在每个时间步接收两个输入:
- 当前时间步的输入数据 x t x_t xt(例如词向量)。
- 上一个时间步的隐藏状态 h t − 1 h_{t-1} ht−1。
然后,RNN 会输出:
- 当前时间步的隐藏状态 h t h_t ht(传递给下一个时间步)。
- 当前时间步的输出 o t o_t ot(通常用于预测任务)。
2. 数学公式
2.1 RNN 的隐藏状态计算
h t = tanh ( W i h x t + b i h + W h h h t − 1 + b h h ) h_t = \text{tanh}(W_{ih} x_t + b_{ih} + W_{hh} h_{t-1} + b_{hh}) ht=tanh(Wihxt+bih+Whhht−1+bhh)
2.2 全连接层的输出计算
o t = W h o h t + b h o o_t = W_{ho} h_t + b_{ho} ot=Whoht+bho
其中:
- x t x_t xt:当前时间步的输入(向量)。
- h t − 1 h_{t-1} ht−1:上一个时间步的隐藏状态(向量)。
- h t h_t ht:当前时间步的隐藏状态(向量)。
- o t o_t ot:当前时间步的输出(向量)。
- W i h W_{ih} Wih:输入到隐藏层的权重矩阵。
- W h h W_{hh} Whh:隐藏层到隐藏层的权重矩阵。
- W h o W_{ho} Who:隐藏层到输出层的权重矩阵。
- b i h b_{ih} bih 和 b h h b_{hh} bhh:隐藏层的偏置项。
- b h o b_{ho} bho:输出层的偏置项。
3. 数据的流动
3.1 输入数据 x t x_t xt
- 形状:
[input_dim]
。 - 例如,
input_dim=100
,表示输入是一个 100 维的向量。
3.2 隐藏状态 h t − 1 h_{t-1} ht−1
- 形状:
[hidden_dim]
。 - 例如,
hidden_dim=256
,表示隐藏状态是一个 256 维的向量。
3.3 权重矩阵
-
W
i
h
W_{ih}
Wih:形状为
[hidden_dim, input_dim]
。- 将输入
x
t
x_t
xt 从
input_dim
映射到hidden_dim
。
- 将输入
x
t
x_t
xt 从
-
W
h
h
W_{hh}
Whh:形状为
[hidden_dim, hidden_dim]
。- 将隐藏状态
h
t
−
1
h_{t-1}
ht−1 从
hidden_dim
映射到hidden_dim
。
- 将隐藏状态
h
t
−
1
h_{t-1}
ht−1 从
-
W
h
o
W_{ho}
Who:形状为
[output_dim, hidden_dim]
。- 将隐藏状态
h
t
h_t
ht 从
hidden_dim
映射到output_dim
。
- 将隐藏状态
h
t
h_t
ht 从
3.4 偏置项
-
b
i
h
b_{ih}
bih 和
b
h
h
b_{hh}
bhh:形状为
[hidden_dim]
。 -
b
h
o
b_{ho}
bho:形状为
[output_dim]
。
3.5 计算过程
- 计算
W
i
h
x
t
W_{ih} x_t
Wihxt:
- 输入
x
t
x_t
xt 的形状是
[input_dim]
。 - 权重
W
i
h
W_{ih}
Wih 的形状是
[hidden_dim, input_dim]
。 - 结果是
[hidden_dim]
。
- 输入
x
t
x_t
xt 的形状是
- 计算
W
h
h
h
t
−
1
W_{hh} h_{t-1}
Whhht−1:
- 隐藏状态
h
t
−
1
h_{t-1}
ht−1 的形状是
[hidden_dim]
。 - 权重
W
h
h
W_{hh}
Whh 的形状是
[hidden_dim, hidden_dim]
。 - 结果是
[hidden_dim]
。
- 隐藏状态
h
t
−
1
h_{t-1}
ht−1 的形状是
- 相加:
-
W
i
h
x
t
+
b
i
h
+
W
h
h
h
t
−
1
+
b
h
h
W_{ih} x_t + b_{ih} + W_{hh} h_{t-1} + b_{hh}
Wihxt+bih+Whhht−1+bhh 的结果形状是
[hidden_dim]
。
-
W
i
h
x
t
+
b
i
h
+
W
h
h
h
t
−
1
+
b
h
h
W_{ih} x_t + b_{ih} + W_{hh} h_{t-1} + b_{hh}
Wihxt+bih+Whhht−1+bhh 的结果形状是
- 应用激活函数:
- tanh \text{tanh} tanh 是逐元素操作的,不会改变形状。
- 最终结果
h
t
h_t
ht 的形状是
[hidden_dim]
。
- 计算全连接层的输出
o
t
o_t
ot:
- 输入
h
t
h_t
ht 的形状是
[hidden_dim]
。 - 权重
W
h
o
W_{ho}
Who 的形状是
[output_dim, hidden_dim]
。 - 结果是
[output_dim]
。
- 输入
h
t
h_t
ht 的形状是
4. 具体例子
假设:
input_dim=100
(输入维度为 100)。hidden_dim=256
(隐藏层维度为 256)。output_dim=10
(输出维度为 10,例如 10 分类任务)。- 输入 x t x_t xt 是一个 100 维的向量。
- 隐藏状态 h t − 1 h_{t-1} ht−1 是一个 256 维的向量。
4.1 输入 x t x_t xt
- 形状:
[100]
。 - 例如: x t = [ x 1 , x 2 , … , x 100 ] x_t = [x_1, x_2, \dots, x_{100}] xt=[x1,x2,…,x100]。
4.2 隐藏状态 h t − 1 h_{t-1} ht−1
- 形状:
[256]
。 - 例如: h t − 1 = [ h 1 , h 2 , … , h 256 ] h_{t-1} = [h_1, h_2, \dots, h_{256}] ht−1=[h1,h2,…,h256]。
4.3 权重矩阵
-
W
i
h
W_{ih}
Wih:形状为
[256, 100]
。- 例如:
W i h = [ w 11 w 12 … w 1 , 100 w 21 w 22 … w 2 , 100 ⋮ ⋮ ⋱ ⋮ w 256 , 1 w 256 , 2 … w 256 , 100 ] W_{ih} = \begin{bmatrix} w_{11} & w_{12} & \dots & w_{1,100} \\ w_{21} & w_{22} & \dots & w_{2,100} \\ \vdots & \vdots & \ddots & \vdots \\ w_{256,1} & w_{256,2} & \dots & w_{256,100} \end{bmatrix} Wih= w11w21⋮w256,1w12w22⋮w256,2……⋱…w1,100w2,100⋮w256,100
- 例如:
-
W
h
h
W_{hh}
Whh:形状为
[256, 256]
。- 例如:
W h h = [ w 11 w 12 … w 1 , 256 w 21 w 22 … w 2 , 256 ⋮ ⋮ ⋱ ⋮ w 256 , 1 w 256 , 2 … w 256 , 256 ] W_{hh} = \begin{bmatrix} w_{11} & w_{12} & \dots & w_{1,256} \\ w_{21} & w_{22} & \dots & w_{2,256} \\ \vdots & \vdots & \ddots & \vdots \\ w_{256,1} & w_{256,2} & \dots & w_{256,256} \end{bmatrix} Whh= w11w21⋮w256,1w12w22⋮w256,2……⋱…w1,256w2,256⋮w256,256
- 例如:
-
W
h
o
W_{ho}
Who:形状为
[10, 256]
。- 例如:
W h o = [ w 11 w 12 … w 1 , 256 w 21 w 22 … w 2 , 256 ⋮ ⋮ ⋱ ⋮ w 10 , 1 w 10 , 2 … w 10 , 256 ] W_{ho} = \begin{bmatrix} w_{11} & w_{12} & \dots & w_{1,256} \\ w_{21} & w_{22} & \dots & w_{2,256} \\ \vdots & \vdots & \ddots & \vdots \\ w_{10,1} & w_{10,2} & \dots & w_{10,256} \end{bmatrix} Who= w11w21⋮w10,1w12w22⋮w10,2……⋱…w1,256w2,256⋮w10,256
- 例如:
4.4 偏置项
-
b
i
h
b_{ih}
bih 和
b
h
h
b_{hh}
bhh:形状为
[256]
。- 例如: b i h = [ b 1 , b 2 , … , b 256 ] b_{ih} = [b_1, b_2, \dots, b_{256}] bih=[b1,b2,…,b256]。
-
b
h
o
b_{ho}
bho:形状为
[10]
。- 例如: b h o = [ b 1 , b 2 , … , b 10 ] b_{ho} = [b_1, b_2, \dots, b_{10}] bho=[b1,b2,…,b10]。
4.5 计算过程
-
计算 W i h x t W_{ih} x_t Wihxt:
W i h x t = [ w 11 x 1 + w 12 x 2 + ⋯ + w 1 , 100 x 100 w 21 x 1 + w 22 x 2 + ⋯ + w 2 , 100 x 100 ⋮ w 256 , 1 x 1 + w 256 , 2 x 2 + ⋯ + w 256 , 100 x 100 ] W_{ih} x_t = \begin{bmatrix} w_{11} x_1 + w_{12} x_2 + \dots + w_{1,100} x_{100} \\ w_{21} x_1 + w_{22} x_2 + \dots + w_{2,100} x_{100} \\ \vdots \\ w_{256,1} x_1 + w_{256,2} x_2 + \dots + w_{256,100} x_{100} \end{bmatrix} Wihxt= w11x1+w12x2+⋯+w1,100x100w21x1+w22x2+⋯+w2,100x100⋮w256,1x1+w256,2x2+⋯+w256,100x100
结果是一个 256 维的向量。 -
计算 W h h h t − 1 W_{hh} h_{t-1} Whhht−1:
W h h h t − 1 = [ w 11 h 1 + w 12 h 2 + ⋯ + w 1 , 256 h 256 w 21 h 1 + w 22 h 2 + ⋯ + w 2 , 256 h 256 ⋮ w 256 , 1 h 1 + w 256 , 2 h 2 + ⋯ + w 256 , 256 h 256 ] W_{hh} h_{t-1} = \begin{bmatrix} w_{11} h_1 + w_{12} h_2 + \dots + w_{1,256} h_{256} \\ w_{21} h_1 + w_{22} h_2 + \dots + w_{2,256} h_{256} \\ \vdots \\ w_{256,1} h_1 + w_{256,2} h_2 + \dots + w_{256,256} h_{256} \end{bmatrix} Whhht−1= w11h1+w12h2+⋯+w1,256h256w21h1+w22h2+⋯+w2,256h256⋮w256,1h1+w256,2h2+⋯+w256,256h256
结果是一个 256 维的向量。 -
相加:
W i h x t + b i h + W h h h t − 1 + b h h W_{ih} x_t + b_{ih} + W_{hh} h_{t-1} + b_{hh} Wihxt+bih+Whhht−1+bhh
结果是一个 256 维的向量。 -
应用激活函数:
h t = tanh ( W i h x t + b i h + W h h h t − 1 + b h h ) h_t = \text{tanh}(W_{ih} x_t + b_{ih} + W_{hh} h_{t-1} + b_{hh}) ht=tanh(Wihxt+bih+Whhht−1+bhh)
结果是一个 256 维的向量。 -
计算全连接层的输出 o t o_t ot:
o t = W h o h t + b h o o_t = W_{ho} h_t + b_{ho} ot=Whoht+bho
其中:-
W
h
o
W_{ho}
Who 的形状是
[10, 256]
。 -
h
t
h_t
ht 的形状是
[256]
。 - 结果是
[10]
。
例如:
o t = [ w 11 h 1 + w 12 h 2 + ⋯ + w 1 , 256 h 256 w 21 h 1 + w 22 h 2 + ⋯ + w 2 , 256 h 256 ⋮ w 10 , 1 h 1 + w 10 , 2 h 2 + ⋯ + w 10 , 256 h 256 ] o_t = \begin{bmatrix} w_{11} h_1 + w_{12} h_2 + \dots + w_{1,256} h_{256} \\ w_{21} h_1 + w_{22} h_2 + \dots + w_{2,256} h_{256} \\ \vdots \\ w_{10,1} h_1 + w_{10,2} h_2 + \dots + w_{10,256} h_{256} \end{bmatrix} ot= w11h1+w12h2+⋯+w1,256h256w21h1+w22h2+⋯+w2,256h256⋮w10,1h1+w10,2h2+⋯+w10,256h256
结果是一个 10 维的向量。 -
W
h
o
W_{ho}
Who 的形状是
5. 总结
- 数据:
- 输入
x
t
x_t
xt:形状为
[input_dim]
。 - 隐藏状态
h
t
−
1
h_{t-1}
ht−1:形状为
[hidden_dim]
。 - 输出
o
t
o_t
ot:形状为
[output_dim]
。
- 输入
x
t
x_t
xt:形状为
- 权重:
-
W
i
h
W_{ih}
Wih:形状为
[hidden_dim, input_dim]
。 -
W
h
h
W_{hh}
Whh:形状为
[hidden_dim, hidden_dim]
。 -
W
h
o
W_{ho}
Who:形状为
[output_dim, hidden_dim]
。
-
W
i
h
W_{ih}
Wih:形状为
- 偏置:
-
b
i
h
b_{ih}
bih 和
b
h
h
b_{hh}
bhh:形状为
[hidden_dim]
。 -
b
h
o
b_{ho}
bho:形状为
[output_dim]
。
-
b
i
h
b_{ih}
bih 和
b
h
h
b_{hh}
bhh:形状为
通过以上步骤,RNN 在每个时间步将输入 x t x_t xt 和隐藏状态 h t − 1 h_{t-1} ht−1 转换为新的隐藏状态 h t h_t ht,并通过全连接层得到输出 o t o_t ot。