ActivationType, Pool, ModelType(helpers文件中的classes.py)

该代码实现了一系列类与方法,主要用于图神经网络(GNN)模型的编码、激活函数、池化操作等。EnvArgsActionNetArgs 类用于根据配置参数生成网络结构,并通过 ModelTypeActivationTypePool 等控制模型的组件与行为。代码的主要目的是构建可配置的图神经网络模型,并在这些模型中实现了不同的特征编码器、激活函数、池化策略等。

from helpers.classes import ActivationType, Pool, ModelType

from enum import Enum, auto
from torch.nn import Linear, ModuleList, Module, Dropout, ReLU, GELU, Sequential
from torch import Tensor
from typing import NamedTuple, Any, Callable
import torch.nn.functional as F
from torch_geometric.nn.pool import global_mean_pool, global_add_pool

from helpers.metrics import MetricType
from helpers.model import ModelType
from helpers.encoders import DataSetEncoders, PosEncoder
from lrgb.encoders.composition import Concat2NodeEncoder

class ActivationType(Enum):
        an object for the different activation types
    RELU = auto()
    GELU = auto()

    def from_string(s: str):
            return ActivationType[s]
        except KeyError:
            raise ValueError()

    def get(self):
        if self is ActivationType.RELU:
            return F.relu
        elif self is ActivationType.GELU:
            return F.gelu
            raise ValueError(f'ActivationType {self.name} not supported')

    def nn(self) -> Module:
        if self is ActivationType.RELU:
            return ReLU()
        elif self is ActivationType.GELU:
            return GELU()
            raise ValueError(f'ActivationType {self.name} not supported')

class GumbelArgs(NamedTuple):
    learn_temp: bool
    temp_model_type: ModelType
    tau0: float
    temp: float
    gin_mlp_func: Callable

class Pool(Enum):
        an object for the different activation types
    NONE = auto()
    MEAN = auto()
    SUM = auto()

    def from_string(s: str):
            return Pool[s]
        except KeyError:
            raise ValueError()

    def get(self):
        if self is Pool.MEAN:
            return global_mean_pool
        elif self is Pool.SUM:
            return global_add_pool
        elif self is Pool.NONE:
            return BatchIdentity()
            raise ValueError(f'Pool {self.name} not supported')

class EnvArgs(NamedTuple):
    model_type: ModelType
    num_layers: int
    env_dim: int

    layer_norm: bool
    skip: bool
    batch_norm: bool
    dropout: float
    act_type: ActivationType
    dec_num_layers: int
    pos_enc: PosEncoder
    dataset_encoders: DataSetEncoders

    metric_type: MetricType
    in_dim: int
    out_dim: int

    gin_mlp_func: Callable

    def load_net(self) -> ModuleList:
        if self.pos_enc is PosEncoder.NONE:
            enc_list = [self.dataset_encoders.node_encoder(in_dim=self.in_dim, emb_dim=self.env_dim)]
            if self.dataset_encoders is DataSetEncoders.NONE:
                enc_list = [self.pos_enc.get(in_dim=self.in_dim, emb_dim=self.env_dim)]
                enc_list = [Concat2NodeEncoder(enc1_cls=self.dataset_encoders.node_encoder,
                                               in_dim=self.in_dim, emb_dim=self.env_dim,

        component_list =\
            self.model_type.get_component_list(in_dim=self.env_dim, hidden_dim=self.env_dim,  out_dim=self.env_dim,
                                               num_layers=self.num_layers, bias=True, edges_required=True,

        if self.dec_num_layers > 1:
            mlp_list = (self.dec_num_layers - 1) * [Linear(self.env_dim, self.env_dim),
                                                    Dropout(self.dropout), self.act_type.nn()]
            mlp_list = mlp_list + [Linear(self.env_dim, self.out_dim)]
            dec_list = [Sequential(*mlp_list)]
            dec_list = [Linear(self.env_dim, self.out_dim)]

        return ModuleList(enc_list + component_list + dec_list)

class ActionNetArgs(NamedTuple):
    model_type: ModelType
    num_layers: int
    hidden_dim: int

    dropout: float
    act_type: ActivationType

    env_dim: int
    gin_mlp_func: Callable
    def load_net(self) -> ModuleList:
        net = self.model_type.get_component_list(in_dim=self.env_dim, hidden_dim=self.hidden_dim, out_dim=2,
                                                 num_layers=self.num_layers, bias=True, edges_required=False,
        return ModuleList(net)

class BatchIdentity(Module):
    def __init__(self, *args: Any, **kwargs: Any) -> None:

    def forward(self, x: Tensor, batch: Tensor) -> Tensor:
        return x




