当前位置: 首页 > article >正文

PyTorch 深度学习实战(24):分层强化学习(HRL)

一、分层强化学习原理

1. 分层学习核心思想

分层强化学习(Hierarchical Reinforcement Learning, HRL)通过时间抽象任务分解解决复杂长程任务。核心思想是:

对比维度传统强化学习分层强化学习
策略结构单一策略直接输出动作高层策略选择选项(Option)
时间尺度单一步长决策高层策略决策跨度长,底层策略执行
适用场景简单短程任务复杂长程任务(如迷宫导航、机器人操控)
2. Option-Critic 算法框架

Option-Critic 是 HRL 的代表性算法,其核心组件包括:


二、Option-Critic 实现步骤(基于 Gymnasium)

我们将以 Meta-World 机械臂多阶段任务 为例,实现 Option-Critic 算法:

  1. 定义选项集合:包含 reach(接近目标)、grasp(抓取)、move(移动) 三个选项

  2. 构建策略网络:高层策略 + 选项内部策略 + 终止条件网络

  3. 分层交互训练:高层选择选项,底层执行多步动作

  4. 联合梯度更新:优化高层和底层策略


三、代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical, Normal
import gymnasium as gym
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
import time
​
# ================== 配置参数优化 ==================
class OptionCriticConfig:
    num_options = 3                  # 选项数量(reach, grasp, move)
    option_length = 20               # 选项最大执行步长
    hidden_dim = 128                 # 网络隐藏层维度
    lr_high = 1e-4                   # 高层策略学习率
    lr_option = 3e-4                 # 选项策略学习率
    gamma = 0.99                     # 折扣因子
    entropy_weight = 0.01            # 熵正则化权重
    max_episodes = 5000              # 最大训练回合数
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
​
# ================== 高层策略网络 ==================
class HighLevelPolicy(nn.Module):
    def __init__(self, state_dim, num_options):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, OptionCriticConfig.hidden_dim),
            nn.ReLU(),
            nn.Linear(OptionCriticConfig.hidden_dim, num_options)
        )
    
    def forward(self, state):
        return self.net(state)
​
# ================== 选项内部策略网络 ==================
class OptionPolicy(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, OptionCriticConfig.hidden_dim),
            nn.ReLU(),
            nn.Linear(OptionCriticConfig.hidden_dim, action_dim)
        )
    
    def forward(self, state):
        return self.net(state)
​
# ================== 终止条件网络 ==================
class TerminationNetwork(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, OptionCriticConfig.hidden_dim),
            nn.ReLU(),
            nn.Linear(OptionCriticConfig.hidden_dim, 1),
            nn.Sigmoid()  # 输出终止概率
        )
    
    def forward(self, state):
        return self.net(state)
