当前位置: 首页 > article >正文

一个超级简单的清晰的LSTM模型的例子

废话不多说,把代码贴上去,就可以运行。然后看注释,自己慢慢品,细细品。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


# 1. 生成时间序列数据,这里使用正弦函数模拟
def generate_time_series():
    time_steps = np.linspace(0, 10 * np.pi, 500, dtype=np.float32)
    data = np.sin(time_steps)
    data = np.expand_dims(data, axis=-1)
    return data


# 2. 划分训练集和测试集
def prepare_data(data):
    train_data = data[:400]
    test_data = data[400:]
    return train_data, test_data


# 3. 创建数据集
def create_dataset(data, time_steps):
    Xs, ys = [], []
    for i in range(len(data) - time_steps):
        v = data[i:(i + time_steps)]
        Xs.append(v)
        ys.append(data[i + time_steps])
    return np.array(Xs), np.array(ys)


# 4. 定义 LSTM 模型
def build_model(input_shape):
    model = tf.keras.Sequential([
        tf.keras.layers.LSTM(units=50, return_sequences=True, input_shape=input_shape),
        tf.keras.layers.LSTM(units=50),
        tf.keras.layers.Dense(units=1)
    ])
    model.compile(optimizer='adam', loss='mse')
    return model


# 5. 训练模型
def train_model(model, X_train, y_train, epochs=20):
    history = model.fit(
        X_train, y_train,
        epochs=epochs,
        batch_size=32,
        validation_split=0.1,
        shuffle=False
    )
    return history


# 6. 预测和可视化结果
def predict_and_visualize(model, X_train, y_train, X_test, y_test):
    train_predict = model.predict(X_train)
    test_predict = model.predict(X_test)

    plt.figure(figsize=(10, 6))
    plt.plot(y_train, label='True Train')
    plt.plot(train_predict, label='Predicted Train')
    plt.plot(range(len(y_train), len(y_train) + len(y_test)), y_test, label='True Test')
    plt.plot(range(len(y_train), len(y_train) + len(y_test)), test_predict, label='Predicted Test')
    plt.legend(loc='upper left')
    plt.show()


def plot_loss(history):
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(loc='upper right')
    plt.show()


if __name__ == "__main__":
    # 生成数据
    data = generate_time_series()
    train_data, test_data = prepare_data(data)
    time_steps = 10
    X_train, y_train = create_dataset(train_data, time_steps)
    X_test, y_test = create_dataset(test_data, time_steps)

    # 构建模型
    input_shape = (X_train.shape[1], X_train.shape[2])
    model = build_model(input_shape)

    # 训练模型
    history = train_model(model, X_train, y_train, epochs=20)

    # 显示训练的loss,val_loss
    plot_loss(history)

    # 预测和可视化
    predict_and_visualize(model, X_train, y_train, X_test, y_test)

http://www.kler.cn/a/506411.html

相关文章:

  • 【Flink系列】4. Flink运行时架构
  • “扣子”开发之四:与千帆AppBuilder比较
  • qml LevelAdjust详解
  • ClickHouse-CPU、内存参数设置
  • OpenCV实现Kuwahara滤波
  • Jmeter进行http接口并发测试
  • (双系统)Ubuntu+Windows解决grub引导问题和启动黑屏问题
  • 记录一次RPC服务有损上线的分析过程
  • 2025年01月14日Github流行趋势
  • Elasticsearch容器启动报错:AccessDeniedException[/usr/share/elasticsearch/data/nodes];
  • 栈算法篇——LIFO后进先出,数据与思想的层叠乐章(下)
  • MATLAB自带函数,使用遗传算法,求函数最小值,附代码
  • 用python进行大恒相机的调试
  • SpringSecurity-前后端分离
  • 码编译安装httpd 2.4,测试
  • CryptoMamba:利用状态空间模型实现精确的比特币价格预测
  • 基于多个边缘盒子部署的综合视频安防系统的智慧地产开源了
  • Python如何在指定行追加内容
  • IDEA测试报错java.lang.NullPointerException空指针异常解决办法
  • Jetbrains 官方微信小程序插件已上线!
  • 数据存取:存取方式、操作、技术、挑战、相关学术分享
  • Docker 的安装和基本使用[SpringBoot之Docker实战系列] - 第535篇
  • vue中使用OpenLayer加载Geoserver的WMS
  • javascript基础从小白到高手系列一十二:JSON
  • 麦田物语学习笔记:构建游戏的时间系统
  • 常见链表专题相关算法