当前位置: 首页 > article >正文

informer辅助笔记:exp/exp_informer.py

0 导入库

from data.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred
from exp.exp_basic import Exp_Basic
from models.model import Informer, InformerStack

from utils.tools import EarlyStopping, adjust_learning_rate
from utils.metrics import metric

import numpy as np

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader

import os
import time

import warnings
warnings.filterwarnings('ignore')

1 Exp_Informer

class Exp_Informer(Exp_Basic):
    def __init__(self, args):
        super(Exp_Informer, self).__init__(args)

1.1 build_model

'''
用于构建模型。它根据提供的参数来实例化特定类型的模型
'''
def _build_model(self):
        model_dict = {
            'informer':Informer,
            'informerstack':InformerStack,
        }

        if self.args.model=='informer' or self.args.model=='informerstack':
            e_layers = self.args.e_layers if self.args.model=='informer' else self.args.s_layers

            model = model_dict[self.args.model](
                self.args.enc_in,
                self.args.dec_in, 
                self.args.c_out, 
                self.args.seq_len, 
                self.args.label_len,
                self.args.pred_len, 
                self.args.factor,
                self.args.d_model, 
                self.args.n_heads, 
                e_layers, # self.args.e_layers,
                self.args.d_layers, 
                self.args.d_ff,
                self.args.dropout, 
                self.args.attn,
                self.args.embed,
                self.args.freq,
                self.args.activation,
                self.args.output_attention,
                self.args.distil,
                self.args.mix,
                self.device
            ).float()
            #用提供的参数实例化模型
        
        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
            #如果设置为使用多 GPU,那么模型将被包装在 nn.DataParallel 中,以便在多个 GPU 上并行运行。
        return model

1.2 get_data

'''
根据指定的模式(如训练、测试或预测)获取数据
'''
def _get_data(self, flag):
        args = self.args

        data_dict = {
            'ETTh1':Dataset_ETT_hour,
            'ETTh2':Dataset_ETT_hour,
            'ETTm1':Dataset_ETT_minute,
            'ETTm2':Dataset_ETT_minute,
            'WTH':Dataset_Custom,
            'ECL':Dataset_Custom,
            'Solar':Dataset_Custom,
            'custom':Dataset_Custom,
        }
        '''
        定义了一个字典,映射不同的数据集名称到相应的数据集类。

        例如,'ETTh1' 和 'ETTh2' 映射到 Dataset_ETT_hour 类。
        '''

        Data = data_dict[self.args.data]
        #根据参数中指定的数据集名称选择相应的数据集类

        timeenc = 0 if args.embed!='timeF' else 1    
        #设置时间编码标志。如果嵌入类型不是 'timeF',则 timeenc 设置为 0,否则设置为 1。

        if flag == 'test':
            shuffle_flag = False; drop_last = True; batch_size = args.batch_size; freq=args.freq
        elif flag=='pred':
            shuffle_flag = False; drop_last = False; batch_size = 1; freq=args.detail_freq
            Data = Dataset_Pred
        else:
            shuffle_flag = True; drop_last = True; batch_size = args.batch_size; freq=args.freq
        '''
        根据 flag 参数(指示数据集用途,如 'test', 'pred', 或其他)设置不同的参数:
                shuffle_flag:是否打乱数据。
                drop_last:在数据批次不足时是否丢弃最后一批数据。
                batch_size:每批数据的大小。
                freq:数据频率,用于确定数据处理的时间间隔。
        '''

        data_set = Data(
            root_path=args.root_path,
            data_path=args.data_path,
            flag=flag,
            size=[args.seq_len, args.label_len, args.pred_len],
            features=args.features,
            target=args.target,
            inverse=args.inverse,
            timeenc=timeenc,
            freq=freq,
            cols=args.cols
        )
        '''
        使用指定参数实例化数据集。这里包括了
                数据路径
                标志(如 'train', 'test')
                序列长度、标签长度、预测长度
                特征类型 (M,S,MS)
                目标列
                时间编码标志
                频率
                需要使用的列
        '''
        print(flag, len(data_set))

        data_loader = DataLoader(
            data_set,
            batch_size=batch_size,
            shuffle=shuffle_flag,
            num_workers=args.num_workers,
            drop_last=drop_last)
        '''
        使用 DataLoader 创建一个数据加载器,用于批量加载数据
        同时指定是否打乱、是否丢弃最后一个批次、使用的工作进程数量等。
        '''

        return data_set, data_loader
        #返回数据集和数据加载器的实例

1.3 optimizer & criterion

def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        return model_optim
    
def _select_criterion(self):
        criterion =  nn.MSELoss()
        return criterion

