当前位置: 首页 > article >正文

GMAN解读(论文+代码)

一、注意力机制

        注意力机制与传统的卷积神经网络不同的是,前者擅长捕获全局依赖和长程关系,权重会动态调整。而后者对于所有特征都使用同一个卷积核。关于更多注意力机制内容,详见:

注意力机制、自注意力机制、多头注意力机制、通道注意力机制、空间注意力机制超详细讲解-CSDN博客

        而本论文基于空间注意力机制,提出了空间和时间注意力机制,以建模动态的时空相关性。 

二、GMAN 模型

 

1. 时空嵌入(Sptio-Tempoal Embedding,STE)

        STE 的作用是将图结构和时间信息融入多注意力机制中。

        空间嵌入采用Node2Vec方法将N个顶点编码为保留图结构信息的向量,然后将其输入两层全连接神经网络中以进行联合训练。从而可以得到一个 RD 的向量。

        时间嵌入而是将每个时间步编码为向量。首先将将一天划分为 T 个时间步,并对每个时间步的“周几”和“时间段”进行独热编码。然后通过两层全连接网络将时间特征也转换为一个 RD 的向量。

        最后,对于顶点 vi 和时间步 tj ,STE定义为

2.  时空注意力模块(ST-Attention Block)

2.1 空间注意力(Spatial Attention)

        在传统图卷积GCN中,隐藏状态通过以下方式计算:

        其中,A代表图的邻接矩阵,W代表权重。但在实际场景中,不同邻居对目标节点的重要性可能不同。因此,在注意力机制中,通过一个注意力分数 α 表示顶点 v 对顶点 vi 的重要性。那么加入注意力机制后的隐藏状态计算方式如下:

        注意力分数 α 的计算方式如下所示:

                    

        式中,|| 表示拼接,<  ,> 表示左右两个元素点乘。e 也就是上一步的时空嵌入(STE)。 其代码如下所示:

X = tf.concat((X, STE), axis = -1)

query = FC(X, units = D, activations = tf.nn.relu, bn = bn, bn_decay = bn_decay, is_training = is_training)
key = FC(X, units = D, activations = tf.nn.relu, bn = bn, bn_decay = bn_decay, is_training = is_training)

# 多头分解(K代表有几个头)
query = tf.concat(tf.split(query, K, axis = -1), axis = 0)
key = tf.concat(tf.split(key, K, axis = -1), axis = 0)

# 计算注意力分数
attention = tf.matmul(query, key, transpose_b = True)
attention /= (d ** 0.5)
attention = tf.nn.softmax(attention, axis = -1)
2.2 时间注意力(Temporal Attention) 

        时间注意力和空间注意力的实现方法类似。只有一点不同,空间注意力在空间维度(N)中捕捉节点间的依赖。而时间注意力在时间维度(num_step)中捕捉时间步间的依赖。

        作者在这一块的代码部分加入了可选的掩码(mask)。它与因果卷积的作用相同,都是为了解决时间序列建模中的因果性问题,防止未来信息泄露。不同点如下所示:

2.3 门控融合(Gated Fusion)

        在某些情况下,交通状况可能主要受空间因素影响(如附近道路拥堵)。在另一些情况下,时间因素可能更为关键(如高峰期的规律性变化)。为了平衡空间和时间注意力的贡献,使得模型可以在不同时空条件下动态调整两者的重要性,作者在这里使用了门控机制。Hs 代表空间注意力机制,Ht 代表时间注意力机制。

        z 是一个门控权重,z 越接近 1,模型越依赖空间注意力输出。反之越接近 0 ,模型越依赖时间注意力输出。其中 z 通过以下公式计算:

        其算法代码如下, 

XS = FC(
    HS, units = D, activations = None,
    bn = bn, bn_decay = bn_decay,
    is_training = is_training, use_bias = False)
XT = FC(
    HT, units = D, activations = None,
    bn = bn, bn_decay = bn_decay,
    is_training = is_training, use_bias = True)
z = tf.nn.sigmoid(tf.add(XS, XT))
H = tf.add(tf.multiply(z, HS), tf.multiply(1 - z, HT))

3. 编码器-解码器结构

        编码器会接收历史交通数据(比如过去 1 小时的交通流量),将这些时间序列的信息“浓缩”为一个隐藏表示,这个表示概括了所有历史时间步的信息。解码器接收编码器生成的隐藏表示,结合目标预测的要求(比如未来 1 小时的交通流量),逐步生成未来时间步的预测值。

4. 转换注意力(Transform Attention) 

        在长时间交通预测中,我们不仅需要知道历史的交通状况,还要明白历史的哪些时刻对未来的影响更重要。而转换注意力的作用就是把历史信息直接“映射”到未来,建立一种历史时间步和未来时间步之间的直接联系。比如,要想预测明天早上的交通状况,就要先知道今天早上的交通状况和昨天晚上的交通状况。如果昨天晚上有交通事故,那么一定会影响今天早上的交通状况。最后就可以建立映射关系:

        昨天晚上——>今天早上,那么今天晚上——>明天早上。

        转换注意力会计算每个历史时间步和每个未来时间步之间的“相关性分数”。这个分数告诉我们某个历史时刻对未来有多重要。然后根据计算出的相关性分数,为未来时间步选择最重要的历史时间步,提取它们的特征信息。最后把这些选出来的历史特征送到解码器,直接生成未来的预测值。

        在代码上,transformAttention 和 temporalAttention、spatialAttention 的写法类似,只不过传入的参数有所不同。具体来说,也就是 Query、Key、Value 不同。transformAttention 使用历史时间步(STE_P)和预测时间步(STE_Q) 构建 Query-Key 机制,实现时间序列的转换。

 


http://www.kler.cn/a/408406.html

相关文章:

  • 无线电磁波在自由空间的衰减
  • el-table-column自动生成序号在序号前插入图标
  • AMD(Xilinx) FPGA配置Flash大小选择
  • 已阻止加载“http://localhost:8086/xxx.js”的模块,它使用了不允许的 MIME 类型 (“text/plain”)。
  • opencv undefined reference to `cv::noarray()‘ 。window系统配置opencv,找到opencv库,但连接不了
  • Java集合分页
  • 【面向对象】Java处理异常的方式
  • STM32抢占优先级不生效
  • 对于相对速度的重新理解 - 插一句
  • MySQL原理简介—10.SQL语句和执行计划
  • 编程中的字节序问题
  • 海信Java后端开发面试题及参考答案
  • 基于python的长津湖评论数据分析与可视化,使用是svm情感分析建模
  • docker 配置代理
  • 如何在 .gitignore 中仅保留特定文件:以忽略文件夹中的所有文件为例
  • hyperf 配置步骤
  • 深入理解CRC:通信可靠性的关键
  • CSS中flex:1是什么属性
  • Milvus实操
  • Adobe Illustrator 2024 安装教程与下载分享
  • docker拉取镜像问题解决
  • 【Linux】gcc/g++使用
  • Python + 深度学习从 0 到 1(00 / 99)
  • 小公司该如何做好项目管理工作
  • 空安全-模块-并发
  • Go-protobuf consul注册备忘录