当前位置: 首页 > article >正文

什么是门控循环单元?

一、概念

        门控循环单元(Gated Recurrent Unit,GRU)是一种改进的循环神经网络(RNN),由Cho等人在2014年提出。GRU是LSTM的简化版本,通过减少门的数量和简化结构,保留了LSTM的长时间依赖捕捉能力,同时提高了计算效率。GRU通过引入两个门(重置门和更新门)来控制信息的流动。与LSTM不同,GRU没有单独的细胞状态,而是将隐藏状态直接作为信息传递的载体,因此结构更简单,计算效率更高。

二、核心算法

        令x_{t}为时间步 t 的输入向量,h_{t-1}为前一个时间步的隐藏状态向量,h_{t}为当前时间步的隐藏状态向量,r_{t}为当前时间步的重置门向量,z_{t}为当前时间步的更新门向量,\bar{h_{t}}为当前时间步的候选隐藏状态向量,W_{r},W_{z},W_{h}分别为各门的权重矩阵,b_{r},b_{z},b_{h}为偏置向量,\sigma为sigmoid激活函数,tanh为tanh激活函数,*为元素级乘法。

1、重置门

        重置门控制前一个时间步的隐藏状态对当前时间步的影响。通过sigmoid激活函数,重置门的输出在0到1之间,表示前一个隐藏状态元素被保留的比例。

r_{t} = \sigma(W_{r} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{r})

2、更新门

        更新门控制前一个时间步的隐藏状态和当前时间步的候选隐藏状态的混合比例。通过sigmoid激活函数,更新门的输出在0到1之间,表示前一个隐藏状态元素被保留的比例。

z_{t} = \sigma(W_{z} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{z})

3、候选隐藏状态

        候选隐藏状态结合当前输入和前一个时间步的隐藏状态生成。重置门的输出与前一个隐藏状态相乘,表示保留的旧信息。然后与当前输入一起通过tanh激活函数生成候选隐藏状态。

\bar{h_{t}} = tanh(W_{h} \cdot \left [ r_{t} \ast h_{t-1}, x_{t} \right ] + b_{h})

4、隐藏状态更新

        隐藏状态结合更新门的结果进行更新。更新门的输出与前一个隐藏状态相乘,表示保留的旧信息。更新门的补数与候选隐藏状态相乘,表示写入的新信息。两者相加得到当前时间步的隐藏状态。

h_{t} = (1-z_{t}) \ast h_{t-1} + z_{t} \ast \bar{h_{t}}

三、python实现

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
 
# 设置随机种子
torch.manual_seed(0)
np.random.seed(0)
 
# 生成正弦波数据
timesteps = 1000
sin_wave = np.array([np.sin(2 * np.pi * i / timesteps) for i in range(timesteps)])
 
# 创建数据集
def create_dataset(data, time_step=1):
    dataX, dataY = [], []
    for i in range(len(data) - time_step - 1):
        a = data[i:(i + time_step)]
        dataX.append(a)
        dataY.append(data[i + time_step])
    return np.array(dataX), np.array(dataY)
 
time_step = 10
X, y = create_dataset(sin_wave, time_step)
 
# 数据预处理
X = X.reshape(X.shape[0], time_step, 1)
y = y.reshape(-1, 1)
 
# 转换为Tensor
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)
 
# 划分训练集和测试集
train_size = int(len(X) * 0.7)
test_size = len(X) - train_size
trainX, testX = X[:train_size], X[train_size:]
trainY, testY = y[:train_size], y[train_size:]
 
# 定义RNN模型
class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
 
    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size)
        out, _ = self.gru(x, h0)
        out = self.fc(out[:, -1, :])
        return out
 
input_size = 1
hidden_size = 50
output_size = 1
model = GRUModel(input_size, hidden_size, output_size)
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
# 训练模型
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(trainX)
    loss = criterion(outputs, trainY)
    loss.backward()
    optimizer.step()
 
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
 
# 预测
model.eval()
train_predict = model(trainX)
test_predict = model(testX)
train_predict = train_predict.detach().numpy()
test_predict = test_predict.detach().numpy()
 
# 绘制结果
plt.figure(figsize=(10, 6))
plt.plot(sin_wave, label='Original Data')
plt.plot(np.arange(time_step, time_step + len(train_predict)), train_predict, label='Training Predict')
plt.plot(np.arange(time_step + len(train_predict), time_step + len(train_predict) + len(test_predict)), test_predict, label='Test Predict')
plt.legend()
plt.show()

四、总结

        GRU的结构比LSTM更简单,只有两个门(重置门和更新门),没有单独的细胞状态。这使得GRU的计算复杂度较低,训练和推理速度更快。通过引入重置门和更新门,GRU也有效地解决了标准RNN在处理长序列时的梯度消失和梯度爆炸问题。然而,在需要更精细的门控制和信息流动的任务中,LSTM的性能可能优于GRU。因此在我们实际的建模过程中,可以根据数据特点选择合适的RNN系列模型,并没有哪个模型能在所有任务中都具有优势。


http://www.kler.cn/a/528641.html

相关文章:

  • 单细胞-第四节 多样本数据分析,下游画图
  • 【算法设计与分析】实验8:分支限界—TSP问题
  • git基础使用--1--版本控制的基本概念
  • AI学习指南HuggingFace篇-Hugging Face 的环境搭建
  • React 的 12 个核心概念
  • JxBrowser 7.41.7 版本发布啦!
  • 爬取鲜花网站数据
  • 使用 Docker(Podman) 部署 MongoDB 数据库及使用详解
  • 白话DeepSeek-R1论文(三)| DeepSeek-R1蒸馏技术:让小模型“继承”大模型的推理超能力
  • 为AI聊天工具添加一个知识系统 之82 详细设计之23 符号逻辑 正则表达式规则 之1
  • 如何实现滑动列表功能
  • 智慧园区综合管理系统如何实现多个维度的高效管理与安全风险控制
  • c++ list的front和pop_front的概念和使用案例
  • 【3】阿里面试题整理
  • http 请求类型及其使用场景
  • python学习——函数的返回值
  • 【python】tkinter实现音乐播放器(源码+音频文件)【独一无二】
  • error: RPC failed; curl 56 OpenSSL SSL_read: SSL_ERROR_SYSCALL, errno 10054
  • C#面向对象(封装)
  • C语言:整型提升
  • 前端知识速记:节流与防抖
  • Vue2.x简介
  • MongoDB快速上手(包会用)
  • 浅析DDOS攻击及防御策略
  • Linux系统部署Python项目持续运行配置
  • 数据结构【单链表操作大全详解】【c语言版】(只有输入输出为了方便用的c++)