当前位置: 首页 > article >正文

用LSTM模型预测股价的例子(1)

文章目录

  • 基本说明
  • 直接上代码
  • 预测结果

基本说明

本实例用的是单个参数“收盘价”,学习后。用10天的收盘价预测后面1天的收盘价。
数据如下图:
在这里插入图片描述
后续我们还要采用这个数据,进一步添加其他的影响因子进行预测。

直接上代码

代码中包含注释,我就不多说了

import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


# 1. 加载数据
def load_stock_data(file_path):
    data = pd.read_csv(file_path)
    return data['close'].values


# 2. 划分训练集和测试集
def prepare_data(data, train_ratio=0.8):
    train_size = int(len(data) * train_ratio)
    train_data = data[:train_size]
    test_data = data[train_size:]
    return train_data, test_data


# 3. 创建数据集
def create_dataset(data, time_steps):
    Xs, ys = [], []
    for i in range(len(data) - time_steps):
        v = data[i:(i + time_steps)]
        Xs.append(v)
        ys.append(data[i + time_steps])
    return np.array(Xs), np.array(ys)


# 4. 定义 LSTM 模型
def build_model(input_shape):
    model = tf.keras.Sequential([
        tf.keras.layers.LSTM(units=50, return_sequences=True, input_shape=input_shape),
        tf.keras.layers.LSTM(units=50),
        tf.keras.layers.Dense(units=1)
    ])
    model.compile(optimizer='adam', loss='mse')
    return model


# 5. 训练模型
def train_model(model, X_train, y_train, epochs=20):
    history = model.fit(
        X_train, y_train,
        epochs=epochs,
        batch_size=32,
        validation_split=0.1,
        shuffle=False
    )
    return history


# 6. 预测和可视化结果
def predict_and_visualize(model, X_train, y_train, X_test, y_test):
    train_predict = model.predict(X_train)
    test_predict = model.predict(X_test)

    plt.figure(figsize=(10, 6))
    plt.plot(y_train, label='True Train')
    plt.plot(train_predict, label='Predicted Train')
    plt.plot(range(len(y_train), len(y_train) + len(y_test)), y_test, label='True Test')
    plt.plot(range(len(y_train), len(y_train) + len(y_test)), test_predict, label='Predicted Test')
    plt.legend(loc='upper left')
    plt.show()


if __name__ == "__main__":
    file_path ='d:/test.csv'
    data = load_stock_data(file_path)
    train_data, test_data = prepare_data(data)
    time_steps = 10
    X_train, y_train = create_dataset(train_data, time_steps)
    X_test, y_test = create_dataset(test_data, time_steps)

    # 数据归一化
    from sklearn.preprocessing import MinMaxScaler
    scaler = MinMaxScaler(feature_range=(0, 1))
    X_train = scaler.fit_transform(X_train.reshape(-1, 1)).reshape(X_train.shape)
    y_train = scaler.transform(y_train.reshape(-1, 1)).reshape(y_train.shape)
    X_test = scaler.transform(X_test.reshape(-1, 1)).reshape(X_test.shape)
    y_test = scaler.transform(y_test.reshape(-1, 1)).reshape(y_test.shape)

    #print(X_train.shape)
    input_shape = (X_train.shape[1], 1)
    model = build_model(input_shape)
    history = train_model(model, X_train, y_train, epochs=20)
    predict_and_visualize(model, X_train, y_train, X_test, y_test)

预测结果

在这里插入图片描述
红色的是预测的,绿色的是实际的。


http://www.kler.cn/a/510543.html

相关文章:

  • 第5章:Python TDD定义Dollar对象相等性
  • HighCharts 交互式图表-01-入门介绍
  • 数据可视化大屏设计与实现
  • 【深度学习】Huber Loss详解
  • 21天学通C++——11多态(引入多态的目的)
  • pytest-instafail:让测试失败信息即时反馈
  • Chromium 132 编译指南 Linux 篇 - 安装 Chromium 官方工具(三)
  • 河北省乡镇界面图层shp格式arcgis数据乡镇名称和编码2020年wgs84坐标无偏移内容测评
  • 山西省乡镇界面图层shp格式arcgis数据乡镇名称和编码2020年wgs84坐标无偏移测评
  • HRNet,Deep High-Resolution Representation Learning for Visual Recognition解读
  • 缓存、数据库双写一致性解决方案
  • 计算机毕业设计PySpark+Hadoop+Hive机票预测 飞机票航班数据分析可视化大屏 航班预测系统 机票爬虫 飞机票推荐系统 大数据毕业设计
  • Object常用的方法及开发中的使用场景
  • T-SQL语言的数据库交互
  • MYSQL数据库基础-01.数据库的基本操作
  • Windows图形界面(GUI)-QT-C/C++ - Qt控件与布局系统详解
  • 汇旺财支付PHP代码
  • 服务化架构 IM 系统之应用 MQ
  • 数据库服务体系结构
  • 基于机器学习的用户健康风险分类及预测分析
  • 数据结构 (C语言) 链表
  • C#里await Task.Run死锁的分析与解决
  • 【错误解决方案记录】spine3.8.75导出的数据使用unity-spine3.8插件解析失败报错的解决方案
  • 知识库管理系统的用户体验之道:便捷、高效、智能
  • PyTorch 基础数据集:从理论到实践的深度学习基石
  • 洛谷P1807 最长路(拓扑排序)