PyTorch 深度学习实战(24):分层强化学习(HRL)
一、分层强化学习原理
1. 分层学习核心思想
分层强化学习(Hierarchical Reinforcement Learning, HRL)通过时间抽象和任务分解解决复杂长程任务。核心思想是:
对比维度 | 传统强化学习 | 分层强化学习 |
---|---|---|
策略结构 | 单一策略直接输出动作 | 高层策略选择选项(Option) |
时间尺度 | 单一步长决策 | 高层策略决策跨度长,底层策略执行 |
适用场景 | 简单短程任务 | 复杂长程任务(如迷宫导航、机器人操控) |
2. Option-Critic 算法框架
Option-Critic 是 HRL 的代表性算法,其核心组件包括:
二、Option-Critic 实现步骤(基于 Gymnasium)
我们将以 Meta-World 机械臂多阶段任务 为例,实现 Option-Critic 算法:
-
定义选项集合:包含
reach
(接近目标)、grasp
(抓取)、move
(移动) 三个选项 -
构建策略网络:高层策略 + 选项内部策略 + 终止条件网络
-
分层交互训练:高层选择选项,底层执行多步动作
-
联合梯度更新:优化高层和底层策略
三、代码实现
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical, Normal
import gymnasium as gym
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
import time
# ================== 配置参数优化 ==================
class OptionCriticConfig:
num_options = 3 # 选项数量(reach, grasp, move)
option_length = 20 # 选项最大执行步长
hidden_dim = 128 # 网络隐藏层维度
lr_high = 1e-4 # 高层策略学习率
lr_option = 3e-4 # 选项策略学习率
gamma = 0.99 # 折扣因子
entropy_weight = 0.01 # 熵正则化权重
max_episodes = 5000 # 最大训练回合数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ================== 高层策略网络 ==================
class HighLevelPolicy(nn.Module):
def __init__(self, state_dim, num_options):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, OptionCriticConfig.hidden_dim),
nn.ReLU(),
nn.Linear(OptionCriticConfig.hidden_dim, num_options)
)
def forward(self, state):
return self.net(state)
# ================== 选项内部策略网络 ==================
class OptionPolicy(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, OptionCriticConfig.hidden_dim),
nn.ReLU(),
nn.Linear(OptionCriticConfig.hidden_dim, action_dim)
)
def forward(self, state):
return self.net(state)
# ================== 终止条件网络 ==================
class TerminationNetwork(nn.Module):
def __init__(self, state_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, OptionCriticConfig.hidden_dim),
nn.ReLU(),
nn.Linear(OptionCriticConfig.hidden_dim, 1),
nn.Sigmoid() # 输出终止概率
)
def forward(self, state):
return self.net(state)
# ================== 训练系统 ==================
class OptionCriticTrainer:
def __init__(self):
# 初始化环境
self.env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE['pick-place-v2-goal-observable']()
# 处理观测空间
if isinstance(self.env.observation_space, gym.spaces.Dict):
self.state_dim = sum([self.env.observation_space.spaces[key].shape[0] for key in ['observation', 'desired_goal']])
self.process_state = self._process_dict_state
else:
self.state_dim = self.env.observation_space.shape[0]
self.process_state = lambda x: x
self.action_dim = self.env.action_space.shape[0]
# 初始化网络
self.high_policy = HighLevelPolicy(self.state_dim, OptionCriticConfig.num_options).to(OptionCriticConfig.device)
self.option_policies = nn.ModuleList([
OptionPolicy(self.state_dim, self.action_dim).to(OptionCriticConfig.device)
for _ in range(OptionCriticConfig.num_options)
])
self.termination_networks = nn.ModuleList([
TerminationNetwork(self.state_dim).to(OptionCriticConfig.device)
for _ in range(OptionCriticConfig.num_options)
])
# 优化器
self.optimizer_high = optim.Adam(self.high_policy.parameters(), lr=OptionCriticConfig.lr_high)
self.optimizer_option = optim.Adam(
list(self.option_policies.parameters()) + list(self.termination_networks.parameters()),
lr=OptionCriticConfig.lr_option
)
def _process_dict_state(self, state_dict):
return np.concatenate([state_dict['observation'], state_dict['desired_goal']])
def select_option(self, state):
state = torch.FloatTensor(state).to(OptionCriticConfig.device)
logits = self.high_policy(state)
dist = Categorical(logits=logits)
option = dist.sample()
return option.item(), dist.log_prob(option)
def select_action(self, state, option):
state = torch.FloatTensor(state).to(OptionCriticConfig.device)
action_mean = self.option_policies[option](state)
dist = Normal(action_mean, torch.ones_like(action_mean)) # 假设动作空间连续
action = dist.sample()
log_prob = dist.log_prob(action).sum(dim=-1) # 沿最后一个维度求和得到标量
return action.cpu().numpy(), log_prob # 返回标量log概率
def should_terminate(self, state, current_option):
state = torch.FloatTensor(state).to(OptionCriticConfig.device)
terminate_prob = self.termination_networks[current_option](state)
return torch.bernoulli(terminate_prob).item() == 1
def train(self):
for episode in range(OptionCriticConfig.max_episodes):
state_dict, _ = self.env.reset()
state = self.process_state(state_dict)
episode_reward = 0
current_option, log_prob_high = self.select_option(state)
option_step = 0
while True:
# 执行选项内部策略
action, log_prob_option = self.select_action(state, current_option)
next_state_dict, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
next_state = self.process_state(next_state_dict)
episode_reward += reward
# 判断是否终止选项
terminate = self.should_terminate(next_state, current_option) or (option_step >= OptionCriticConfig.option_length)
# 计算梯度
if terminate or done:
# 计算选项价值(添加detach防止梯度传递)
with torch.no_grad():
next_value = self.high_policy(torch.FloatTensor(next_state).to(OptionCriticConfig.device)).max().item()
termination_output = self.termination_networks[current_option](torch.FloatTensor(state).to(OptionCriticConfig.device))
# 计算delta时分离终止网络的梯度
delta = reward + OptionCriticConfig.gamma * next_value - termination_output.detach()
# 高层策略梯度计算
loss_high = -log_prob_high * delta
self.optimizer_high.zero_grad()
loss_high.backward(retain_graph=True) # 保留计算图
self.optimizer_high.step()
# 选项策略梯度计算
loss_option = -log_prob_option * delta
entropy = -log_prob_option * torch.exp(log_prob_option.detach())
loss_option_total = loss_option + OptionCriticConfig.entropy_weight * entropy
self.optimizer_option.zero_grad()
loss_option_total.backward() # 此时仍可访问保留的计算图
self.optimizer_option.step()
# 重置选项
if not done:
current_option, log_prob_high = self.select_option(next_state)
option_step = 0
else:
break
else:
option_step += 1
state = next_state
if (episode + 1) % 100 == 0:
print(f"Episode {episode+1} | Reward: {episode_reward:.1f}")
if __name__ == "__main__":
start = time.time()
start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
print(f"开始时间: {start_str}")
print("初始化环境...")
trainer = OptionCriticTrainer()
trainer.train()
end = time.time()
end_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end))
print(f"训练完成时间: {end_str}")
print(f"训练完成,耗时: {end - start:.2f}秒")
四、关键代码解析
-
高层策略选择选项
select_option
:基于当前状态选择选项,返回选项 ID 和选择概率的对数值。 -
选项内部策略执行
select_action
:根据当前选项生成动作,支持连续动作空间(使用高斯分布)。 -
终止条件判断
should_terminate
:根据终止网络输出概率判断是否终止当前选项。 -
梯度更新逻辑
高层策略:基于选项的价值差(TD Error)更新。
选项策略:结合 TD Error 和熵正则化更新。
五、训练输出示例
开始时间: 2025-03-24 08:29:46
初始化环境...
Episode 100 | Reward: 2.7
Episode 200 | Reward: 4.9
Episode 300 | Reward: 2.2
Episode 400 | Reward: 2.8
Episode 500 | Reward: 3.0
Episode 600 | Reward: 3.3
Episode 700 | Reward: 3.2
Episode 800 | Reward: 4.7
Episode 900 | Reward: 5.3
Episode 1000 | Reward: 7.5
Episode 1100 | Reward: 6.3
Episode 1200 | Reward: 3.7
Episode 1300 | Reward: 7.8
Episode 1400 | Reward: 3.8
Episode 1500 | Reward: 2.4
Episode 1600 | Reward: 2.3
Episode 1700 | Reward: 2.5
Episode 1800 | Reward: 2.7
Episode 1900 | Reward: 2.7
Episode 2000 | Reward: 3.9
Episode 2100 | Reward: 4.5
Episode 2200 | Reward: 4.1
Episode 2300 | Reward: 4.7
Episode 2400 | Reward: 4.0
Episode 2500 | Reward: 4.3
Episode 2600 | Reward: 3.8
Episode 2700 | Reward: 3.3
Episode 2800 | Reward: 4.6
Episode 2900 | Reward: 5.2
Episode 3000 | Reward: 7.7
Episode 3100 | Reward: 7.8
Episode 3200 | Reward: 3.3
Episode 3300 | Reward: 5.3
Episode 3400 | Reward: 4.5
Episode 3500 | Reward: 3.9
Episode 3600 | Reward: 4.1
Episode 3700 | Reward: 4.0
Episode 3800 | Reward: 5.2
Episode 3900 | Reward: 8.2
Episode 4000 | Reward: 2.2
Episode 4100 | Reward: 2.2
Episode 4200 | Reward: 2.2
Episode 4300 | Reward: 2.2
Episode 4400 | Reward: 6.9
Episode 4500 | Reward: 5.6
Episode 4600 | Reward: 2.0
Episode 4700 | Reward: 1.6
Episode 4800 | Reward: 1.7
Episode 4900 | Reward: 1.9
Episode 5000 | Reward: 3.1
训练完成时间: 2025-03-24 12:41:48
训练完成,耗时: 15122.31秒
在下一篇文章中,我们将探索 逆向强化学习(Inverse RL),并实现 GAIL 算法!
注意事项
-
安装依赖:
pip install metaworld gymnasium torch
-
Meta-World 需要 MuJoCo 许可证:
export MUJOCO_PY_MUJOCO_PATH=/path/to/mujoco
-
训练时间较长(推荐 GPU 加速):
CUDA_VISIBLE_DEVICES=0 python option_critic.py