当前位置: 首页 > article >正文

DGCN论文解读

一、时空特征提取结构

        作者团队所提出的 DGCN 的时空特征提取结构主要包括两个组成部分:拉普拉斯矩阵潜在网络(LMLN)基于GCN的交通预测组件

1 拉普拉斯矩阵潜在网络(LMLN)

        这是用于估计观察到的交通数据的拉普拉斯矩阵。这个矩阵帮助模型理解“哪条路会影响另一条路的交通状况”,从而更好地预测未来的交通流量或速度。 如下图所示,拉普拉斯矩阵预测单元包含三个模块:特征采样、空间注意力和LSTM单元。

1.1 全局拉普拉斯矩阵学习层(Global Laplace matrix Learing Layer)

        首先,道路图(Road map)需要被简化为道路网(Road network),这样才能得到一个关于道路交通网络的邻接矩阵 W,然 后使用对称归一化可以得到拉普拉斯矩阵 L。在论文第二部分相关工作中提到,拉普拉斯矩阵 L 的分解计算复杂,因此提出了快速GCN(Fast-GCN),这种方法的主要思想是在每次训练迭代时随机采样(sampling)一部分节点和邻居,而不是处理整个图。即可以对拉普拉斯矩阵进行进一步缩放。其中 λ 是拉普拉斯矩阵的最大特征值。

# 计算归一化拉普拉斯矩阵
def scaled_Laplacian(W):
    assert W.shape[0] == W.shape[1]
    D = np.diag(np.sum(W, axis=1))
    L = D - W
    lambda_max = eigs(L, k=1, which='LR')[0].real
    return (2 * L) / lambda_max - np.identity(W.shape[0])

        然后再通过全局拉普拉斯矩阵学习层处理。在这个过程中,用一个可训练的参数化矩阵 Lpar 调整邻接关系,即和上一步的缩放拉普拉斯矩阵 L 相加,就可以得到一个残差拉普拉斯矩阵 Lres1。公式中 Dres 表示节点 i 的度矩阵,也就是路段 i 和其他路段的连接数量。对其取逆并和原来的残差拉普拉斯矩阵相乘,那么度数较大的节点的影响力就会变得较小,而度数较小的节点的影响力则会被放大。确保图中每个节点对整个图结构的影响力是公平的,避免度数较大的节点在图卷积计算中占主导地位。这能帮助模型更准确地捕捉到图中节点之间的真实关系。

A = self.h + supports  # self.h即公式中的Lpar
d = 1 / (torch.sum(A, -1) + 0.0001)
D = torch.diag_embed(d)
A = torch.matmul(D, A)
A1 = F.dropout(A, 0.5, self.training)
1.2 特征采样(Feature Sampling) 

        作者提出了一种特征采样方案,根据不同时间间隔的重要性来减少交通特征的数据维度。举个例子,假设在预测某条道路未来一小时的交通流量,最近的1小时内的交通情况对预测未来的交通流量影响最大,所以会给最近的数据更多的权重。然后,把过去4小时内每10分钟的数据合并成一个整体特征,以简化计算过程,同时保持对预测有用的时空信息。

        在实际的代码中,作者用两个不同的空间注意力 SATT_3SATT_2 分别计算不同时间步的特征。

