当前位置: 首页 > article >正文

什么是循环神经网络?

一、概念

        循环神经网络(Recurrent Neural Network, RNN)是一类用于处理序列数据的神经网络。与传统的前馈神经网络不同,RNN具有循环连接,可以利用序列数据的时间依赖性。正因如此,RNN在自然语言处理、时间序列预测、语音识别等领域有广泛的应用。

二、核心算法

        RNN的核心思想是通过循环连接来记忆和处理序列数据中的上下文信息。RNN在每个时间步(time step)上接收当前输入和前一个时间步的隐藏状态(hidden state),并输出当前时间步的隐藏状态和输出。RNN的基本结构包括输入层、隐藏层和输出层。隐藏层的状态在每个时间步都会更新,并传递到下一个时间步。具体来说,RNN在时间步 t 的计算过程如下:

1、输入

        获取当前时间步的输入x_{t}

2、隐藏状态更新

        根据当前的输入x_{t}和前一个时间步的隐藏状态h_{t-1},计算当前时间步的隐藏状态h_{t}

h_{t} = \sigma(W_{xh}x_{t}+W_{hh}h_{t-1}+b_{h})

        其中,W_{xh}是输入x到隐藏层的权重矩阵,W_{hh}是隐藏层到隐藏层的权重矩阵,b_{h}是隐藏层的偏置向量,\sigma是激活函数。

3、输出

        根据当前时间步的隐藏状态h_{t}计算当前时间步的输出y_{t}

y_{t} = \phi (W_{hy}h_{t}+b_{y})

        其中,W_{hy}是隐藏层到输出层的权重矩阵,b_{y}是输出层的偏置向量,\phi是输出层的激活函数。

三、python实现

        这里,我们使用一个简单的正弦波数据来演示RNN在时间序列预测中的应用。

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# 设置随机种子
torch.manual_seed(0)
np.random.seed(0)

# 生成正弦波数据
timesteps = 100
sin_wave = np.array([np.sin(2 * np.pi * i / timesteps) for i in range(timesteps)])

# 创建数据集
def create_dataset(data, time_step=1):
    dataX, dataY = [], []
    for i in range(len(data) - time_step - 1):
        a = data[i:(i + time_step)]
        dataX.append(a)
        dataY.append(data[i + time_step])
    return np.array(dataX), np.array(dataY)

time_step = 10
X, y = create_dataset(sin_wave, time_step)

# 数据预处理
X = X.reshape(X.shape[0], time_step, 1)
y = y.reshape(-1, 1)

# 转换为Tensor
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)

# 划分训练集和测试集
train_size = int(len(X) * 0.7)
test_size = len(X) - train_size
trainX, testX = X[:train_size], X[train_size:]
trainY, testY = y[:train_size], y[train_size:]

# 定义RNN模型
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size)
        out, _ = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])
        return out

input_size = 1
hidden_size = 50
output_size = 1
model = RNNModel(input_size, hidden_size, output_size)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 训练模型
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(trainX)
    loss = criterion(outputs, trainY)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# 预测
model.eval()
train_predict = model(trainX)
test_predict = model(testX)
train_predict = train_predict.detach().numpy()
test_predict = test_predict.detach().numpy()

# 绘制结果
plt.figure(figsize=(10, 6))
plt.plot(sin_wave, label='Original Data')
plt.plot(np.arange(time_step, time_step + len(train_predict)), train_predict, label='Training Predict')
plt.plot(np.arange(time_step + len(train_predict), time_step + len(train_predict) + len(test_predict)), test_predict, label='Test Predict')
plt.legend()
plt.show()

四、总结

        RNN能够处理任意长度的序列数据,这使得它在自然语言处理、时间序列预测、语音识别等领域非常有用。然而,标准RNN在捕捉长时间依赖关系方面表现不佳。为了解决这个问题,大佬们相继提出了长短期记忆网络(LSTM)和门控循环单元(GRU)等RNN的变体,在增加模型的复杂度的同时提高了RNN对长时间依赖关系的捕捉能力。此外,由于RNN的时间步之间存在依赖关系,因此难以进行并行化处理。


http://www.kler.cn/a/522915.html

相关文章:

  • “AI视频智能分析系统:让每一帧视频都充满智慧
  • 抖音上线打车服务?抖音要大规模杀入网约车了吗?
  • Linux之详谈——权限管理
  • Ceph:关于Ceph 中使用 RADOS 块设备提供块存储的一些笔记整理(12)
  • 开源项目Umami网站统计MySQL8.0版本Docker+Linux安装部署教程
  • MiniHack:为强化学习研究提供丰富而复杂的环境
  • python.tkinter设计标记语言(渲染7-动态呈现标签) - 副本
  • 1.2第1章DC/DC变换器的动态建模-1.2Buck-Boost 变换器的交流模型--电力电子系统建模及控制 (徐德鸿)--读书笔记
  • game101 环节搭建 windows 平台 vs2022
  • doris:STRUCT
  • 【阅读笔记】New Edge Diected Interpolation,NEDI算法,待续
  • 跨域问题解释及前后端解决方案(SpringBoot)
  • 接口技术-第2次作业
  • Gradle配置指南:深入解析settings.gradle.kts(Kotlin DSL版)
  • AAAI2024论文合集解读|Multi-dimensional Fair Federated Learning-water-merged
  • IBMSamllPower服务器监控指标解读
  • 【数据库初阶】表的查询语句和聚合函数
  • leetcode 2920. 收集所有金币可获得的最大积分
  • 10 款《医学数据库和期刊》查阅网站
  • Lesson 119 A true story
  • 蓝桥杯模拟算法:多项式输出
  • 【Prometheus】Prometheus如何监控Haproxy
  • 菜鸟之路Day09一一集合进阶(二)
  • 【公因数匹配——暴力、(质)因数分解、哈希】
  • Github 2025-01-27 开源项目周报 Top15
  • 第 4 章:游戏逻辑与状态管理