当前位置: 首页 > article >正文

20250301_代码笔记_函数class CVRPEnv: def step(self, selected)

文章目录

  • 前言
  • 1. 时间步数控制(time_step < 4)
  • 2. 时间步数3的特定操作
  • 3. 时间步数4的特定操作
  • 4. 更新当前节点和已选择节点列表
  • 5. 更新负载
  • 6. 更新访问标记
  • 7. 更新负无穷掩码
  • 8. 更新步骤状态
  • 9. 时间步大于等于 4 的复杂操作
    • 9.1. 动作模式分类(action0, action1, action2, action3)
    • 9.2. 动作模式的索引提取
    • 9.3. 更新选择计数
    • 9.4. 节点更新
    • 9.5. 更新已选择节点列表
    • 9.6. 更新负载(load)
    • 9. 7. 更新访问标记(visited_ninf_flag)
    • 9.8. 更新负无穷掩码(ninf_mask)
    • 9.9. 更新完成状态(finished)
    • 9.10. 更新模式(mode)
    • 9.11. 更新完成后的掩码调整
  • 10. 完成状态与奖励
  • 11. 返回值
  • 附录
    • 函数代码


前言

细读函数class CVRPEnv: def step(self, selected)

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPEnv.py

注:selectedstep()函数的输入参数。


1. 时间步数控制(time_step < 4)

if self.time_step < 4:
    self.time_step = self.time_step + 1
    self.selected_count = self.selected_count + 1
    self.at_the_depot = (selected == 0)

功能:

  • 判断当前的time_step是否小于4。如果是,则执行该时间段的逻辑。
    • self.time_step增加1,表示进入下一个时间步。
    • self.selected_count表示已经选择的节点数量,递增。
    • self.at_the_depot记录当前的节点是否为配送中心(depot)。如果selected == 0(即选择了配送中心),则设置为True

2. 时间步数3的特定操作

if self.time_step == 3:
    self.last_current_node = self.current_node.clone()
    self.last_load = self.load.clone()

功能:

  • time_step == 3时,记录当前节点和负载的状态,用于后续的操作。
    • self.last_current_nodeself.last_load保存当前状态,方便后续在步骤4时进行比较或更新。

3. 时间步数4的特定操作

if self.time_step == 4:
    self.last_current_node = self.current_node.clone()
    self.last_load = self.load.clone()
    self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot) & (self.last_current_node != 0)] = 0

功能:

  • time_step == 4时,保存当前节点和负载的状态。
    • 进一步处理visited_ninf_flag,标记已访问的节点,特别是那些不在配送中心并且之前没有访问过的节点。

4. 更新当前节点和已选择节点列表

self.current_node = selected
self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)

功能:
selected(当前选择的节点)设置为self.current_node
更新self.selected_node_list,将当前选择的节点添加到已选择的节点列表中。


5. 更新负载

demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
gathering_index = selected[:, :, None]
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
self.load -= selected_demand
self.load[self.at_the_depot] = 1  # refill loaded at the depot

功能:

  • 获取当前选择节点的需求量,并更新self.load(负载):
    • demand_list:扩展depot_node_demand以便按批次、POMO索引进行操作。
    • gathering_index:用于从demand_list中获取对应于selected的需求量。
    • self.load -= selected_demand:根据选择的节点减少当前负载。
    • 如果当前位置在配送中心(at_the_depot),则将负载重置为1,表示负载已满。

6. 更新访问标记

self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0  # depot is considered unvisited, unless you are AT the depot

功能:

  • 更新visited_ninf_flag以防止重复访问已选择的节点:
    • 对于已选择的节点,标记为负无穷(表示已访问)。
    • 对于配送中心(at_the_depot),设置为未访问(0),除非当前就在配送中心。

7. 更新负无穷掩码

self.ninf_mask = self.visited_ninf_flag.clone()
round_error_epsilon = 0.00001
demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
_2 = torch.full((demand_too_large.shape[0], demand_too_large.shape[1], 1), False)
demand_too_large = torch.cat((demand_too_large, _2), dim=2)
self.ninf_mask[demand_too_large] = float('-inf')

