20250301_代码笔记_函数class CVRPEnv: def step(self, selected)
文章目录
- 前言
- 1. 时间步数控制(time_step < 4)
- 2. 时间步数3的特定操作
- 3. 时间步数4的特定操作
- 4. 更新当前节点和已选择节点列表
- 5. 更新负载
- 6. 更新访问标记
- 7. 更新负无穷掩码
- 8. 更新步骤状态
- 9. 时间步大于等于 4 的复杂操作
- 9.1. 动作模式分类(action0, action1, action2, action3)
- 9.2. 动作模式的索引提取
- 9.3. 更新选择计数
- 9.4. 节点更新
- 9.5. 更新已选择节点列表
- 9.6. 更新负载(load)
- 9. 7. 更新访问标记(visited_ninf_flag)
- 9.8. 更新负无穷掩码(ninf_mask)
- 9.9. 更新完成状态(finished)
- 9.10. 更新模式(mode)
- 9.11. 更新完成后的掩码调整
- 10. 完成状态与奖励
- 11. 返回值
- 附录
- 函数代码
前言
细读函数class CVRPEnv: def step(self, selected)
/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPEnv.py
注:selected
为step()
函数的输入参数。
1. 时间步数控制(time_step < 4)
if self.time_step < 4:
self.time_step = self.time_step + 1
self.selected_count = self.selected_count + 1
self.at_the_depot = (selected == 0)
功能:
- 判断当前的
time_step
是否小于4。如果是,则执行该时间段的逻辑。self.time_step
增加1,表示进入下一个时间步。self.selected_count
表示已经选择的节点数量,递增。self.at_the_depot
记录当前的节点是否为配送中心(depot
)。如果selected == 0
(即选择了配送中心),则设置为True
。
2. 时间步数3的特定操作
if self.time_step == 3:
self.last_current_node = self.current_node.clone()
self.last_load = self.load.clone()
功能:
- 当
time_step == 3
时,记录当前节点和负载的状态,用于后续的操作。self.last_current_node
和self.last_load
保存当前状态,方便后续在步骤4时进行比较或更新。
3. 时间步数4的特定操作
if self.time_step == 4:
self.last_current_node = self.current_node.clone()
self.last_load = self.load.clone()
self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot) & (self.last_current_node != 0)] = 0
功能:
- 当
time_step == 4
时,保存当前节点和负载的状态。- 进一步处理
visited_ninf_flag
,标记已访问的节点,特别是那些不在配送中心并且之前没有访问过的节点。
- 进一步处理
4. 更新当前节点和已选择节点列表
self.current_node = selected
self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
功能:
将selected
(当前选择的节点)设置为self.current_node
。
更新self.selected_node_list
,将当前选择的节点添加到已选择的节点列表中。
5. 更新负载
demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
gathering_index = selected[:, :, None]
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
self.load -= selected_demand
self.load[self.at_the_depot] = 1 # refill loaded at the depot
功能:
- 获取当前选择节点的需求量,并更新
self.load
(负载):demand_list
:扩展depot_node_demand
以便按批次、POMO索引进行操作。gathering_index
:用于从demand_list
中获取对应于selected
的需求量。self.load -= selected_demand
:根据选择的节点减少当前负载。- 如果当前位置在配送中心(
at_the_depot
),则将负载重置为1,表示负载已满。
6. 更新访问标记
self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0 # depot is considered unvisited, unless you are AT the depot
功能:
- 更新
visited_ninf_flag
以防止重复访问已选择的节点:- 对于已选择的节点,标记为负无穷(表示已访问)。
- 对于配送中心(
at_the_depot
),设置为未访问(0
),除非当前就在配送中心。
7. 更新负无穷掩码
self.ninf_mask = self.visited_ninf_flag.clone()
round_error_epsilon = 0.00001
demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
_2 = torch.full((demand_too_large.shape[0], demand_too_large.shape[1], 1), False)
demand_too_large = torch.cat((demand_too_large, _2), dim=2)
self.ninf_mask[demand_too_large] = float('-inf')
功能:
- 更新
ninf_mask
,用于掩盖负载不足以承载需求量的节点:demand_too_large
表示负载不足以承载节点的需求,如果某个节点的需求量大于当前负载,则将其标记为True
。- 使用
self.ninf_mask
将这些节点掩盖为负无穷(-inf
)。
8. 更新步骤状态
self.step_state.selected_count = self.time_step
self.step_state.load = self.load
self.step_state.current_node = self.current_node
self.step_state.ninf_mask = self.ninf_mask
功能:
- 更新
self.step_state
,将当前步骤的状态(如selected_count
、load、current_node
和ninf_mask
)同步到步骤状态中。
9. 时间步大于等于 4 的复杂操作
9.1. 动作模式分类(action0, action1, action2, action3)
action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))
action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1)) # regret
action2_bool_index = self.mode == 1
action3_bool_index = self.mode == 2
功能:
action0_bool_index
:表示选择了正常节点且当前模式为0
的情况。selected != self.problem_size + 1
是用来排除选择了特殊的节点(即problem_size + 1
,通常是用于后悔机制或终止标记)。action1_bool_index
:表示当前模式为0
并且选择了后悔节点(selected == self.problem_size + 1
)。在这种情况下,智能体执行后悔操作(Regret
)。action2_bool_index
:表示当前模式为1
,意味着智能体正在执行某种特定的操作。action3_bool_index
:表示当前模式为2
,意味着智能体正在执行另一种特定操作。
9.2. 动作模式的索引提取
action1_index = torch.nonzero(action1_bool_index)
action2_index = torch.nonzero(action2_bool_index)
action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))
功能:
action1_index
:获取所有满足action1_bool_index
的索引,即执行后悔操作的智能体。action2_index
:获取所有满足action2_bool_index
的索引,执行操作模式1的智能体。action4_index
:获取所有满足action3_bool_index
并且当前节点不为0
(即不在配送中心)的智能体。
9.3. 更新选择计数
self.selected_count = self.selected_count + 1
self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2
功能:
self.selected_count
增加1
,表示当前有一个新的节点被选择。- 对于执行后悔操作的智能体(
action1_bool_index
),减少选择计数2
,因为它们撤回了之前的选择。
9.4. 节点更新
self.last_is_depot = (self.last_current_node == 0)
_ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()
temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()
self.last_current_node = self.current_node.clone()
self.current_node = selected.clone()
self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()
功能:
self.last_is_depot
:检查之前的节点是否为配送中心(0
)。- 对于执行后悔操作的智能体(
action1_index
),保存它们上一步的current_node
,然后将self.current_node
更新为selected
(当前选择的节点),并恢复后悔节点的选择。 self.last_current_node
:保存当前的current_node
,以便后悔操作或其他模式下的恢复。temp_last_current_node_action2
:用于存储执行操作模式2
的智能体的节点,以便后续更新。
9.5. 更新已选择节点列表
self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)
功能:
- 将当前选择的节点(
selected
)添加到已选择节点列表self.selected_node_list
中。
9.6. 更新负载(load)
self.at_the_depot = (selected == 0)
demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
_3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)
demand_list = torch.cat((demand_list, _3), dim=2)
gathering_index = selected[:, :, None]
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
_1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()
self.last_load = self.load.clone()
self.load -= selected_demand
self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()
self.load[self.at_the_depot] = 1 # refill loaded at the depot
功能:
self.at_the_depot
:判断是否选择了配送中心(selected == 0
)。demand_list
:扩展depot_node_demand
以便按批次、POMO索引进行操作。- 对于后悔操作的智能体(
action1_index
),负载被恢复为之前的值。 - 通过
gather
方法获取当前选择节点的需求量,并更新self.load
。 - 如果当前节点是配送中心,负载重置为
1
,表示重新加载。
9. 7. 更新访问标记(visited_ninf_flag)
self.visited_ninf_flag[:, :, self.problem_size + 1][self.last_is_depot] = 0
self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)
self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)
self.visited_ninf_flag[:, :, self.problem_size + 1][self.at_the_depot] = float('-inf')
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0
功能:
self.visited_ninf_flag
:更新节点的访问状态:- 对于执行后悔操作的智能体,恢复其节点访问状态。
- 对于操作模式
2
的智能体,将其上一步的节点标记为未访问。 - 对于操作模式
3
的智能体,确保结束时访问标记正确。 - 如果节点为配送中心,设置为已访问。
9.8. 更新负无穷掩码(ninf_mask)
self.ninf_mask = self.visited_ninf_flag.clone()
round_error_epsilon = 0.00001
demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
self.ninf_mask[demand_too_large] = float('-inf')
功能:
- 更新
ninf_mask
,根据当前负载和节点需求量来掩盖不满足条件的节点(负载不足以承载需求的节点被标记为-inf
)。
9.9. 更新完成状态(finished)
newly_finished = (self.visited_ninf_flag == float('-inf'))[:, :, :self.problem_size + 1].all(dim=2)
self.finished = self.finished + newly_finished
功能:
- 检查是否所有节点都已被访问,对于完成的智能体,更新
self.finished
标志。
9.10. 更新模式(mode)
self.mode[action1_bool_index] = 1
self.mode[action2_bool_index] = 2
self.mode[action3_bool_index] = 0
self.mode[self.finished] = 4
功能:
- 根据不同的动作模式更新
self.mode
,包括后悔操作、执行操作1、执行操作2和完成操作。
9.11. 更新完成后的掩码调整
self.ninf_mask[:, :, 0][self.finished] = 0
self.ninf_mask[:, :, self.problem_size + 1][self.finished] = float('-inf')
功能:
- 如果智能体完成任务,调整掩码,确保完成的节点不会被选择。
10. 完成状态与奖励
done = self.finished.all()
if done:
reward = -self._get_travel_distance() # note the minus sign!
else:
reward = None
功能:
- 检查所有智能体是否已完成任务(即访问所有节点)。
- 如果所有任务完成,则计算奖励(这里通过负的旅行距离来表示,越短的路径越好)。
- 如果未完成任务,则奖励为
None
。
11. 返回值
return self.step_state, reward, done
功能:
- 返回当前步骤状态、奖励和完成标志。
附录
函数代码
def step(self, selected):
# selected.shape: (batch, pomo)
#时间步数控制
if self.time_step<4:
# 控制时间步的递增
self.time_step=self.time_step+1
self.selectex_count = self.selected_count+1
#判断是否在配送中心
self.at_the_depot = (selected == 0)
#特定时间步的操作
if self.time_step==3:
self.last_current_node = self.current_node.clone()
self.last_load = self.load.clone()
if self.time_step == 4:
self.last_current_node = self.current_node.clone()
self.last_load = self.load.clone()
self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot)&(self.last_current_node!=0)] = 0
#更新当前节点和已选择节点列表
self.current_node = selected
self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
#更新需求和负载
demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
gathering_index = selected[:, :, None]
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
self.load -= selected_demand
self.load[self.at_the_depot] = 1 # refill loaded at the depot
#更新访问标记(防止重复选择已访问的节点)
self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0 # depot is considered unvisited, unless you are AT the depot
#更新负无穷掩码(屏蔽需求量超过当前负载的节点)
self.ninf_mask = self.visited_ninf_flag.clone()
round_error_epsilon = 0.00001
demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
_2=torch.full((demand_too_large.shape[0],demand_too_large.shape[1],1),False)
demand_too_large = torch.cat((demand_too_large, _2), dim=2)
self.ninf_mask[demand_too_large] = float('-inf')
#更新步骤状态,将更新后的状态同步到 self.step_state
self.step_state.selected_count = self.time_step
self.step_state.load = self.load
self.step_state.current_node = self.current_node
self.step_state.ninf_mask = self.ninf_mask
#时间步大于等于 4 的复杂操作
else:
#动作模式分类
action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))
action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1)) # regret
action2_bool_index = self.mode == 1
action3_bool_index = self.mode == 2
action1_index = torch.nonzero(action1_bool_index)
action2_index = torch.nonzero(action2_bool_index)
action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))
#更新选择计数
self.selected_count = self.selected_count+1
#后悔模式
self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2
#节点更新
self.last_is_depot = (self.last_current_node == 0)
_ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()
temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()
self.last_current_node = self.current_node.clone()
self.current_node = selected.clone()
self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()
#更新已选择节点列表
self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)
#更新负载
self.at_the_depot = (selected == 0)
demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
# shape: (batch, pomo, problem+1)
_3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)
#扩展需求列表 demand_list
demand_list = torch.cat((demand_list, _3), dim=2)
gathering_index = selected[:, :, None]
# shape: (batch, pomo, 1)
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
_1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()
self.last_load= self.load.clone()
# shape: (batch, pomo)
self.load -= selected_demand
self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()
self.load[self.at_the_depot] = 1 # refill loaded at the depot
#更新访问标记
self.visited_ninf_flag[:, :, self.problem_size+1][self.last_is_depot] = 0
self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)
self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)
self.visited_ninf_flag[:, :, self.problem_size+1][self.at_the_depot] = float('-inf')
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0
# 更新负无穷掩码
self.ninf_mask = self.visited_ninf_flag.clone()
round_error_epsilon = 0.00001
demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
# shape: (batch, pomo, problem+1)
self.ninf_mask[demand_too_large] = float('-inf')
# 更新完成状态
# 检查哪些智能体已经完成所有节点的访问。
# 更新完成标记 self.finished。
newly_finished = (self.visited_ninf_flag == float('-inf'))[:,:,:self.problem_size+1].all(dim=2)
# shape: (batch, pomo)
self.finished = self.finished + newly_finished
# shape: (batch, pomo)
#更新模式
self.mode[action1_bool_index] = 1
self.mode[action2_bool_index] = 2
self.mode[action3_bool_index] = 0
self.mode[self.finished] = 4
# 更新完成后的掩码调整
self.ninf_mask[:, :, 0][self.finished] = 0
self.ninf_mask[:, :, self.problem_size+1][self.finished] = float('-inf')
# 更新步骤状态
self.step_state.selected_count = self.time_step
self.step_state.load = self.load
self.step_state.current_node = self.current_node
self.step_state.ninf_mask = self.ninf_mask
# returning values
done = self.finished.all()
if done:
reward = -self._get_travel_distance() # note the minus sign!
else:
reward = None
return self.step_state, reward, done