# SATT_3中,通道数(c_in)扩大到 12 倍,时间维度被缩小为原来的 1/12
shape = seq.shape
seq = seq.permute(0, 1, 3, 2).contiguous().view(shape[0], shape[1] * 12, shape[3] // 12, shape[2])
seq = seq.permute(0, 1, 3, 2)
x_1 = self.time_conv(x)
x_1 = F.leaky_relu(x_1)
x_tem1 = x_1[:, :, :, 0:48]  # 前48个时间步
x_tem2 = x_1[:, :, :, 48:60]  # 后12个时间步

# 合并两个空间注意力
S_coef1 = self.SATT_3(x_tem1)
S_coef2 = self.SATT_2(x_tem2)
S_coef = torch.cat((S_coef1, S_coef2), 1)  # b,l,n,c
1.3 空间注意力(Spatial Attention)

        空间注意力回答的问题是:在当前时刻,哪些节点(例如交通路段)对某个节点的状态变化有更大的影响。因此,作者通过空间注意力构建每个时刻道路网络的空间关系,计算得到一个动态邻接矩阵 Ld。计算方式如下,F(1:τ)  表示在时间范围 1 到 τ 内的节点特征。W 是可学习的权重矩阵。想象一下,矩阵 F 的某一行是这个路段对其他路段的影响,而某一列是其他路段对这个路段的影响。转置后,点积操作会计算每个节点对其他节点的相似性,从而生成一个节点之间的相互关系矩阵 Ld。最后再通过 sigmoid 函数归一化。

# 输入通道缩小为原来的 1/4
f1 = self.conv1(seq).view(shape[0], self.c_in // 4, 4, shape[2], shape[3]).permute(0, 3, 1, 4, 2).contiguous()
f2 = self.conv2(seq).view(shape[0], self.c_in // 4, 4, shape[2], shape[3]).permute(0, 1, 3, 4, 2).contiguous()

logits = torch.einsum('bnclm,bcqlm->bnqlm', f1, f2)
logits = logits.permute(0, 3, 1, 2, 4).contiguous()
logits = torch.sigmoid(logits)
  1.4 LSTM单元(LSTM Unit)

         之前提到,空间注意力由 SATT_3SATT_2 组成。这会产生两个不同的邻接矩阵 Ld,因此作者将这两个 Ld 合并为一个,并使用 LSTM 分析动态邻接矩阵序列 Ld​ 在不同时间点之间的变化规律。那么这里计算的 Ld 和上面计算出的 Lres 有什么关系?Lres 是优化后的全局拉普拉斯矩阵,用于提供长期的全局图结构约束。它是更稳定、更大范围的空间表示。而 Ld 是通过注意力机制动态生成的邻接矩阵,用于反映节点的短期动态关系。两者相乘,能够同时保留了 Ld 的动态变化性和 Lres 的全局稳定性。

# 通过 LSTM 学习空间注意力的序列依赖
S_coef = S_coef.permute(0, 2, 1, 3).contiguous().view(shape[0] * shape[2], shape[1], shape[3])
S_coef = F.dropout(S_coef, 0.5, self.training)
_, hidden = self.LSTM(S_coef, hidden)

adj_out = hidden[0].squeeze().view(shape[0], shape[2], shape[3]).contiguous()
adj_out1 = (adj_out) * supports
2 基于GCN的交通预测
2.1 图时间卷积层(Graph Temporal Convolutional Layer, GTCL)

        传统的交通预测领域中,GCN 和时间卷积层(TCL)堆叠为一个时空块(Spatial-Temporal Block),但是这样做的话计算量会非常大。因此,作者提出将这两种功能集成为一个单一的图时间卷积层(GTCL)。式中,TC 代表通过一维卷积提取到的局部时间特征,Cm(L) 代表用 M 阶切比雪夫多项式 Cm 去近似拉普拉斯矩阵 L(也就是之前计算出的 Lp),Φm 是每一阶切比雪夫多项式的可训练图卷积核参数。

def forward(self, x, adj):
    nSample, feat_in, nNode, length = x.shape

    Ls = []
    L1 = adj
    L0 = torch.eye(nNode).repeat(nSample, 1, 1).cuda()
    Ls.append(L0)
    Ls.append(L1)

    # 计算 Chebyshev 多项式 T_k
    for k in range(2, self.K):
        L2 = 2 * torch.matmul(adj, L1) - L0
        L0, L1 = L1, L2
        Ls.append(L2)

    # 将所有的 Chebyshev 多项式堆叠起来
    Lap = torch.stack(Ls, 1)  # [B, K,nNode, nNode]
    Lap = Lap.transpose(-1, -2)

    # 图卷积:将输入 x 与 Chebyshev 多项式 Lap 进行张量乘法
    x = torch.einsum('bcnl,bknq->bckql', x, Lap).contiguous()
    x = x.view(nSample, -1, nNode, length)
    out = self.conv1(x)

······

x_1 = self.dynamic_gcn(x_1, adj_out1)  # dynamic_gcn为Chebyshev图卷积,也就是上面的前向传播函数

        除此之外,作者还设计了一种门控机制来探索局部时间特征。把拆分出来的 β1 用 sigmoid 激活函数处理,生成一个介于 [0,1] 的权重值来计算每个特征的重要性。β2 通过非线性处理(Leaky ReLU),会作为 “主特征” 与 β1 相乘。那么模型就不用处理所有特征,而把重点放在分数高的特征上。

# 将图卷积输出拆分为滤波器和门控
filter, gate = torch.split(x_1, [self.c_out, self.c_out], 1)
x_1 = torch.sigmoid(gate) * F.leaky_relu(filter)
x_1 = F.dropout(x_1, 0.5, self.training)
 2.2 时间注意力机制(Temporal Attention Mechanism)

        时间注意力回答的问题是:一个节点在不同的时间点(过去或未来)之间,哪些时刻的状态对当前时刻的预测更重要。这一块儿公式的解读可以参考论文 ASTGCN。

2.3 批归一化(Batch Normalization)

        假如训练一个网络,输入数据的范围很混乱,比如有时是 0~1,有时是 10~100。这会让网络难以找到合适的参数。归一化就像一个“整顿器”,让每一批数据的分布变得均匀有序,让模型更容易学习。

# 归一化
def normalization(train, val, test):
    assert train.shape[1:] == val.shape[1:] and val.shape[1:] == test.shape[1:]

    # xis=0 表示沿第一个轴计算均值,keepdims=True 表示保持维度一致
    mean = train.mean(axis=0, keepdims=True)  # 计算训练集的均值
    std = train.std(axis=0, keepdims=True)  # 计算训练集的标准差

    # 定义归一化函数
    def normalize(x):
        return (x - mean) / std

    train = (train).transpose(0, 2, 1, 3)
    val = (val).transpose(0, 2, 1, 3)
    test = (test).transpose(0, 2, 1, 3)

    return {'mean': mean, 'std': std}, train, val, test
2.4 损失函数(Loss Function)

        如下图所示,最近的一段数据(Th)是连续的,比如过去10分钟的数据,它们是紧密相关的。而每日周期(Td),每周周期(Tw)这些数据并不是连续的,比如每天早上8点的情况,或者每周一的早上8点。这些时间点之间是不连续的。

        因此, 模型用一个 1× i 的卷积操作(像一个“滑动窗口”)把 Th 这样的近期数据总结成一个特征。而用一个 1×1 的卷积操作(类似“取样分析”)分别处理像 Tw 、 Td 的时间点,确保它们的特征不会被遗漏。最后模型把所有卷积的结果加在一起,生成最终的预测值。

class ST_BLOCK_2(nn.Module):
    def __init__(self, c_in, c_out, num_nodes, tem_size, K, Kt):
        super(ST_BLOCK_2, self).__init__()

        # 1×1 的卷积操作
        self.conv1 = Conv2d(c_in, c_out, kernel_size=(1, 1), stride=(1, 1), bias=True)
        # 1×i 的卷积操作
        self.time_conv = Conv2d(c_in, c_out, kernel_size=(1, Kt), padding=(0, 1), stride=(1, 1), bias=True)

        ···

二、实验

  1. DGCN_R:仅使用近期数据的模型版本。
  2. DGCN_Mask: 使用掩码拉普拉斯矩阵。
  3. DGCN_Res: 使用残差拉普拉斯矩阵。
  4. DGCN_GAT: 使用多头注意力机制(GAT)替代切比雪夫多项式。

        


http://www.kler.cn/a/444881.html

相关文章:

  • 每天40分玩转Django:Django部署概述
  • 底层解析v-modle和v-bind在绑定数据时的内存模型上的区别
  • 探索 Java 静态变量(static)的奥秘
  • python流行的web框架对比分析
  • C语言-基因序列转换独热码(one-hot code)
  • 加载Tokenizer和基础模型的解析及文件介绍:from_pretrained到底加载了什么?
  • Python读取Excel批量写入到PPT生成词卡
  • 配置免密登陆服务器
  • python快速接入阿里云百炼大模型
  • 【数据分析】数据分析流程优化:从数据采集到可视化的全面指南
  • 一篇文章理解前端的请求头和响应头含义
  • 打 印 菱 形
  • Gartner发布2025年网络安全主要趋势:实现转型和嵌入弹性两大主题下的9个趋势
  • Linux性能监控命令_nmon 安装与使用以及生成分析Excel图表
  • 基于注意力机制的ResNet优化算法(三种注意力机制+源码+pytorch)
  • 4、交换机IP接口功能
  • git 删除鉴权缓存及账号信息
  • 基于时间情境创造与 AI 智能名片 S2B2C 商城小程序源码的零售创新策略研究
  • 从零开始学习HTML5
  • 【Linux】文件IO--read/write/缓冲区(详)
  • 防火墙规则配置错误导致的网络问题排查
  • 用C#(.NET8)开发一个NTP(SNTP)服务
  • windwos defender实现白名单效果(除了指定应用或端口其它一律禁止)禁止服务器上网
  • pycharm debug
  • 网络安全概论——入侵检测系统IDS
  • 使用python的模块cryptography对文件加密