当前位置: 首页 > article >正文

pytorch图神经网络处理图结构数据

图神经网络(Graph Neural Networks,GNNs)是一类能够处理图结构数据的深度学习模型。图结构数据由节点(vertices)和边(edges)组成,其中节点表示实体,边表示实体之间的关系或连接。GNNs 通过在图的结构上进行信息传递和节点嵌入(node embedding)来学习节点或图的特征表示。

GNN的关键思想是通过消息传递机制(message passing)更新每个节点的表示,通常是基于其邻居节点的特征信息。GNNs 可以广泛应用于许多领域,如社交网络分析、推荐系统、知识图谱、分子图表示等。

以下是GNN的基本组成部分和工作原理:

  1. 节点表示更新:每个节点的表示通过其邻居节点的表示进行更新。常见的做法是通过聚合邻居节点的特征,然后与节点本身的特征进行结合

GNN的变种

  1. GCN(Graph Convolutional Networks):一种基于图卷积的GNN,通过聚合邻居节点的特征来更新节点表示,适用于无向图。

  2. GraphSAGE(Graph Sample and Aggregation):通过随机采样邻居节点来提高计算效率,尤其适用于大规模图。

  3. GAT(Graph Attention Networks):引入了注意力机制,使得不同邻居对节点更新的贡献不同,能够动态调整每个邻居的权重。

  4. Graph Isomorphism Network (GIN):通过强大的表征能力增强了图的判别性。

GNN的应用

  • 社交网络分析:预测用户之间的关系或用户的兴趣。
  • 推荐系统:基于用户和物品之间的图结构进行个性化推荐。
  • 生物信息学:如分子图表示,用于药物发现、蛋白质结构预测等。
  • 图像分割与语义分析:在视觉任务中处理图形数据,捕捉图像之间的关系。

例子:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import matplotlib.pyplot as plt

# 1. 生成随机图数据
num_nodes = 100
x = torch.rand((num_nodes, 2))  # 100 个节点,每个节点有 2 维特征
y = (x[:, 0] + x[:, 1] > 1).long()  # 二分类标签(0 或 1)

# 2. 生成图结构(邻接关系)
edge_index = []
for i in range(num_nodes):
    for j in range(i + 1, num_nodes):
        if (y[i] == y[j] and torch.rand(1).item() > 0.6) or (y[i] != y[j] and torch.rand(1).item() > 0.9):
            edge_index.append([i, j])
            edge_index.append([j, i])
edge_index = torch.tensor(edge_index, dtype=torch.long).t()

# 3. 训练集和测试集
train_mask = torch.rand(num_nodes) < 0.8  # 80% 训练,20% 测试
test_mask = ~train_mask

# 4. 构造 PyG 数据对象
data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask)


# 5. 定义 4 层 GCN 模型
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(2, 16)
        self.conv2 = GCNConv(16, 16)
        self.conv3 = GCNConv(16, 16)  # 将 conv3 输出改为与输入维度相同
        self.conv4 = GCNConv(16, 2)  # 输出类别数 2

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index)) + x  # 跳跃连接,维度一致
        x = self.conv4(x, edge_index)
        return F.log_softmax(x, dim=1)  # 输出对数概率


# 6. 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)  # 学习率衰减

data = data.to(device)
num_epochs = 2000  # 增加训练轮数

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    scheduler.step()  # 逐步降低学习率

    if epoch % 200 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# 7. 评估模型
model.eval()
out = model(data)
pred = out.argmax(dim=1)  # 取最大值的索引作为类别
test_pred = pred[data.test_mask]
test_true = data.y[data.test_mask]

# 8. 过滤低置信度预测
proba = torch.exp(out)  # 转换为 softmax
test_pred[proba[data.test_mask].max(dim=1)[0] < 0.6] = -1  # 低置信度设为 -1

# 9. 可视化测试结果
test_mask_np = torch.arange(num_nodes)[data.test_mask].cpu().numpy()
test_pred_np = test_pred.cpu().numpy()
test_true_np = test_true.cpu().numpy()

plt.figure(figsize=(10, 5))
plt.scatter(test_mask_np, test_pred_np, color='blue', alpha=0.5, label='Predicted')
plt.scatter(test_mask_np, test_true_np, color='red', alpha=0.5, label='True')
plt.xlabel('Test Node Index')
plt.ylabel('Node Class')
plt.title('Test Results vs True Results')
plt.legend()
plt.show()

 


http://www.kler.cn/a/527124.html

相关文章:

  • CSS核心
  • 学习数据结构(5)单向链表的实现
  • 【自学嵌入式(7)天气时钟:WiFi模块、OLED模块、NTP模块开发】
  • 一文讲解Java中的异常处理机制
  • Day29(补)-【AI思考】-精准突围策略——从“时间贫困“到“效率自由“的逆袭方案
  • 基于Python的人工智能患者风险评估预测模型构建与应用研究(下)
  • Git进阶之旅:分支管理策略
  • 【华为OD-E卷 - 字符串化繁为简 100分(python、java、c++、js、c)】
  • 计算机网络一点事(23)
  • minimind - 从零开始训练小型语言模型
  • 树莓派入门笔记(二)最常用的树莓派 Linux 命令及说明_树莓派系统命令
  • PostgreSQL TRUNCATE TABLE 操作详解
  • AVL搜索树
  • 商品列表及商品详情展示
  • 通过想像,见证奇迹
  • 【gRPC-gateway】初探grpc网关,插件安装,默认实现,go案例
  • Mysql进阶学习
  • 最新 Android 热门开源项目公布
  • 稀疏混合专家架构语言模型(MoE)
  • 【4Day创客实践入门教程】Day4 迈向高手之路——进一步学习!
  • .cc扩展名是什么语言?C语言必须用.c为扩展名吗?主流编程语言扩展名?Java为什么不能用全数字的文件名?
  • 七、深入了解SpringBoot的配置文件
  • 代随(138):单调栈:一维接雨水
  • 如何将IP切换到海外:详细指南
  • WebSocket使用及优化(心跳机制与断线重连)_websocket timeout
  • IT运维的365天--025 H3C交换机用NTP同步正确的时间