当前位置: 首页 > article >正文

CoGNN(models文件中的CoGNN.py)

CoGNN 类实现了一个图神经网络(GNN),该模型利用 Gumbel-Softmax 技术动态调整边权重(边选择),并根据输入的节点特征和边信息进行图嵌入生成。模型的主要功能包括环境编码、节点和边的编码、边权重的创建和应用、跳跃连接(skip connection)、以及池化操作,用于生成整个图的嵌入。模型中的 Gumbel-Softmax 温度参数可以学习或固定。

import torch
from torch import Tensor
from torch_geometric.typing import Adj, OptTensor
from torch.nn import Module, Dropout, LayerNorm, Identity
import torch.nn.functional as F
from typing import Tuple
import numpy as np

from helpers.classes import GumbelArgs, EnvArgs, ActionNetArgs, Pool, DataSetEncoders
from models.temp import TempSoftPlus
from models.action import ActionNet


class CoGNN(Module):
    def __init__(self, gumbel_args: GumbelArgs, env_args: EnvArgs, action_args: ActionNetArgs, pool: Pool):
        super(CoGNN, self).__init__()
        self.env_args = env_args
        self.learn_temp = gumbel_args.learn_temp
        if gumbel_args.learn_temp:
            self.temp_model = TempSoftPlus(gumbel_args=gumbel_args, env_dim=env_args.env_dim)
        self.temp = gumbel_args.temp

        self.num_layers = env_args.num_layers
        self.env_net = env_args.load_net()
        self.use_encoders = env_args.dataset_encoders.use_encoders()

        layer_norm_cls = LayerNorm if env_args.layer_norm else Identity
        self.hidden_layer_norm = layer_norm_cls(env_args.env_dim)
        self.skip = env_args.skip
        self.dropout = Dropout(p=env_args.dropout)
        self.drop_ratio = env_args.dropout
        self.act = env_args.act_type.get()
        self.in_act_net = ActionNet(action_args=action_args)
        self.out_act_net = ActionNet(action_args=action_args)

        # Encoder types
        self.dataset_encoder = env_args.dataset_encoders
        self.env_bond_encoder = self.dataset_encoder.edge_encoder(emb_dim=env_args.env_dim, model_type=env_args.model_type)
        self.act_bond_encoder = self.dataset_encoder.edge_encoder(emb_dim=action_args.hidden_dim, model_type=action_args.model_type)

        # Pooling function to generate whole-graph embeddings
        self.pooling = pool.get()

    def forward(self, x: Tensor, edge_index: Adj, pestat, edge_attr: OptTensor = None, batch: OptTensor = None,
                edge_ratio_node_mask: OptTensor = None) -> Tuple[Tensor, Tensor]:
        result = 0

        calc_stats = edge_ratio_node_mask is not None
        if calc_stats:
            edge_ratio_edge_mask = edge_ratio_node_mask[edge_index[0]] & edge_ratio_node_mask[edge_index[1]]
            edge_ratio_list = []

        # bond encode
        if edge_attr is None or self.env_bond_encoder is None:
            env_edge_embedding = None
        else:
            env_edge_embedding = self.env_bond_encoder(edge_attr)
        if edge_attr is None or self.act_bond_encoder is None:
            act_edge_embedding = None
        else:
            act_edge_embedding = self.act_bond_encoder(edge_attr)

        # node encode  
        x = self.env_net[0](x, pestat)  # (N, F) encoder
        if not self.use_encoders:
            x = self.dropout(x)
            x = self.act(x)

        for gnn_idx in range(self.num_layers):
            x = self.hidden_layer_norm(x)

            # action
            in_logits = self.in_act_net(x=x, edge_index=edge_index, env_edge_attr=env_edge_embedding,
                                        act_edge_attr=act_edge_embedding)  # (N, 2)
            
            out_logits = self.out_act_net(x=x, edge_index=edge_index, env_edge_attr=env_edge_embedding,
                                          act_edge_attr=act_edge_embedding)  # (N, 2)

            t

http://www.kler.cn/news/367579.html

相关文章:

  • 银河麒麟相关
  • 【WPF】中Dispatcher的DispatcherPriority参数使用
  • 面包种类图像分割系统:多层面改进
  • 改进YOLOv8系列:引入低照度图像增强网络Retinexformer | 优化低光照目标检测那题
  • AI智能爆发:从自动驾驶到智能家居,科技如何改变我们的日常?
  • <Project-11 Calculator> 计算器 0.3 年龄计算器 age Calculator HTML JS
  • 【AI大模型】ChatGPT模型原理介绍
  • Mybatis-plus-入门
  • 2024年10月第3个交易周收盘总结
  • 工具_OpenSSL
  • 【微软商店平台】如何将exe打包上传微软商店
  • SpringCloud学习(补漏)
  • 哈希表之哈希数组、HashSet
  • 随机变量、取值、样本和统计量之间的关系
  • 智能科学与技术(一级学科)介绍
  • 从0开始深度学习(16)——暂退法(Dropout)
  • C++笔记---位图
  • PHP如何抛出和接收错误
  • C语言[求x的y次方]
  • 7.hyperf安装【Docker】
  • 京东电商下单黄金链路:防止订单重复提交与支付的深度解析
  • Pseudo Multi-Camera Editing 数据集:通过常规视频生成的伪标记多摄像机推荐数据集,显著提升模型在未知领域的准确性。
  • 背包九讲——混合背包问题
  • 虾类图像分割系统:改进亮点优化
  • 前端项目接入sqlite轻量级数据库sql.js指南
  • ffmpeg视频滤镜: 色温- colortemperature