当前位置: 首页 > article >正文

20250226-代码笔记04-class CVRP_Encoder AND class EncoderLayer

文章目录

  • 前言
  • 一、class CVRP_Encoder(nn.Module):__init__(self, **model_params)
    • 函数功能
    • 函数代码
  • 二、class CVRP_Encoder(nn.Module):forward(self, depot_xy, node_xy_demand)
    • 函数功能
    • 函数代码
  • 三、class EncoderLayer(nn.Module):__init__(self, **model_params)
    • 函数功能
    • 函数代码
  • 四、class EncoderLayer(nn.Module):forward(self, input1)
    • 函数功能
    • 函数代码
  • 附录
    • 代码(全)


前言

  • class CVRP_Encoder
  • class EncoderLayer(nn.Module)

以上是CVRPModel.py的类

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPModel.py


一、class CVRP_Encoder(nn.Module):init(self, **model_params)

函数功能

该代码片段是 CVRP_Encoder 类的 init 构造函数。其主要功能是 ==初始化模型的参数 ==和 构建神经网络的层,用于解决 容量限制的车辆路径问题(CVRP)。

执行流程图链接
在这里插入图片描述

函数代码

    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        encoder_layer_num = self.model_params['encoder_layer_num']

        self.embedding_depot = nn.Linear(2, embedding_dim)
        self.embedding_node = nn.Linear(3, embedding_dim)
        self.layers = nn.ModuleList([EncoderLayer(**model_params) for _ in range(encoder_layer_num)])


二、class CVRP_Encoder(nn.Module):forward(self, depot_xy, node_xy_demand)

函数功能

forward 方法是 CVRP_Encoder 类的一部分,主要功能是将仓库(depot)和客户节点(node)的位置和需求信息进行嵌入(embedding)并通过编码器进行处理。
它返回处理后的节点和仓库的嵌入表示,用于后续的模型推理。

执行流程图链接
在这里插入图片描述

函数代码

    def forward(self, depot_xy, node_xy_demand):
        # depot_xy.shape: (batch, 1, 2)
        # node_xy_demand.shape: (batch, problem, 3)

        embedded_depot = self.embedding_depot(depot_xy)
        # shape: (batch, 1, embedding)
        embedded_node = self.embedding_node(node_xy_demand)
        # shape: (batch, problem, embedding)

        out = torch.cat((embedded_depot, embedded_node), dim=1)
        # shape: (batch, problem+1, embedding)

        for layer in self.layers:
            out = layer(out)

        return out
        # shape: (batch, problem+1, embedding)


三、class EncoderLayer(nn.Module):init(self, **model_params)

函数功能

init 方法是 EncoderLayer 类的构造函数,主要功能是初始化编码器层的各个子组件。
设置多头注意力机制所需的权重矩阵和相关的正则化层、前馈网络等,以便在 forward 方法中执行实际的操作。

执行流程图链接
在这里插入图片描述

函数代码

    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.add_n_normalization_1 = AddAndInstanceNormalization(**model_params)
        self.feed_forward = FeedForward(**model_params)
        self.add_n_normalization_2 = AddAndInstanceNormalization(**model_params)


四、class EncoderLayer(nn.Module):forward(self, input1)

函数功能

forward 方法是 EncoderLayer 类中的前向传播函数,主要功能是执行 多头自注意力机制前馈神经网络,并进行相应的 残差连接归一化
执行流程图链接
在这里插入图片描述

函数代码

    def forward(self, input1):
        # input1.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']

        q = reshape_by_heads(self.Wq(input1), head_num=head_num)
        k = reshape_by_heads(self.Wk(input1), head_num=head_num)
        v = reshape_by_heads(self.Wv(input1), head_num=head_num)
        # qkv shape: (batch, head_num, problem, qkv_dim)

        out_concat = multi_head_attention(q, k, v)
        # shape: (batch, problem, head_num*qkv_dim)

        multi_head_out = self.multi_head_combine(out_concat)
        # shape: (batch, problem, embedding)

        out1 = self.add_n_normalization_1(input1, multi_head_out)
        out2 = self.feed_forward(out1)
        out3 = self.add_n_normalization_2(out1, out2)

        return out3
        # shape: (batch, problem, embedding)


附录

代码(全)

########################################
# ENCODER
########################################

class CVRP_Encoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        encoder_layer_num = self.model_params['encoder_layer_num']

        self.embedding_depot = nn.Linear(2, embedding_dim)
        self.embedding_node = nn.Linear(3, embedding_dim)
        self.layers = nn.ModuleList([EncoderLayer(**model_params) for _ in range(encoder_layer_num)])

    def forward(self, depot_xy, node_xy_demand):
        # depot_xy.shape: (batch, 1, 2)
        # node_xy_demand.shape: (batch, problem, 3)

        embedded_depot = self.embedding_depot(depot_xy)
        # shape: (batch, 1, embedding)
        embedded_node = self.embedding_node(node_xy_demand)
        # shape: (batch, problem, embedding)

        out = torch.cat((embedded_depot, embedded_node), dim=1)
        # shape: (batch, problem+1, embedding)

        for layer in self.layers:
            out = layer(out)

        return out
        # shape: (batch, problem+1, embedding)


class EncoderLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.add_n_normalization_1 = AddAndInstanceNormalization(**model_params)
        self.feed_forward = FeedForward(**model_params)
        self.add_n_normalization_2 = AddAndInstanceNormalization(**model_params)

    def forward(self, input1):
        # input1.shape: (batch, problem+1, embedding) 364t
    
        head_num = self.model_params['head_num']

        q = reshape_by_heads(self.Wq(input1), head_num=head_num)
        k = reshape_by_heads(self.Wk(input1), head_num=head_num)
        v = reshape_by_heads(self.Wv(input1), head_num=head_num)
        # qkv shape: (batch, head_num, problem, qkv_dim)

        out_concat = multi_head_attention(q, k, v)
        # shape: (batch, problem, head_num*qkv_dim)

        multi_head_out = self.multi_head_combine(out_concat)
        # shape: (batch, problem, embedding)

        out1 = self.add_n_normalization_1(input1, multi_head_out)
        out2 = self.feed_forward(out1)
        out3 = self.add_n_normalization_2(out1, out2)

        return out3
        # shape: (batch, problem, embedding)


http://www.kler.cn/a/564841.html

相关文章:

  • flutter项目构建常见问题
  • iptables核心和简例[NET]
  • 客户端进程突然结束,服务端read是什么行为?
  • STM32基于HAL库(CUBEMX)MPU6050 DMP的移植(新手一看必会)
  • 蓝桥杯备考1
  • 51c自动驾驶~合集52
  • Oracle 查询表空间使用情况及收缩数据文件
  • 基于Spring Boot的乡村养老服务管理系统设计与实现(LW+源码+讲解)
  • 谷云科技iPaaS×DeepSeek:构建企业智能集成的核心底座
  • 如何使用豆包AI来快速提升编程能力?
  • netcore入门案例:netcore api连接mysql的完整记事本接口示例
  • 纯c#字体处理库(FontParser) -- 轻量、极速、跨平台、具有字体子集化功能
  • 火语言RPA--Word打开文档
  • 一劳永逸解决vsocde模块import引用问题
  • 视频级虚拟试衣技术在淘宝的产品化实践
  • 【编程语言】Bash使用教程
  • 如何更改vim命令创建代码文件时的默认模板
  • 【Eureka 缓存机制】
  • 【js逆向入门】图灵爬虫练习平台 第八题
  • Flutter - StatefulWidget (有状态的 Widget) 和 生命周期