当前位置: 首页 > article >正文

GCN,GraphSAGE 到底在训练什么呢?

根据DGL 来做的,按照DGL 实现来讲述

1. GCN Cora 训练代码:

import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv


class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h


def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
    for e in range(100):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that you should only compute the losses of the nodes in the training set.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if e % 5 == 0:
            print(
                f"In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})"
            )




if __name__ == "__main__" :
    dataset = dgl.data.CoraGraphDataset()
    # print(f"Number of categories: {dataset.num_classes}")
    g = dataset[0]
    g = g.to('cuda')
    model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes).to('cuda')
    train(g, model)

一些基础python torch.tensor语法概述:

1.  

if __name__ == "__main__" :
    XXXXXXX
    XXXXXXX

当我们直接执行这个脚本时,__name__属性被设置为__main__,因此满足if条件,语句块中的代码被调用。
但如果我们将该脚本作为模块导入到另一个脚本中,则__name__属性会被设置为模块的名称(例如"example"),语句块中的代码不会被执行。

2. 

# Compute prediction
pred = logits.argmax(1)    # 返回沿着第一个维度(即维度索引为1)的最大值的索引。
                           # 即,加入有5个样本,每个样本有3个维度的评分,那么就会给出没个样本3中维度评分最高的哪个维度的索引序号

 

3. numpy 关于 tensor 的一个用法:

在DGL 中使用一串 True 或 False 组成的 一维tensor 来标识 这个节点到底是属于 train test val 哪一类

train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]

而后,由于对于torch中的tensor来说:

就可以:select_label_tensor = labels[train_mask] 了

import torch

# 定义一个Tensor
tensor = torch.tensor([1, 2, 3, 4, 5])

# 定义一个布尔数组,选择索引为1和4的元素
mask = torch.tensor([False, True, False, False, True])

# 通过布尔索引选择元素
selected_tensor = tensor[mask]

print(selected_tensor)  # tensor([2, 5])

顺便,查看一个变量到底是什么类型可以使用 type() 函数:

train_mask = g.ndata["train_mask"]
print(type(train_mask))

# 输出为:
# <class 'torch.Tensor'>


http://www.kler.cn/news/157278.html

相关文章:

  • 卷积神经网络(CNN):乳腺癌识别.ipynb
  • ChatGPT使用路径:从新手到专家的指南
  • RedisTemplate序列化的问题
  • ElasticSearch学习笔记(一)
  • Redis Hash数据类型
  • C#的参数数组
  • 使用ES6 async awai t进行异步处理
  • python - abstractmethod作用 - `staticmethod`和`abc.abstractmethod`:它会混合吗?
  • Git和Git小乌龟安装
  • make -c VS make -f
  • 电脑发生0x80070002错误,0x80070002错误代码怎么解决
  • G口大带宽是什么意思?
  • Appium:进行iOS自动化测试遇到的问题与解决方案
  • Learning Normal Dynamics in Videos with Meta Prototype Network 论文阅读
  • 网络安全小白自学
  • 【qml入门教程系列】:qml property使用介绍
  • 【static】关键字静态成员:在类级别上共享数据和方法的机制
  • BFS求树的宽度——结合数组建树思想算距离
  • GPT市场将取代插件商店 openAI已经关闭plugins申请,全部集成到GPTs(Actions)来连接现实世界,可以与物理世界互动了。
  • 不再只是android,华为自爆Harmony将对标iOS
  • C# AES-128-CBC 加密
  • 【电源专题】什么是电源管理
  • OpenCV快速入门:移动物体检测和目标跟踪
  • python 运用pandas 库处理excel 表格数据
  • C++11的互斥量
  • C语言枚举
  • react-native实践日记--3.ui-kitten中的button设置字体颜色无效
  • AI医疗交流平台【Docola】申请823万美元纳斯达克IPO上市
  • json序列化时Long类型转换为String类型
  • Day50力扣打卡