DGCN论文解读
一、时空特征提取结构
作者团队所提出的 DGCN 的时空特征提取结构主要包括两个组成部分:拉普拉斯矩阵潜在网络(LMLN)和基于GCN的交通预测组件。
1 拉普拉斯矩阵潜在网络(LMLN)
这是用于估计观察到的交通数据的拉普拉斯矩阵。这个矩阵帮助模型理解“哪条路会影响另一条路的交通状况”,从而更好地预测未来的交通流量或速度。 如下图所示,拉普拉斯矩阵预测单元包含三个模块:特征采样、空间注意力和LSTM单元。
1.1 全局拉普拉斯矩阵学习层(Global Laplace matrix Learing Layer)
首先,道路图(Road map)需要被简化为道路网(Road network),这样才能得到一个关于道路交通网络的邻接矩阵 W,然 后使用对称归一化可以得到拉普拉斯矩阵 L。在论文第二部分相关工作中提到,拉普拉斯矩阵 L 的分解计算复杂,因此提出了快速GCN(Fast-GCN),这种方法的主要思想是在每次训练迭代时随机采样(sampling)一部分节点和邻居,而不是处理整个图。即可以对拉普拉斯矩阵进行进一步缩放。其中 λ 是拉普拉斯矩阵的最大特征值。
# 计算归一化拉普拉斯矩阵
def scaled_Laplacian(W):
assert W.shape[0] == W.shape[1]
D = np.diag(np.sum(W, axis=1))
L = D - W
lambda_max = eigs(L, k=1, which='LR')[0].real
return (2 * L) / lambda_max - np.identity(W.shape[0])
然后再通过全局拉普拉斯矩阵学习层处理。在这个过程中,用一个可训练的参数化矩阵 Lpar 调整邻接关系,即和上一步的缩放拉普拉斯矩阵 L 相加,就可以得到一个残差拉普拉斯矩阵 Lres1。公式中 Dres 表示节点 i 的度矩阵,也就是路段 i 和其他路段的连接数量。对其取逆并和原来的残差拉普拉斯矩阵相乘,那么度数较大的节点的影响力就会变得较小,而度数较小的节点的影响力则会被放大。确保图中每个节点对整个图结构的影响力是公平的,避免度数较大的节点在图卷积计算中占主导地位。这能帮助模型更准确地捕捉到图中节点之间的真实关系。
A = self.h + supports # self.h即公式中的Lpar
d = 1 / (torch.sum(A, -1) + 0.0001)
D = torch.diag_embed(d)
A = torch.matmul(D, A)
A1 = F.dropout(A, 0.5, self.training)
1.2 特征采样(Feature Sampling)
作者提出了一种特征采样方案,根据不同时间间隔的重要性来减少交通特征的数据维度。举个例子,假设在预测某条道路未来一小时的交通流量,最近的1小时内的交通情况对预测未来的交通流量影响最大,所以会给最近的数据更多的权重。然后,把过去4小时内每10分钟的数据合并成一个整体特征,以简化计算过程,同时保持对预测有用的时空信息。
在实际的代码中,作者用两个不同的空间注意力 SATT_3 和 SATT_2 分别计算不同时间步的特征。
# SATT_3中,通道数(c_in)扩大到 12 倍,时间维度被缩小为原来的 1/12
shape = seq.shape
seq = seq.permute(0, 1, 3, 2).contiguous().view(shape[0], shape[1] * 12, shape[3] // 12, shape[2])
seq = seq.permute(0, 1, 3, 2)
x_1 = self.time_conv(x)
x_1 = F.leaky_relu(x_1)
x_tem1 = x_1[:, :, :, 0:48] # 前48个时间步
x_tem2 = x_1[:, :, :, 48:60] # 后12个时间步
# 合并两个空间注意力
S_coef1 = self.SATT_3(x_tem1)
S_coef2 = self.SATT_2(x_tem2)
S_coef = torch.cat((S_coef1, S_coef2), 1) # b,l,n,c
1.3 空间注意力(Spatial Attention)
空间注意力回答的问题是:在当前时刻,哪些节点(例如交通路段)对某个节点的状态变化有更大的影响。因此,作者通过空间注意力构建每个时刻道路网络的空间关系,计算得到一个动态邻接矩阵 Ld。计算方式如下,F(1:τ) 表示在时间范围 1 到 τ 内的节点特征。W 是可学习的权重矩阵。想象一下,矩阵 F 的某一行是这个路段对其他路段的影响,而某一列是其他路段对这个路段的影响。转置后,点积操作会计算每个节点对其他节点的相似性,从而生成一个节点之间的相互关系矩阵 Ld。最后再通过 sigmoid 函数归一化。
# 输入通道缩小为原来的 1/4
f1 = self.conv1(seq).view(shape[0], self.c_in // 4, 4, shape[2], shape[3]).permute(0, 3, 1, 4, 2).contiguous()
f2 = self.conv2(seq).view(shape[0], self.c_in // 4, 4, shape[2], shape[3]).permute(0, 1, 3, 4, 2).contiguous()
logits = torch.einsum('bnclm,bcqlm->bnqlm', f1, f2)
logits = logits.permute(0, 3, 1, 2, 4).contiguous()
logits = torch.sigmoid(logits)
1.4 LSTM单元(LSTM Unit)
之前提到,空间注意力由 SATT_3 和 SATT_2 组成。这会产生两个不同的邻接矩阵 Ld,因此作者将这两个 Ld 合并为一个,并使用 LSTM 分析动态邻接矩阵序列 Ld 在不同时间点之间的变化规律。那么这里计算的 Ld 和上面计算出的 Lres 有什么关系?Lres 是优化后的全局拉普拉斯矩阵,用于提供长期的全局图结构约束。它是更稳定、更大范围的空间表示。而 Ld 是通过注意力机制动态生成的邻接矩阵,用于反映节点的短期动态关系。两者相乘,能够同时保留了 Ld 的动态变化性和 Lres 的全局稳定性。
# 通过 LSTM 学习空间注意力的序列依赖
S_coef = S_coef.permute(0, 2, 1, 3).contiguous().view(shape[0] * shape[2], shape[1], shape[3])
S_coef = F.dropout(S_coef, 0.5, self.training)
_, hidden = self.LSTM(S_coef, hidden)
adj_out = hidden[0].squeeze().view(shape[0], shape[2], shape[3]).contiguous()
adj_out1 = (adj_out) * supports
2 基于GCN的交通预测
2.1 图时间卷积层(Graph Temporal Convolutional Layer, GTCL)
传统的交通预测领域中,GCN 和时间卷积层(TCL)堆叠为一个时空块(Spatial-Temporal Block),但是这样做的话计算量会非常大。因此,作者提出将这两种功能集成为一个单一的图时间卷积层(GTCL)。式中,TC 代表通过一维卷积提取到的局部时间特征,Cm(L) 代表用 M 阶切比雪夫多项式 Cm 去近似拉普拉斯矩阵 L(也就是之前计算出的 Lp),Φm 是每一阶切比雪夫多项式的可训练图卷积核参数。
def forward(self, x, adj):
nSample, feat_in, nNode, length = x.shape
Ls = []
L1 = adj
L0 = torch.eye(nNode).repeat(nSample, 1, 1).cuda()
Ls.append(L0)
Ls.append(L1)
# 计算 Chebyshev 多项式 T_k
for k in range(2, self.K):
L2 = 2 * torch.matmul(adj, L1) - L0
L0, L1 = L1, L2
Ls.append(L2)
# 将所有的 Chebyshev 多项式堆叠起来
Lap = torch.stack(Ls, 1) # [B, K,nNode, nNode]
Lap = Lap.transpose(-1, -2)
# 图卷积:将输入 x 与 Chebyshev 多项式 Lap 进行张量乘法
x = torch.einsum('bcnl,bknq->bckql', x, Lap).contiguous()
x = x.view(nSample, -1, nNode, length)
out = self.conv1(x)
······
x_1 = self.dynamic_gcn(x_1, adj_out1) # dynamic_gcn为Chebyshev图卷积,也就是上面的前向传播函数
除此之外,作者还设计了一种门控机制来探索局部时间特征。把拆分出来的 β1 用 sigmoid 激活函数处理,生成一个介于 [0,1] 的权重值来计算每个特征的重要性。β2 通过非线性处理(Leaky ReLU),会作为 “主特征” 与 β1 相乘。那么模型就不用处理所有特征,而把重点放在分数高的特征上。
# 将图卷积输出拆分为滤波器和门控
filter, gate = torch.split(x_1, [self.c_out, self.c_out], 1)
x_1 = torch.sigmoid(gate) * F.leaky_relu(filter)
x_1 = F.dropout(x_1, 0.5, self.training)
2.2 时间注意力机制(Temporal Attention Mechanism)
时间注意力回答的问题是:一个节点在不同的时间点(过去或未来)之间,哪些时刻的状态对当前时刻的预测更重要。这一块儿公式的解读可以参考论文 ASTGCN。
2.3 批归一化(Batch Normalization)
假如训练一个网络,输入数据的范围很混乱,比如有时是 0~1,有时是 10~100。这会让网络难以找到合适的参数。归一化就像一个“整顿器”,让每一批数据的分布变得均匀有序,让模型更容易学习。
# 归一化
def normalization(train, val, test):
assert train.shape[1:] == val.shape[1:] and val.shape[1:] == test.shape[1:]
# xis=0 表示沿第一个轴计算均值,keepdims=True 表示保持维度一致
mean = train.mean(axis=0, keepdims=True) # 计算训练集的均值
std = train.std(axis=0, keepdims=True) # 计算训练集的标准差
# 定义归一化函数
def normalize(x):
return (x - mean) / std
train = (train).transpose(0, 2, 1, 3)
val = (val).transpose(0, 2, 1, 3)
test = (test).transpose(0, 2, 1, 3)
return {'mean': mean, 'std': std}, train, val, test
2.4 损失函数(Loss Function)
如下图所示,最近的一段数据(Th)是连续的,比如过去10分钟的数据,它们是紧密相关的。而每日周期(Td),每周周期(Tw)这些数据并不是连续的,比如每天早上8点的情况,或者每周一的早上8点。这些时间点之间是不连续的。
因此, 模型用一个 1× i 的卷积操作(像一个“滑动窗口”)把 Th 这样的近期数据总结成一个特征。而用一个 1×1 的卷积操作(类似“取样分析”)分别处理像 Tw 、 Td 的时间点,确保它们的特征不会被遗漏。最后模型把所有卷积的结果加在一起,生成最终的预测值。
class ST_BLOCK_2(nn.Module):
def __init__(self, c_in, c_out, num_nodes, tem_size, K, Kt):
super(ST_BLOCK_2, self).__init__()
# 1×1 的卷积操作
self.conv1 = Conv2d(c_in, c_out, kernel_size=(1, 1), stride=(1, 1), bias=True)
# 1×i 的卷积操作
self.time_conv = Conv2d(c_in, c_out, kernel_size=(1, Kt), padding=(0, 1), stride=(1, 1), bias=True)
···
二、实验
- DGCN_R:仅使用近期数据的模型版本。
- DGCN_Mask: 使用掩码拉普拉斯矩阵。
- DGCN_Res: 使用残差拉普拉斯矩阵。
- DGCN_GAT: 使用多头注意力机制(GAT)替代切比雪夫多项式。