功能:

  • 更新ninf_mask,用于掩盖负载不足以承载需求量的节点:
    • demand_too_large表示负载不足以承载节点的需求,如果某个节点的需求量大于当前负载,则将其标记为True
    • 使用self.ninf_mask将这些节点掩盖为负无穷(-inf)。

8. 更新步骤状态

self.step_state.selected_count = self.time_step
self.step_state.load = self.load
self.step_state.current_node = self.current_node
self.step_state.ninf_mask = self.ninf_mask

功能:

  • 更新self.step_state,将当前步骤的状态(如selected_countload、current_nodeninf_mask)同步到步骤状态中。

9. 时间步大于等于 4 的复杂操作

9.1. 动作模式分类(action0, action1, action2, action3)

action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))
action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1))  # regret
action2_bool_index = self.mode == 1
action3_bool_index = self.mode == 2

功能:

  • action0_bool_index:表示选择了正常节点且当前模式为0的情况。selected != self.problem_size + 1是用来排除选择了特殊的节点(即problem_size + 1,通常是用于后悔机制或终止标记)。
  • action1_bool_index:表示当前模式为0并且选择了后悔节点(selected == self.problem_size + 1)。在这种情况下,智能体执行后悔操作(Regret)。
  • action2_bool_index:表示当前模式为1,意味着智能体正在执行某种特定的操作。
  • action3_bool_index:表示当前模式为2,意味着智能体正在执行另一种特定操作。

9.2. 动作模式的索引提取

action1_index = torch.nonzero(action1_bool_index)
action2_index = torch.nonzero(action2_bool_index)
action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))

功能:

  • action1_index:获取所有满足action1_bool_index的索引,即执行后悔操作的智能体。
  • action2_index:获取所有满足action2_bool_index的索引,执行操作模式1的智能体。
  • action4_index:获取所有满足action3_bool_index并且当前节点不为0(即不在配送中心)的智能体。

9.3. 更新选择计数

self.selected_count = self.selected_count + 1
self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2

功能:

  • self.selected_count增加1,表示当前有一个新的节点被选择。
  • 对于执行后悔操作的智能体(action1_bool_index),减少选择计数2,因为它们撤回了之前的选择。

9.4. 节点更新

self.last_is_depot = (self.last_current_node == 0)
_ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()
temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()
self.last_current_node = self.current_node.clone()
self.current_node = selected.clone()
self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()

功能:

  • self.last_is_depot:检查之前的节点是否为配送中心(0)。
  • 对于执行后悔操作的智能体(action1_index),保存它们上一步的current_node,然后将self.current_node更新为selected(当前选择的节点),并恢复后悔节点的选择。
  • self.last_current_node:保存当前的current_node,以便后悔操作或其他模式下的恢复。
  • temp_last_current_node_action2:用于存储执行操作模式2的智能体的节点,以便后续更新。

9.5. 更新已选择节点列表

self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)

功能:

  • 将当前选择的节点(selected)添加到已选择节点列表self.selected_node_list中。

9.6. 更新负载(load)

self.at_the_depot = (selected == 0)
demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
_3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)
demand_list = torch.cat((demand_list, _3), dim=2)
gathering_index = selected[:, :, None]
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
_1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()
self.last_load = self.load.clone()
self.load -= selected_demand
self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()
self.load[self.at_the_depot] = 1  # refill loaded at the depot

功能:

  • self.at_the_depot:判断是否选择了配送中心(selected == 0)。
  • demand_list:扩展depot_node_demand以便按批次、POMO索引进行操作。
  • 对于后悔操作的智能体(action1_index),负载被恢复为之前的值。
  • 通过gather方法获取当前选择节点的需求量,并更新self.load
  • 如果当前节点是配送中心,负载重置为1,表示重新加载。

9. 7. 更新访问标记(visited_ninf_flag)

self.visited_ninf_flag[:, :, self.problem_size + 1][self.last_is_depot] = 0
self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)
self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)
self.visited_ninf_flag[:, :, self.problem_size + 1][self.at_the_depot] = float('-inf')
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0