​
# ================== 训练系统 ==================
class OptionCriticTrainer:
    def __init__(self):
        # 初始化环境
        self.env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE['pick-place-v2-goal-observable']()
        # 处理观测空间
        if isinstance(self.env.observation_space, gym.spaces.Dict):
            self.state_dim = sum([self.env.observation_space.spaces[key].shape[0] for key in ['observation', 'desired_goal']])
            self.process_state = self._process_dict_state
        else:
            self.state_dim = self.env.observation_space.shape[0]
            self.process_state = lambda x: x
        self.action_dim = self.env.action_space.shape[0]
        
        # 初始化网络
        self.high_policy = HighLevelPolicy(self.state_dim, OptionCriticConfig.num_options).to(OptionCriticConfig.device)
        self.option_policies = nn.ModuleList([
            OptionPolicy(self.state_dim, self.action_dim).to(OptionCriticConfig.device)
            for _ in range(OptionCriticConfig.num_options)
        ])
        self.termination_networks = nn.ModuleList([
            TerminationNetwork(self.state_dim).to(OptionCriticConfig.device)
            for _ in range(OptionCriticConfig.num_options)
        ])
        
        # 优化器
        self.optimizer_high = optim.Adam(self.high_policy.parameters(), lr=OptionCriticConfig.lr_high)
        self.optimizer_option = optim.Adam(
            list(self.option_policies.parameters()) + list(self.termination_networks.parameters()),
            lr=OptionCriticConfig.lr_option
        )
    
    def _process_dict_state(self, state_dict):
        return np.concatenate([state_dict['observation'], state_dict['desired_goal']])
    
    def select_option(self, state):
        state = torch.FloatTensor(state).to(OptionCriticConfig.device)
        logits = self.high_policy(state)
        dist = Categorical(logits=logits)
        option = dist.sample()
        return option.item(), dist.log_prob(option)
    
    def select_action(self, state, option):
        state = torch.FloatTensor(state).to(OptionCriticConfig.device)
        action_mean = self.option_policies[option](state)
        dist = Normal(action_mean, torch.ones_like(action_mean))  # 假设动作空间连续
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)  # 沿最后一个维度求和得到标量
        return action.cpu().numpy(), log_prob  # 返回标量log概率
    
    def should_terminate(self, state, current_option):
        state = torch.FloatTensor(state).to(OptionCriticConfig.device)
        terminate_prob = self.termination_networks[current_option](state)
        return torch.bernoulli(terminate_prob).item() == 1
    
    def train(self):
        for episode in range(OptionCriticConfig.max_episodes):
            state_dict, _ = self.env.reset()
            state = self.process_state(state_dict)
            episode_reward = 0
            current_option, log_prob_high = self.select_option(state)
            option_step = 0
            
            while True:
                # 执行选项内部策略
                action, log_prob_option = self.select_action(state, current_option)
                next_state_dict, reward, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated
                next_state = self.process_state(next_state_dict)
                episode_reward += reward
                
                # 判断是否终止选项
                terminate = self.should_terminate(next_state, current_option) or (option_step >= OptionCriticConfig.option_length)
                
                # 计算梯度
                if terminate or done:
                    # 计算选项价值(添加detach防止梯度传递)
                    with torch.no_grad():
                        next_value = self.high_policy(torch.FloatTensor(next_state).to(OptionCriticConfig.device)).max().item()
                    termination_output = self.termination_networks[current_option](torch.FloatTensor(state).to(OptionCriticConfig.device))
                    
                    # 计算delta时分离终止网络的梯度
                    delta = reward + OptionCriticConfig.gamma * next_value - termination_output.detach()
​
                    # 高层策略梯度计算
                    loss_high = -log_prob_high * delta
                    self.optimizer_high.zero_grad()
                    loss_high.backward(retain_graph=True)  # 保留计算图
                    self.optimizer_high.step()
​
                    # 选项策略梯度计算
                    loss_option = -log_prob_option * delta
                    entropy = -log_prob_option * torch.exp(log_prob_option.detach())
                    loss_option_total = loss_option + OptionCriticConfig.entropy_weight * entropy
                    self.optimizer_option.zero_grad()
                    loss_option_total.backward()  # 此时仍可访问保留的计算图
                    self.optimizer_option.step()
                    
                    # 重置选项
                    if not done:
                        current_option, log_prob_high = self.select_option(next_state)
                        option_step = 0
                    else:
                        break
                else:
                    option_step += 1
                    state = next_state
            
            if (episode + 1) % 100 == 0:
                print(f"Episode {episode+1} | Reward: {episode_reward:.1f}")
​
if __name__ == "__main__":
    start = time.time()
    start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
    print(f"开始时间: {start_str}")
    print("初始化环境...")
    trainer = OptionCriticTrainer()
    trainer.train()
    end = time.time()
    end_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end))
    print(f"训练完成时间: {end_str}")
    print(f"训练完成,耗时: {end - start:.2f}秒")

四、关键代码解析

  1. 高层策略选择选项

    select_option:基于当前状态选择选项,返回选项 ID 和选择概率的对数值。
  2. 选项内部策略执行

    select_action:根据当前选项生成动作,支持连续动作空间(使用高斯分布)。
  3. 终止条件判断

    should_terminate:根据终止网络输出概率判断是否终止当前选项。
  4. 梯度更新逻辑

    高层策略:基于选项的价值差(TD Error)更新。
    选项策略:结合 TD Error 和熵正则化更新。

