当前位置: 首页 > article >正文

时序预测算法TimeXer代码解析

在时序预测领域,如何有效地利用外部变量(exogenous variables)来提升内部变量(endogenous variables)的预测性能一直是一个挑战。

在上一篇文章中,我结合论文为大家解读了TimeXer框架,今天,我将对TimeXer代码进行解析。

论文精度链接: TimeXer:融合外部变量与内部变量,提升时序预测性能

TimeXer算法框架中包括模块有:Endogenous Embedding,Exogenous Embedding,Endogenous Self-Attention,Exogenous-to-Endogenous Cross-Attention。

1. Endogenous Embedding

Endogenous Embedding 通过Patch级表示精细捕捉内生变量时间变化,针对内生、外生变量嵌入粒度不同产生的信息对齐问题,引入可学习的全局令牌促进外生因果信息向内生时间Patch传递,助力预测。


class EnEmbedding(nn.Module):
	def __init__(self, n_vars, d_model, patch_len, dropout):
	    super(EnEmbedding, self).__init__()
	    # patch_len: 每个补丁的长度
	    # d_model: 嵌入维度
	    self.patch_len = patch_len
	    # 将 patch_len 维投影到 d_model 维,生成时间 token
	    self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
	    # 全局可学习 token,初始化为 [1, n_vars, 1, d_model] 的随机值
	    self.glb_token = nn.Parameter(torch.randn(1, n_vars, 1, d_model))
	    # 位置嵌入,用于补充位置信息
	    self.position_embedding = PositionalEmbedding(d_model)
	    # Dropout 正则化
	    self.dropout = nn.Dropout(dropout)
	def forward(self, x):
	    # 输入维度 [B, V, L]
	    n_vars = x.shape[1]
	    # 全局令牌复制到每个 batch
	    glb = self.glb_token.repeat((x.shape[0], 1, 1, 1))
	    # 将时间序列划分为不重叠补丁 [B, V, N, patch_len]
	    x = x.unfold(dimension=-1, size=self.patch_len, step=self.patch_len)
	    # 展平为 [B*V, N, patch_len]
	    x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
	    # 通过线性投影生成时间 token [B*V, N, d_model]
	    x = self.value_embedding(x) + self.position_embedding(x)
	    # 恢复为 [B, V, N, d_model]
	    x = torch.reshape(x, (-1, n_vars, x.shape[-2], x.shape[-1]))
	    # 拼接全局令牌 [B, V, N+1, d_model]
	    x = torch.cat([x, glb], dim=2)
	    # 展平为 [B*V, N+1, d_model]
	    x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
	    # 应用 dropout
	    return self.dropout(x), n_vars

 完整文章链接: 时序预测算法TimeXer代码解析


http://www.kler.cn/a/427580.html

相关文章:

  • 《深度学习模型的应用与发展:从基础到前沿》
  • 【PID】温控、调速的应用
  • 设计模式c++(二)
  • 深入浅出 Go 语言:理解包管理
  • maven常用知识详解3:聚合与继承
  • 2024年9月GESPC++二级真题解析
  • 基于Matlab卷积神经网络的交通标志识别系统研究与实现
  • AcWing 5843. 染色
  • 怎么获取键值对的键的数值?
  • 数仓技术hive与oracle对比(四)
  • Python有趣小例子:魔法药水制作机
  • SQL注入基础入门篇 注入思路及常见的SQL注入类型总结
  • 在已经有的docker镜像中打包新的组件
  • python selenium 爬虫入门备忘
  • [高阶数据结构七]跳表的深度剖析
  • C# 设计模式--建造者模式 (Builder Pattern)
  • 深度解析 Ansible:核心组件、配置、Playbook 全流程与 YAML 奥秘(上)
  • [C++]构造函数和析构函数
  • 如何“安装Android SDK“?
  • 华为问界M9 [电气架构] 信息梳理