功能:

  • self.visited_ninf_flag:更新节点的访问状态:
    • 对于执行后悔操作的智能体,恢复其节点访问状态。
    • 对于操作模式2的智能体,将其上一步的节点标记为未访问。
    • 对于操作模式3的智能体,确保结束时访问标记正确。
    • 如果节点为配送中心,设置为已访问。

9.8. 更新负无穷掩码(ninf_mask)

self.ninf_mask = self.visited_ninf_flag.clone()
round_error_epsilon = 0.00001
demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
self.ninf_mask[demand_too_large] = float('-inf')

功能:

  • 更新ninf_mask,根据当前负载和节点需求量来掩盖不满足条件的节点(负载不足以承载需求的节点被标记为-inf)。

9.9. 更新完成状态(finished)

newly_finished = (self.visited_ninf_flag == float('-inf'))[:, :, :self.problem_size + 1].all(dim=2)
self.finished = self.finished + newly_finished

功能:

  • 检查是否所有节点都已被访问,对于完成的智能体,更新self.finished标志。

9.10. 更新模式(mode)

self.mode[action1_bool_index] = 1
self.mode[action2_bool_index] = 2
self.mode[action3_bool_index] = 0
self.mode[self.finished] = 4

功能:

  • 根据不同的动作模式更新self.mode,包括后悔操作、执行操作1、执行操作2和完成操作。

9.11. 更新完成后的掩码调整

self.ninf_mask[:, :, 0][self.finished] = 0
self.ninf_mask[:, :, self.problem_size + 1][self.finished] = float('-inf')

功能:

  • 如果智能体完成任务,调整掩码,确保完成的节点不会被选择。

10. 完成状态与奖励

done = self.finished.all()
if done:
    reward = -self._get_travel_distance()  # note the minus sign!
else:
    reward = None

功能:

  • 检查所有智能体是否已完成任务(即访问所有节点)。
  • 如果所有任务完成,则计算奖励(这里通过负的旅行距离来表示,越短的路径越好)。
  • 如果未完成任务,则奖励为None

11. 返回值

return self.step_state, reward, done

功能:

  • 返回当前步骤状态、奖励和完成标志。

附录

