当前位置: 首页 > article >正文

Paddlets时间序列集成模型回测实战:MLPRegressor、NHiTSModel与RNNBlockRegressor

好的,我们继续深入理解代码的每个部分。以下是每个主要模块的详细解释:

1. 导入模块和库

import json
import os
import glob
import pandas as pd
from tqdm import tqdm
from paddlets.datasets import TSDataset
from paddlets.transform import StandardScaler
from paddlets.models.forecasting import MLPRegressor, NHiTSModel, RNNBlockRegressor
from paddlets.ensemble import WeightingEnsembleForecaster
import ray
  • json: 用于处理JSON文件。
  • os: 用于处理文件和目录操作。
  • glob: 用于获取匹配特定模式的文件路径。
  • pandas: 用于数据处理和分析,尤其是表格数据。
  • tqdm: 用于显示进度条,帮助跟踪循环的进度。
  • paddlets: 时间序列预测相关的库。
  • ray: 用于并行计算的库。

2. 定义和创建目录

dirs = [
    "forecasting_all_result_center",
    "pic_forecasting_center",
    "model_forecasting_center_2048_a_b_5_100",
    "best_forecasting_param_center"
]

for dir_name in dirs:
    os.makedirs(dir_name, exist_ok=True)
  • dirs: 定义了多个用于存储不同类型结果的目录。
  • os.makedirs: 创建目录,如果目录已存在,则不报错。

3. 加载股票映射

with open("./stock_mapping.json", "r") as f:
    stock_mapping = json.load(f)
  • stock_mapping.json文件中加载股票的映射关系,以便后续使用。

4. 加载CSV数据

csv_paths = glob.glob(os.path.join("./tu_share_data_day", "*.csv"))
sum_dam_data = []

for csv_path in tqdm(csv_paths):
    new_data = pd.read_csv(csv_path)
    if len(new_data) < 2048 or new_data.iloc[0, 2] < 5 or new_data.iloc[0, 2] > 100:
        continue
    new_data = new_data[::-1].iloc[:2048]
    new_data['index_new'] = range(1, len(new_data) + 1)
    sum_dam_data.append(new_data)
  • 使用glob获取所有CSV文件路径,并遍历每个文件。
  • 读取数据并进行过滤,确保符合条件(如数据长度、价格区间)。
  • 将数据反转并取最后2048条,添加索引列。

5. 构建时间序列数据集

dam_data = pd.concat(sum_dam_data)

dataset = TSDataset.load_from_dataframe(
    dam_data,
    group_id='ts_code',
    time_col="index_new",
    target_cols=['high', 'low']
)
  • 将所有符合条件的数据合并成一个DataFrame。
  • 使用TSDataset将数据转换为时间序列格式,指定分组、时间列和目标列。

6. 初始化标准化器

scaler = StandardScaler().fit(dataset)
dataset = scaler.transform(dataset)
  • 使用StandardScaler对数据进行标准化处理,使模型训练更加稳定。

7. 初始化Ray进行并行计算

ray.init()
  • 初始化Ray,使得后续的计算能够并行执行。

8. 定义并行处理函数

@ray.remote
def process_csv_file(csv_path, scaler):
    ...
  • 使用@ray.remote装饰器定义一个可以被Ray并行化的函数,处理每个CSV文件的逻辑。

9. 设置模型参数和加载模型

nhits_params = {
   'sampling_stride': 24, 'eval_metrics': ["mse", "mae"], 'batch_size': 32, 'max_epochs': 100, 'patience': 10}
rnn_params = nhits_params.copy()
mlp_params = nhits_params.copy()
mlp_params['use_bn'] =

http://www.kler.cn/news/327529.html

相关文章:

  • # VirtualBox中安装的CentOS 6.5网络设置为NAT模式时,怎么使用SecureCRT连接CentOS6.5系统?
  • 计算机毕业设计 基于Python的广东旅游数据分析系统的设计与实现 Python+Django+Vue Python爬虫 附源码 讲解 文档
  • WPF中的switch选择
  • Visual Studio-X64汇编编写
  • stm32单片机学习 - MDK仿真调试
  • Redis篇(缓存机制 - 多级缓存)(持续更新迭代)
  • MySQL—表优化
  • 平衡二叉搜索树删除的实现
  • Spring Cloud全解析:服务调用之OpenFeign集成OkHttp
  • 一次阿里云ECS免费试用实践
  • leetcode-链表篇4
  • MATLAB编写的RSSI在三维空间上的定位程序,锚点数量无限制(可自定义),带中文注释
  • 如何获取钉钉webhook
  • docker容器mysql数据备份 mysql容器无法启动备份数据
  • 【docker学习】Linux系统离线方式安装docker环境方法
  • 【Linux系列】CMA (Contiguous Memory Allocator) 简单介绍
  • IP地址与5G时代的万物互联
  • 享元模式
  • 【MATLAB源码-第178期】基于matlab的8PSK调制解调系统频偏估计及补偿算法仿真,对比补偿前后的星座图误码率。
  • 智慧农业案例 (一)- 自动化机械
  • vue2圆形标记(Marker)添加点击事件不弹出信息窗体(InfoWindow)的BUG解决
  • 05-函数传值VS传引用
  • 2.点位管理|前后端如何交互——帝可得后台管理系统
  • 基础漏洞——SSTI(服务器模板注入)
  • leetcode-134. 加油站-贪心策略
  • 数据结构与算法学习(2)
  • 汽车灯光系统详细介绍
  • 【机器学习】---深入探讨图神经网络(GNN)
  • 【STM32】 TCP/IP通信协议(3)--LwIP网络接口
  • 将 Intersection Observer 与自定义 React Hook 结合使用