当前位置: 首页 > article >正文

使用LSTM模型进行时间序列数据预测的示例

代码功能

这段代码展示了如何使用 LSTM(长短期记忆网络)模型对复杂的时间序列数据进行训练和预测。代码的主要功能分为以下几个步骤:
生成复杂的时间序列数据:通过将线性趋势、周期性正弦波和随机噪声相结合,生成模拟的时间序列数据。
数据预处理:使用 MinMaxScaler 将数据归一化,转换为适合 LSTM 模型的格式。
数据集准备:将时间序列数据转换为特定的输入输出格式,使用过去的 10 个时间步作为输入,预测下一个时间步的数据。
构建和训练 LSTM 模型:通过 Keras 构建一个两层 LSTM 网络,并使用均方误差损失函数和 Adam 优化器进行模型训练。
模型预测:使用训练好的模型对输入数据进行预测,并将预测值反归一化为原始范围。
可视化:绘制时间序列数据和模型预测结果的对比图,展示模型的预测效果。
最终,该模型可以用于对复杂时间序列数据进行预测,并可视化预测结果与真实数据的对比。
在这里插入图片描述
在这里插入图片描述

代码

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.layers import Input

# 1. 生成更加复杂的时间序列数据
np.random.seed(42)
time_steps = np.arange(100)
trend = 0.05 * time_steps  # 线性趋势
seasonal = 10 * np.sin(0.2 * time_steps)  # 周期性
noise = np.random.normal(scale=2, size=100)  # 噪声
data = trend + seasonal + noise

# 2. 可视化生成的时间序列数据
plt.figure(figsize=(10, 6))
plt.plot(time_steps, data, label="Complex Time Series Data")
plt.xlabel("Time Steps")
plt.ylabel("Value")
plt.title("Complex Time Series Data")
plt.show()


# 3. 数据预处理
scaler = MinMaxScaler(feature_range=(0, 1))
data_scaled = scaler.fit_transform(data.reshape(-1, 1))

# 4. 准备数据:将时间序列数据转换为 LSTM 可用的形状
def create_dataset(data, time_step=5):
    X, y = [], []
    for i in range(len(data) - time_step):
        X.append(data[i:(i + time_step), 0])
        y.append(data[i + time_step, 0])
    return np.array(X), np.array(y)

time_step = 10  # 使用过去10个时间步来预测下一个时间步
X, y = create_dataset(data_scaled, time_step)

# 重塑 X 使其符合 LSTM 输入格式: [样本数, 时间步长, 特征数]
X = X.reshape(X.shape[0], X.shape[1], 1)

# 5. 构建 LSTM 模型
model = Sequential()
model.add(Input(shape=(time_step, 1)))  # 使用 Input 层来定义输入形状
model.add(LSTM(50, return_sequences=True))
model.add(LSTM(50))
model.add(Dense(1))

model.compile(loss='mean_squared_error', optimizer='adam')

# 6. 训练模型
model.fit(X, y, epochs=100, batch_size=16, verbose=1)

# 7. 用模型进行预测
train_predict = model.predict(X)

# 将预测值反归一化
train_predict = scaler.inverse_transform(train_predict.reshape(-1, 1))

# 8. 可视化真实数据和预测数据
plt.figure(figsize=(10, 6))
plt.plot(time_steps[time_step:], data[time_step:], label="True Data")
plt.plot(time_steps[time_step:], train_predict, label="Predicted Data", color="red", linestyle="--")
plt.xlabel("Time Steps")
plt.ylabel("Value")
plt.title("LSTM Time Series Prediction")
plt.legend()
plt.show()



http://www.kler.cn/news/322469.html

相关文章:

  • 代码随想录算法训练营Day10
  • 611. 有效三角形的个数
  • 【d52】【Java】【力扣】19.删除链表的倒数第N个节点
  • Python | Leetcode Python题解之第432题全O(1)的数据结构
  • windows端后端运行python程序,类似nohup
  • 大数据Flink(一百二十四):案例实践——淘宝母婴数据加速查询
  • 优青博导团队携手提供组学技术服务、表观组分析、互作组分析、遗传转化实验、单细胞检测等全方位生物医学支持
  • 微服务--ES(Elasticsearch)
  • 如何在谷歌浏览器上玩大型多人在线游戏
  • 【软考】结构化分析方法概述
  • 车载视频监控:安全生产与管理的新趋势
  • 笔记整理—linux进程部分(1)进程终止函数注册、进程环境、进程虚拟地址
  • 基于顺序表的通讯录(纯代码)
  • 「漏洞复现」誉龙视音频综合管理平台 RelMedia/FindById SQL注入漏洞
  • 【大模型-驯化】成功解决载cuda-11.8配置下搭建swift框架
  • VSCode rust文件中的api点击无法跳转问题
  • Request 原理
  • 的使用和内联函数
  • 【Spring Cloud】Spring Cloud 概述
  • 【计算机网络 - 基础问题】每日 3 题(十七)
  • 《JKTECH柔性振动盘:原理与多行业应用》东莞市江坤自动化科技有限公司
  • TOF系列—深度图滤波
  • 手搓一个Agent#Datawhale 组队学习Task3
  • Android常用C++特性之std::move
  • 【机器学习(九)】分类和回归任务-多层感知机(Multilayer Perceptron,MLP)算法-Sentosa_DSML社区版
  • 华为HarmonyOS灵活高效的消息推送服务(Push Kit) -- 6 撤回通知消息
  • tomcat 文件上传 (CVE-2017-12615)
  • 计算机知识科普问答--21(101-105)
  • 【FE】NPM——概述
  • 13年408计算机考研-计算机网络