五、训练输出示例

开始时间: 2025-03-24 08:29:46
初始化环境...
Episode 100 | Reward: 2.7
Episode 200 | Reward: 4.9
Episode 300 | Reward: 2.2
Episode 400 | Reward: 2.8
Episode 500 | Reward: 3.0
Episode 600 | Reward: 3.3
Episode 700 | Reward: 3.2
Episode 800 | Reward: 4.7
Episode 900 | Reward: 5.3
Episode 1000 | Reward: 7.5
Episode 1100 | Reward: 6.3
Episode 1200 | Reward: 3.7
Episode 1300 | Reward: 7.8
Episode 1400 | Reward: 3.8
Episode 1500 | Reward: 2.4
Episode 1600 | Reward: 2.3
Episode 1700 | Reward: 2.5
Episode 1800 | Reward: 2.7
Episode 1900 | Reward: 2.7
Episode 2000 | Reward: 3.9
Episode 2100 | Reward: 4.5
Episode 2200 | Reward: 4.1
Episode 2300 | Reward: 4.7
Episode 2400 | Reward: 4.0
Episode 2500 | Reward: 4.3
Episode 2600 | Reward: 3.8
Episode 2700 | Reward: 3.3
Episode 2800 | Reward: 4.6
Episode 2900 | Reward: 5.2
Episode 3000 | Reward: 7.7
Episode 3100 | Reward: 7.8
Episode 3200 | Reward: 3.3
Episode 3300 | Reward: 5.3
Episode 3400 | Reward: 4.5
Episode 3500 | Reward: 3.9
Episode 3600 | Reward: 4.1
Episode 3700 | Reward: 4.0
Episode 3800 | Reward: 5.2
Episode 3900 | Reward: 8.2
Episode 4000 | Reward: 2.2
Episode 4100 | Reward: 2.2
Episode 4200 | Reward: 2.2
Episode 4300 | Reward: 2.2
Episode 4400 | Reward: 6.9
Episode 4500 | Reward: 5.6
Episode 4600 | Reward: 2.0
Episode 4700 | Reward: 1.6
Episode 4800 | Reward: 1.7
Episode 4900 | Reward: 1.9
Episode 5000 | Reward: 3.1
训练完成时间: 2025-03-24 12:41:48
训练完成,耗时: 15122.31秒

在下一篇文章中,我们将探索 逆向强化学习(Inverse RL),并实现 GAIL 算法!


注意事项

  1. 安装依赖:

    pip install metaworld gymnasium torch
  2. Meta-World 需要 MuJoCo 许可证:

    export MUJOCO_PY_MUJOCO_PATH=/path/to/mujoco
  3. 训练时间较长(推荐 GPU 加速):

    CUDA_VISIBLE_DEVICES=0 python option_critic.py

http://www.kler.cn/a/601060.html

相关文章:

  • Sqoop-试题
  • 结合DrRacket学习《如何设计程序,第二版》
  • 基于Python的机器学习入门指南
  • Blender配置渲染设置并输出动画
  • 在转换不同格式时,保持正确的宽高比可以避免画面变形
  • Python FastApi(5):请求体、查询参数和字符串校验
  • k8s存储介绍(四)hostpath
  • 智能汽车图像及视频处理方案,支持视频实时拍摄特效能力
  • uv - pip 接口
  • 【多媒体交互】Unity+普通摄像头实现UI事件分析
  • VUE项目初始化
  • MATLAB 绘制空间分布图 方法总结
  • 【MySQL】mysql日志文件
  • 【QT】Qlcdnumber的使用
  • openai-agents-python中 agents_as_tools.py 示例
  • vue-如何将组件内容作为图片生成-html2canvas
  • Android ADB工具使用教程(从安装到使用)
  • 代理记账的第三个十年
  • Matlab多种算法解决未来杯B的多分类问题
  • 处理json,将接口返回的数据转成list<T>,和几个时间处理方法的工具类