当前位置: 首页 > article >正文

RNN股票预测(Pytorch版)

任务:基于zgpa_train.csv数据,建立RNN模型,预测股价
1.完成数据预处理,将序列数据转化为可用于RNN输入的数据
2.对新数据zgpa_test.csv进行预测,可视化结果
3.存储预测结果,并观察局部预测结果
备注:模型结构:单层RNN,输出有5个神经元,每次使用前8个数据预测第9个数据
参考视频:吹爆!3小时搞懂!【RNN循环神经网络+时间序列LSTM深度学习模型】学不会UP主下跪!
up主用的Keras,自己用Pytorch尝试了一下,代码如下:

import pandas as pd
import numpy as np
import torch
from torch import nn
from matplotlib import pyplot as plt
data = pd.read_csv('zgpa_train.csv')
# loc 通过行索引 “Index” 中的具体值来取行数据
# 取出开盘价
price = data.loc[:,'close']

# 归一化
price_norm = price/max(price)
# 开盘价折线图
# fig1 = plt.figure(figsize=(10, 6))
# plt.plot(price)
# plt.title('close price')
# plt.xlabel('time')
# plt.ylabel('price')
# plt.show()

# 提取数据 每次使用前8个数据来预测第九个数据
def extract_data(data, time_step):
    x = []
    y = []
    for i in range(len(data)- time_step):
        x.append([a for a in data[i:i+time_step]])
        y.append(data[i + time_step])
    x = np.array(x)
    x = x.reshape(x.shape[0], x.shape[1], 1)
    x = torch.tensor(x, dtype=torch.float32)
    y = torch.tensor(y, dtype=torch.float32)
    return x, y
time_step = 8
x, y = extract_data(price_norm,time_step)
# print(x)
# print(y)
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(RNN,self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first = True)
        self.fc = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        out, _ = self.rnn(x)
        # print(out)
        out = self.fc(out[:, -1, :])
        out = out.squeeze(1)
        return out
# 定义模型参数
input_size = 1 # 输入特征的维度
hidden_size = 64 # 隐藏层的维度
output_size = 1 # 输出特征的维度
num_layers = 1 # RNN的层数

# 创建模型
model = RNN(input_size, hidden_size, output_size, num_layers)

# 定义损失函数和优化器
criterion = nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型
epochs = 200
for epoch in range(epochs):
    optimizer.zero_grad()
    # outputs = model(x.unsqueeze(2))
    outputs = model(x)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')
# 进行预测 数据很少这里就不先保存模型再预测了
model.eval()
with torch.no_grad():
    y_train_predict = model(x) * max(price)
y_train = [i * max(price) for i in y]
# print(y_train_predict)
y_train_predict = y_train_predict.cpu().numpy()
y_train = np.array(y_train)
fig2 = plt.figure(figsize=(10, 6))
plt.plot(y_train_predict, label='Predicted', color='blue')
plt.plot(y_train, label='True', color='red', alpha=0.6)
plt.title('Predicted vs True Values')
plt.xlabel('time')
plt.ylabel('price')
plt.legend()
plt.show()

# 测试集
data_test = pd.read_csv('zgpa_test.csv')
price_test = data_test.loc[:,'close']
price_test_norm = price_test/max(price)
x_test,y_test = extract_data(price_test_norm,time_step)
with torch.no_grad():
    y_test_predict = model(x_test) * max(price)
y_test = [i * max(price) for i in y_test]
# print(y_train_predict)
y_test_predict = y_test_predict.cpu().numpy()
y_test = np.array(y_test)
fig3 = plt.figure(figsize=(10, 6))
plt.plot(y_test_predict, label='Predicted', color='blue')
plt.plot(y_test, label='True', color='red', alpha=0.6)
plt.title('Predicted vs True Values (Test Set)')
plt.xlabel('time')
plt.ylabel('price')
plt.legend()
plt.show()

# 存储数据
result_y_test = np.array(y_test).reshape(-1, 1) # 若干行,1列
result_y_test_predict = y_test_predict.reshape(-1, 1)
print(result_y_test.shape, result_y_test_predict.shape)
result = np.concatenate((result_y_test, result_y_test_predict), axis=1)
print(result.shape)
result = pd.DataFrame(result, columns=['real_price_test', 'predict_price_test'])
result.to_csv('zgpa_predict_test.csv')

http://www.kler.cn/a/307317.html

相关文章:

  • A029-基于Spring Boot的物流管理系统的设计与实现
  • 项目风险管理的3大要素
  • 批量重命名Excel文件并排序
  • Unity 性能优化方案
  • 解锁微前端的优秀库
  • MySQL与Oracle对比及区别
  • 大模型参数高效微调技术原理综述(八)-MAM Adapter、UniPELT
  • Redhat 8,9系(复刻系列) 一键部署Oracle23ai rpm
  • 模型训练的过程中对学习不好的样本怎么处理更合适
  • Qt4Qt5Qt6版本下载(在线和离线)
  • C++ | Leetcode C++题解之第405题数字转换为十六进制数
  • 文本分类实战项目:如何使用NLP构建情感分析模型
  • Element-ui el-table 全局表格排序
  • 腾讯云软件工程师面试问题收集记录-数据库
  • redis简单使用与安装
  • Java并发:互斥锁,读写锁,Condition,StampedLock
  • shopify主题开发之template模板解析
  • C++学习笔记----7、使用类与对象获得高性能(一)---- 书写类(3)
  • 蓝桥杯-基于STM32G432RBT6的LCD进阶(LCD界面切换以及高亮显示界面)
  • 【AIGC】CFG:基于扩散模型分类器差异引导
  • JavaScript 函数 function
  • 用 nextjs 创建 Node+React Demo
  • WebGL入门(048):OES_draw_buffers_indexed 简介、使用方法、示例代码
  • Python---爬虫
  • Leetcode-轮转数组
  • 复现OpenVLA:开源的视觉-语言-动作模型及原理详解