#选择优化器和损失函数

1.4 vali

'''
在验证集上评估模型
'''
def vali(self, vali_data, vali_loader, criterion):
        self.model.eval() #将模型设置为评估模式
        total_loss = []
        for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(vali_loader):
            #遍历验证数据加载器中的每个批次

            pred, true = self._process_one_batch(
                vali_data, batch_x, batch_y, batch_x_mark, batch_y_mark)
            #调用 _process_one_batch 方法处理一个批次的数据。这个方法会返回预测值(pred)和真实值(true)

            loss = criterion(pred.detach().cpu(), true.detach().cpu())
            #计算预测值和真实值之间的损失

            total_loss.append(loss)
            #将计算出的损失添加到 total_loss 列表中

        total_loss = np.average(total_loss)
        #计算所有批次损失的平均值。这个平均损失表示在验证数据集上模型的整体性能。

        self.model.train()
        #将模型重新设置为训练模式,继续训练模型

        return total_loss
        #返回计算出的平均损失值

1.5 train

'''
训练模型
'''
def train(self, setting):
        train_data, train_loader = self._get_data(flag = 'train')
        vali_data, vali_loader = self._get_data(flag = 'val')
        test_data, test_loader = self._get_data(flag = 'test')
        #使用 _get_data 方法加载训练、验证和测试数据集。

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)
        #创建用于保存模型检查点的目录

        time_now = time.time()
        
        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
        #使用EarlyStopping  检查是否应停止训练

        model_optim = self._select_optimizer()
        criterion =  self._select_criterion()

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()
        '''
        初始化一些变量:

            train_steps:训练数据加载器中的批次总数。
            early_stopping:如果验证损失在一定迭代次数后没有改善,则停止训练。
            model_optim:选择优化器。
            criterion:选择损失函数。
            如果启用了自动混合精度(AMP),则初始化 scaler。
        '''

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []
            
            self.model.train()
            epoch_time = time.time()
            for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(train_loader):
                #遍历训练数据加载器中的所有批次
                iter_count += 1
                
                model_optim.zero_grad() #清除模型优化器的梯度

                pred, true = self._process_one_batch(
                    train_data, batch_x, batch_y, batch_x_mark, batch_y_mark)
                #使用 _process_one_batch 处理批次数据,计算损失

                loss = criterion(pred, true)
                #计算这一个batch预测值和实际值的差距
                train_loss.append(loss.item())
                
                if (i+1) % 100==0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                    speed = (time.time()-time_now)/iter_count
                    left_time = speed*((self.args.train_epochs - epoch)*train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()
                #每100次迭代打印损失和预计剩余时间
                
                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    loss.backward()
                    model_optim.step()
                #损失后向传播和优化器步骤,如果启用了 AMP,则使用 scaler 进行这些步骤

            print("Epoch: {} cost time: {}".format(epoch+1, time.time()-epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            #对模型进行validation
            test_loss = self.vali(test_data, test_loader, criterion)

            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                epoch + 1, train_steps, train_loss, vali_loss, test_loss))
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            adjust_learning_rate(model_optim, epoch+1, self.args)
            
        best_model_path = path+'/'+'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))
        #在训练结束后,加载表现最好的模型状态
        
        return self.model

1.6 test

'''
在测试集上评估模型
'''
def test(self, setting):
        test_data, test_loader = self._get_data(flag='test')
        #加载测试数据集
        
        self.model.eval()
        
        preds = []
        trues = []
        #存储模型的预测和相应的真实值
        
        for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(test_loader):
            pred, true = self._process_one_batch(
                test_data, batch_x, batch_y, batch_x_mark, batch_y_mark)
            preds.append(pred.detach().cpu().numpy())
            trues.append(true.detach().cpu().numpy())
        '''
        遍历测试数据加载器中的每个批次。
        使用 _process_one_batch 方法处理每个批次的数据。
        将预测值和真实值添加到各自的列表中。
        '''

        preds = np.array(preds)
        trues = np.array(trues)
        print('test shape:', preds.shape, trues.shape)
        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
        trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
        print('test shape:', preds.shape, trues.shape)

        # result save
        folder_path = './results/' + setting +'/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        #创建一个文件夹来存储测试结果

        mae, mse, rmse, mape, mspe = metric(preds, trues)
        #使用自定义的 metric 函数计算各种性能指标,如 MAE(平均绝对误差)、MSE(均方误差)、RMSE(均方根误差)、MAPE(平均绝对百分比误差)和 MSPE(均方百分比误差)。
        print('mse:{}, mae:{}'.format(mse, mae))

        np.save(folder_path+'metrics.npy', np.array([mae, mse, rmse, mape, mspe]))
        np.save(folder_path+'pred.npy', preds)
        np.save(folder_path+'true.npy', trues)

        return

1.7 predict

#在新数据上进行模型预测
def predict(self, setting, load=False):
        pred_data, pred_loader = self._get_data(flag='pred')
        #加载预测数据集
        
        if load:
            path = os.path.join(self.args.checkpoints, setting)
            best_model_path = path+'/'+'checkpoint.pth'
            self.model.load_state_dict(torch.load(best_model_path))
        #如果 load 为 True,则从保存的路径加载最佳模型的状态。

        self.model.eval()
        
        preds = []
        
        for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(pred_loader):
            pred, true = self._process_one_batch(
                pred_data, batch_x, batch_y, batch_x_mark, batch_y_mark)
            preds.append(pred.detach().cpu().numpy())
        '''
        遍历预测数据加载器中的每个批次。
        使用 _process_one_batch 方法处理每个批次的数据。
        将预测值添加到 preds 列表中。
        '''

        preds = np.array(preds)
        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
        
        # result save
        folder_path = './results/' + setting +'/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        
        np.save(folder_path+'real_prediction.npy', preds)
        #保存预测结果
        
        return

1.8 process_one_batch

'''
处理一个数据批次
'''
def _process_one_batch(self, dataset_object, batch_x, batch_y, batch_x_mark, batch_y_mark):
        batch_x = batch_x.float().to(self.device)
        batch_y = batch_y.float()

        batch_x_mark = batch_x_mark.float().to(self.device)
        batch_y_mark = batch_y_mark.float().to(self.device)

        # decoder input
        if self.args.padding==0:
            dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float()
        elif self.args.padding==1:
            dec_inp = torch.ones([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float()
        #根据 self.args.padding 的值创建一个全零或全一的张量作为解码器的初始输入

        dec_inp = torch.cat([batch_y[:,:self.args.label_len,:], dec_inp], dim=1).float().to(self.device)
        #将这个张量与 batch_y 的一部分拼接,形成完整的解码器输入

        # encoder - decoder
        if self.args.use_amp:
            with torch.cuda.amp.autocast():
                if self.args.output_attention:
                    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                else:
                    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
        else:
            if self.args.output_attention:
                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
            else:
                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
        if self.args.inverse:
            outputs = dataset_object.inverse_transform(outputs)
        #encoder-decoder的输出

        f_dim = -1 if self.args.features=='MS' else 0
        batch_y = batch_y[:,-self.args.pred_len:,f_dim:].to(self.device)
        #从 batch_y 中选择与预测长度相对应的部分,并移动到指定设备。
        #f_dim 变量用于确定特征维度。

        return outputs, batch_y


http://www.kler.cn/a/156745.html

相关文章:

  • 执行flink sql连接clickhouse库
  • 群控系统服务端开发模式-应用开发-前端个人信息功能
  • WebGIS三维地图框架--Cesium
  • ubuntu中apt-get的默认安装路径。安装、卸载以及查看的方法总结
  • 基于混合配准策略的多模态医学图像配准方法研究
  • 【C++】C++11特性(上)
  • 【2021研电赛】基于EAIDK310的视觉导航自动驾驶小车
  • 【C语言】扫雷小游戏初学者版
  • 网络调试助手 连接Onenet 多协议接入平台 TCP透传协议
  • 专业爬虫框架 -- scrapy初识及基本应用
  • C++继承(详解)
  • 聚焦数据库Serverless创新,就在2023亚马逊云科技re:Invent
  • 科技云报道:AI+PaaS,中国云计算市场迎来新“变量”?
  • React Native expo Android adb 调试出现 device not found 怎么办
  • springMVC实验(二)—调式工具APIFOX的使用
  • 物品领用管理软件哪家的好用?怎么让办公用品管理变得更加轻松高效?
  • 智能优化算法应用:基于海洋捕食者算法无线传感器网络(WSN)覆盖优化 - 附代码
  • Python小项目:葛兰中欧医疗基金数据分析
  • 无公网IP环境固定地址远程SSH访问本地树莓派Raspberry Pi
  • 23种设计模式之C++实践(一)
  • Beta冲刺随笔-DAY6-橘色肥猫
  • Android防破解重签名方案研究
  • Lab 3: Recursion, Tree Recursion(CS61A 2020)
  • JAVA代码优化:随机数字生成(UUID)
  • Unity EventSystem的一些理解和使用
  • 论文阅读:Distributed Initialization for VVIRO with Position-Unknown UWB Network