函数代码

    def step(self, selected):
        # selected.shape: (batch, pomo)

        #时间步数控制
        if self.time_step<4:

            # 控制时间步的递增
            self.time_step=self.time_step+1
            self.selectex_count = self.selected_count+1

            #判断是否在配送中心
            self.at_the_depot = (selected == 0)

            #特定时间步的操作
            if self.time_step==3:
                self.last_current_node = self.current_node.clone()
                self.last_load = self.load.clone()
            if self.time_step == 4:
                self.last_current_node = self.current_node.clone()
                self.last_load = self.load.clone()
                self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot)&(self.last_current_node!=0)] = 0
            
            #更新当前节点和已选择节点列表
            self.current_node = selected
            self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)

            #更新需求和负载
            demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
            gathering_index = selected[:, :, None]
            selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
            self.load -= selected_demand
            self.load[self.at_the_depot] = 1  # refill loaded at the depot

            #更新访问标记(防止重复选择已访问的节点)
            self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
            self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0  # depot is considered unvisited, unless you are AT the depot

            #更新负无穷掩码(屏蔽需求量超过当前负载的节点)
            self.ninf_mask = self.visited_ninf_flag.clone()
            round_error_epsilon = 0.00001
            demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
            _2=torch.full((demand_too_large.shape[0],demand_too_large.shape[1],1),False)
            demand_too_large = torch.cat((demand_too_large, _2), dim=2)
            self.ninf_mask[demand_too_large] = float('-inf')

            #更新步骤状态,将更新后的状态同步到 self.step_state
            self.step_state.selected_count = self.time_step
            self.step_state.load = self.load
            self.step_state.current_node = self.current_node
            self.step_state.ninf_mask = self.ninf_mask


        #时间步大于等于 4 的复杂操作
        else:
            #动作模式分类
            action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))
            action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1))  # regret
            action2_bool_index = self.mode == 1
            action3_bool_index = self.mode == 2
            
            action1_index = torch.nonzero(action1_bool_index)
            action2_index = torch.nonzero(action2_bool_index)

            action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))

            #更新选择计数
            self.selected_count = self.selected_count+1
            #后悔模式
            self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2

            #节点更新
            self.last_is_depot = (self.last_current_node == 0)

            _ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()
            temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()
            self.last_current_node = self.current_node.clone()
            self.current_node = selected.clone()
            self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()

            #更新已选择节点列表
            self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)

            #更新负载
            self.at_the_depot = (selected == 0)
            demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
            # shape: (batch, pomo, problem+1)
            _3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)
            #扩展需求列表 demand_list 
            demand_list = torch.cat((demand_list, _3), dim=2)
            gathering_index = selected[:, :, None]
            # shape: (batch, pomo, 1)
            selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
            _1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()
            self.last_load= self.load.clone()
            # shape: (batch, pomo)
            self.load -= selected_demand
            self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()
            self.load[self.at_the_depot] = 1  # refill loaded at the depot

            #更新访问标记
            self.visited_ninf_flag[:, :, self.problem_size+1][self.last_is_depot] = 0
            self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
            self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)
            self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)
            self.visited_ninf_flag[:, :, self.problem_size+1][self.at_the_depot] = float('-inf')
            self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0


            # 更新负无穷掩码
            self.ninf_mask = self.visited_ninf_flag.clone()
            round_error_epsilon = 0.00001
            demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
            # shape: (batch, pomo, problem+1)
            self.ninf_mask[demand_too_large] = float('-inf')

            # 更新完成状态
            # 检查哪些智能体已经完成所有节点的访问。
            # 更新完成标记 self.finished。
            newly_finished = (self.visited_ninf_flag == float('-inf'))[:,:,:self.problem_size+1].all(dim=2)
            # shape: (batch, pomo)
            self.finished = self.finished + newly_finished
            # shape: (batch, pomo)

            #更新模式
            self.mode[action1_bool_index] = 1
            self.mode[action2_bool_index] = 2
            self.mode[action3_bool_index] = 0
            self.mode[self.finished] = 4

            # 更新完成后的掩码调整
            self.ninf_mask[:, :, 0][self.finished] = 0
            self.ninf_mask[:, :, self.problem_size+1][self.finished] = float('-inf')

            # 更新步骤状态
            self.step_state.selected_count = self.time_step
            self.step_state.load = self.load
            self.step_state.current_node = self.current_node
            self.step_state.ninf_mask = self.ninf_mask



        # returning values
        done = self.finished.all()
        if done:
            reward = -self._get_travel_distance()  # note the minus sign!
        else:
            reward = None

        return self.step_state, reward, done


http://www.kler.cn/a/567891.html

相关文章:

  • 文件描述符与重定向
  • ES批量查询
  • 大模型训练——pycharm连接实验室服务器
  • Python中文自然语言处理库SnowNLP
  • 多通道数据采集和信号生成的模块化仪器如何重构飞机电子可靠性测试体系?
  • 数据结构之各类排序算法代码及其详解
  • 判断按键盘是否好使的开机自启动PowerShell脚本
  • 【MATLAB例程】三维下的IMM(交互式多模型),模型使用CV(匀速)和CA(匀加速)
  • UWB人员定位:精准、高效、安全的智能管理解决方案
  • 使用3090显卡部署Wan2.1生成视频
  • 基于ai技术的视频生成工具
  • Java——String
  • 计算机网络之传输层(传输层提供的服务)
  • DeepSeek 开源狂欢周(五)正式收官|3FS并行文件系统榨干SSD
  • 【漫话机器学习系列】111.指数之和的对数(Log-Sum-Exp)
  • Flink同步数据mysql到doris问题合集
  • Ubuntu 下 nginx-1.24.0 源码分析 - ngx_init_cycle 函数 - 详解(5)
  • vue3-print-nb的使用,点击回调
  • 《深度揭秘:生成对抗网络如何重塑遥感图像分析精度》
  • PHP的学习