Paddlets时间序列集成模型回测实战:MLPRegressor、NHiTSModel与RNNBlockRegressor
好的,我们继续深入理解代码的每个部分。以下是每个主要模块的详细解释:
1. 导入模块和库
import json
import os
import glob
import pandas as pd
from tqdm import tqdm
from paddlets.datasets import TSDataset
from paddlets.transform import StandardScaler
from paddlets.models.forecasting import MLPRegressor, NHiTSModel, RNNBlockRegressor
from paddlets.ensemble import WeightingEnsembleForecaster
import ray
- json: 用于处理JSON文件。
- os: 用于处理文件和目录操作。
- glob: 用于获取匹配特定模式的文件路径。
- pandas: 用于数据处理和分析,尤其是表格数据。
- tqdm: 用于显示进度条,帮助跟踪循环的进度。
- paddlets: 时间序列预测相关的库。
- ray: 用于并行计算的库。
2. 定义和创建目录
dirs = [
"forecasting_all_result_center",
"pic_forecasting_center",
"model_forecasting_center_2048_a_b_5_100",
"best_forecasting_param_center"
]
for dir_name in dirs:
os.makedirs(dir_name, exist_ok=True)
- dirs: 定义了多个用于存储不同类型结果的目录。
- os.makedirs: 创建目录,如果目录已存在,则不报错。
3. 加载股票映射
with open("./stock_mapping.json", "r") as f:
stock_mapping = json.load(f)
- 从
stock_mapping.json
文件中加载股票的映射关系,以便后续使用。
4. 加载CSV数据
csv_paths = glob.glob(os.path.join("./tu_share_data_day", "*.csv"))
sum_dam_data = []
for csv_path in tqdm(csv_paths):
new_data = pd.read_csv(csv_path)
if len(new_data) < 2048 or new_data.iloc[0, 2] < 5 or new_data.iloc[0, 2] > 100:
continue
new_data = new_data[::-1].iloc[:2048]
new_data['index_new'] = range(1, len(new_data) + 1)
sum_dam_data.append(new_data)
- 使用
glob
获取所有CSV文件路径,并遍历每个文件。 - 读取数据并进行过滤,确保符合条件(如数据长度、价格区间)。
- 将数据反转并取最后2048条,添加索引列。
5. 构建时间序列数据集
dam_data = pd.concat(sum_dam_data)
dataset = TSDataset.load_from_dataframe(
dam_data,
group_id='ts_code',
time_col="index_new",
target_cols=['high', 'low']
)
- 将所有符合条件的数据合并成一个DataFrame。
- 使用
TSDataset
将数据转换为时间序列格式,指定分组、时间列和目标列。
6. 初始化标准化器
scaler = StandardScaler().fit(dataset)
dataset = scaler.transform(dataset)
- 使用
StandardScaler
对数据进行标准化处理,使模型训练更加稳定。
7. 初始化Ray进行并行计算
ray.init()
- 初始化Ray,使得后续的计算能够并行执行。
8. 定义并行处理函数
@ray.remote
def process_csv_file(csv_path, scaler):
...
- 使用
@ray.remote
装饰器定义一个可以被Ray并行化的函数,处理每个CSV文件的逻辑。
9. 设置模型参数和加载模型
nhits_params = {
'sampling_stride': 24, 'eval_metrics': ["mse", "mae"], 'batch_size': 32, 'max_epochs': 100, 'patience': 10}
rnn_params = nhits_params.copy()
mlp_params = nhits_params.copy()
mlp_params['use_bn'] =