当前位置: 首页 > article >正文

PyTorch 深度学习实战(19):离线强化学习与 Conservative Q-Learning (CQL) 算法

在上一篇文章中,我们探讨了分布式强化学习与 IMPALA 算法,展示了如何通过并行化训练提升强化学习的效率。本文将聚焦 离线强化学习(Offline RL) 这一新兴方向,并实现 Conservative Q-Learning (CQL) 算法,利用 Minari 提供的静态数据集训练安全的强化学习策略。


一、离线强化学习与 CQL 原理

1. 离线强化学习的特点
  • 无需环境交互:直接从预收集的静态数据集学习

  • 数据效率高:复用历史经验(如人类演示、日志数据)

  • 安全风险低:避免在线探索中的危险行为

2. CQL 核心思想

CQL 通过保守策略评估防止价值函数高估,其目标函数为:

3. 算法优势
  • 防止分布偏移导致的策略退化

  • 支持混合质量数据集(专家数据 + 随机数据)

  • 适用于真实世界场景(如医疗、金融)


二、CQL 实现步骤(基于 Minari 数据集)

我们将使用 Minari 库中的 D4RL/door/human-v2 数据集训练策略:

  1. 安装 Minari 并加载数据集

  2. 定义保守 Q 网络

  3. 实现保守正则化损失

  4. 策略优化与评估


三、代码实现

以下是 CQL 算法的完整实现代码:

import torch
import minari
import numpy as np
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from collections import deque
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
​
# 1. 增强型配置类(带维度校验)
class SafeConfig:
    # 训练参数
    batch_size = 1024
    lr = 3e-5
    tau = 0.007
    gamma = 0.99
    total_epochs = 500
    
    # 网络架构
    hidden_dim = 768
    num_layers = 3
    dropout_rate = 0.1
    activation_fn = 'Mish'  # 支持Mish/SiLU/ReLU
    
    # 正则化参数
    conservative_init = 2.5
    conservative_decay = 0.995
    min_conservative = 0.3
    reward_scale = 4.0
    
    # 探索参数
    noise_scale = 0.2
    noise_clip = 0.5
    candidate_samples = 400
    imitation_ratio = 0.15
​
# 2. 安全数据加载系统
class SafeDataset(Dataset):
    def __init__(self, dataset_name):
        # 加载原始数据
        dataset = minari.load_dataset(dataset_name, download=True)
        
        # 获取维度信息
        first_ep = dataset[0]
        self.state_dim = first_ep.observations[0].shape[0]
        self.action_dim = first_ep.actions[0].shape[0]
        
        # 数据存储
        self.obs, self.acts, self.rews, self.dones, self.next_obs = [], [], [], [], []
        for ep in dataset:
            self._store_episode(
                ep.observations[:-1],
                ep.actions,
                ep.rewards,
                np.logical_or(ep.terminations, ep.truncations),
                ep.observations[1:]
            )
        
        # 标准化
        self._normalize()
        self.priorities = np.ones(len(self.obs)) * 1e-5
    
    def _store_episode(self, obs, acts, rews, dones, next_obs):
        self.obs.extend(obs)
        self.acts.extend(acts)
        self.rews.extend(rews)
        self.dones.extend(dones)
        self.next_obs.extend(next_obs)
    
    def _normalize(self):
        # 状态标准化
        self.obs_mean = np.mean(self.obs, axis=0)
        self.obs_std = np.std(self.obs, axis=0) + 1e-8
        self.obs = (self.obs - self.obs_mean) / self.obs_std
        self.next_obs = (self.next_obs - self.obs_mean) / self.obs_std
        
        # 动作标准化
        self.act_mean = np.mean(self.acts, axis=0)
        self.act_std = np.std(self.acts, axis=0) + 1e-8
        self.acts = (self.acts - self.act_mean) / self.act_std
    
    def update_priorities(self, indices, priorities):
        self.priorities[indices] = np.abs(priorities.flatten()) + 1e-5
    
    def __len__(self):
        return len(self.obs)
    
    def __getitem__(self, idx):
        return (
            idx,
            torch.FloatTensor(self.obs[idx]),
            torch.FloatTensor(self.acts[idx]),
            torch.FloatTensor(self.next_obs[idx]),
            torch.FloatTensor([self.rews[idx]]),
            torch.FloatTensor([bool(self.dones[idx])])
        )
​
# 3. 维度安全网络架构
class SafeQNetwork(torch.nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.input_dim = state_dim + action_dim  # 关键动态计算
        
        # 主网络
        self.feature_net = self._build_network()
        self.q1 = torch.nn.Linear(SafeConfig.hidden_dim, 1)
        self.q2 = torch.nn.Linear(SafeConfig.hidden_dim, 1)
        
        # 目标网络
        self.target_net = self._build_network()
        self.target_q1 = torch.nn.Linear(SafeConfig.hidden_dim, 1)
        self.target_q2 = torch.nn.Linear(SafeConfig.hidden_dim, 1)
        
        # 初始化
        self._init_weights()
        self._update_target(1.0)
    
    def _build_network(self):
        layers = []
        input_dim = self.input_dim  # 使用动态计算值
        for _ in range(SafeConfig.num_layers):
            layers.extend([
                torch.nn.Linear(input_dim, SafeConfig.hidden_dim),
                torch.nn.LayerNorm(SafeConfig.hidden_dim),
                self._activation(),
                torch.nn.Dropout(SafeConfig.dropout_rate),
            ])
            input_dim = SafeConfig.hidden_dim
        return torch.nn.Sequential(*layers)
    
    def _activation(self):
        return {
            'Mish': torch.nn.Mish(),
            'SiLU': torch.nn.SiLU(),
            'ReLU': torch.nn.ReLU()
        }[SafeConfig.activation_fn]
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.orthogonal_(m.weight)
                torch.nn.init.normal_(m.bias, 0, 0.1)
    
    def forward(self, state, action):
        # 维度校验
        assert state.shape[-1] == self.state_dim, f"State dim error: {state.shape[-1]} vs {self.state_dim}"
        assert action.shape[-1] == self.action_dim, f"Action dim error: {action.shape[-1]} vs {self.action_dim}"
        
        x = torch.cat([state, action], dim=1)
        features = self.feature_net(x)
        return self.q1(features), self.q2(features)
    
    def target_forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        features = self.target_net(x)
        return self.target_q1(features), self.target_q2(features)
    
    def _update_target(self, tau):
        with torch.no_grad():
            for t_param, param in zip(self.target_net.parameters(), self.feature_net.parameters()):
                t_param.data.copy_(tau * param.data + (1 - tau) * t_param.data)
            for t_param, param in zip(self.target_q1.parameters(), self.q1.parameters()):
                t_param.data.copy_(tau * param.data + (1 - tau) * t_param.data)
            for t_param, param in zip(self.target_q2.parameters(), self.q2.parameters()):
                t_param.data.copy_(tau * param.data + (1 - tau) * t_param.data)
​
# 4. 安全训练系统
class SafeTrainer:
    def __init__(self, dataset_name):
        # 数据系统
        self.dataset = SafeDataset(dataset_name)
        self.state_dim = self.dataset.state_dim
        self.action_dim = self.dataset.action_dim
        
        # 网络系统
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.q_net = SafeQNetwork(self.state_dim, self.action_dim).to(self.device)
        
        # 优化系统
        self.optimizer = torch.optim.AdamW(
            self.q_net.parameters(),
            lr=SafeConfig.lr,
            weight_decay=1e-3
        )
        self.scheduler = CosineAnnealingWarmRestarts(
            self.optimizer,
            T_0=100,
            eta_min=1e-6
        )
        
        # 数据加载
        self.dataloader = DataLoader(
            self.dataset,
            batch_size=SafeConfig.batch_size,
            sampler=WeightedRandomSampler(
                self.dataset.priorities,
                num_samples=len(self.dataset),
                replacement=True
            ),
            collate_fn=lambda b: {
                'indices': torch.LongTensor([x[0] for x in b]),
                'states': torch.stack([x[1] for x in b]),
                'actions': torch.stack([x[2] for x in b]),
                'next_states': torch.stack([x[3] for x in b]),
                'rewards': torch.stack([x[4] for x in b]),
                'dones': torch.stack([x[5] for x in b])
            },
            num_workers=4
        )
        
        # 训练状态
        self.conservative_weight = SafeConfig.conservative_init
        self.loss_history = deque(maxlen=100)
    
    def train_epoch(self, epoch):
        self.q_net.train()
        total_loss = 0.0
        
        for batch in self.dataloader:
            # 数据准备
            states = batch['states'].to(self.device)
            actions = batch['actions'].to(self.device)
            next_states = batch['next_states'].to(self.device)
            rewards = batch['rewards'].to(self.device) * SafeConfig.reward_scale
            dones = batch['dones'].to(self.device)
            
            # 目标Q值计算
            with torch.no_grad():
                # 带噪声的动作生成
                noise = torch.randn_like(actions) * SafeConfig.noise_scale
                noise = torch.clamp(noise, -SafeConfig.noise_clip, SafeConfig.noise_clip)
                noisy_actions = actions + noise
                
                # 双Q学习
                target_q1, target_q2 = self.q_net.target_forward(next_states, noisy_actions)
                target_q = torch.min(target_q1, target_q2).squeeze(-1)
                y = rewards.squeeze(-1) + (1 - dones.squeeze(-1)) * SafeConfig.gamma * target_q
            
            # 当前Q值预测
            current_q1, current_q2 = self.q_net(states, actions)
            current_q1 = current_q1.squeeze(-1).clamp(-10.0, 50.0)
            current_q2 = current_q2.squeeze(-1).clamp(-10.0, 50.0)
            
            # 损失计算
            bellman_loss = 0.5 * (
                torch.nn.functional.huber_loss(current_q1, y, delta=1.0) +
                torch.nn.functional.huber_loss(current_q2, y, delta=1.0)
            )
            
            # 保守正则项
            rand_acts = torch.randn_like(actions) * SafeConfig.noise_scale
            q1_rand, q2_rand = self.q_net(states, rand_acts)
            conservative_loss = (q1_rand + q2_rand).mean() - (current_q1 + current_q2).mean()
            
            # 总损失
            loss = bellman_loss + self.conservative_weight * conservative_loss
            
            # 反向传播
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 2.0)
            self.optimizer.step()
            
            # 更新目标网络
            self.q_net._update_target(SafeConfig.tau)
            
            # 更新优先级
            td_errors = (current_q1 - y).detach().cpu().numpy()
            self.dataset.update_priorities(batch['indices'].numpy(), td_errors)
            
            total_loss += loss.item()
        
        # 调整保守权重
        self.conservative_weight = max(
            self.conservative_weight * SafeConfig.conservative_decay,
            SafeConfig.min_conservative
        )
        
        # 学习率调度
        self.scheduler.step()
        
        return total_loss / len(self.dataloader)
    
    def get_action(self, state):
        self.q_net.eval()
        state_norm = (state - self.dataset.obs_mean) / self.dataset.obs_std
        state_tensor = torch.FloatTensor(state_norm).unsqueeze(0).to(self.device)
        
        # 候选动作生成
        num_imitation = int(SafeConfig.candidate_samples * SafeConfig.imitation_ratio)
        imitation_idx = np.random.choice(len(self.dataset), num_imitation)
        imitation_acts = self.dataset.acts[imitation_idx]
        noise_acts = np.random.randn(SafeConfig.candidate_samples - num_imitation, self.action_dim)
        candidates = np.concatenate([imitation_acts, noise_acts])
        candidates = (candidates * self.dataset.act_std) + self.dataset.act_mean
        
        # 选择最优动作
        with torch.no_grad():
            state_batch = state_tensor.repeat(SafeConfig.candidate_samples, 1)
            candidate_tensor = torch.FloatTensor(candidates).to(self.device)
            candidate_norm = (candidate_tensor - self.dataset.act_mean) / self.dataset.act_std
            q_values, _ = self.q_net(state_batch, candidate_norm)
            best_idx = torch.argmax(q_values)
        
        return candidates[best_idx.cpu().item()]
​
# 5. 训练执行
if __name__ == "__main__":
    trainer = SafeTrainer("D4RL/door/human-v2")
    print(f"初始化维度检查: state={trainer.state_dim}, action={trainer.action_dim}")
    
    try:
        for epoch in range(SafeConfig.total_epochs):
            loss = trainer.train_epoch(epoch)
            
            if (epoch + 1) % 20 == 0:
                print(f"Epoch {epoch+1:04d} | Loss: {loss:.2f} | "
                      f"Conserv: {trainer.conservative_weight:.2f} | "
                      f"LR: {trainer.scheduler.get_last_lr()[0]:.1e}")
    
    except KeyboardInterrupt:
        print("\n训练中断,保存检查点...")
        torch.save(trainer.q_net.state_dict(), "interrupted.pth")
    
    print("训练完成...")

四、关键代码解析

  1. 数据集加载

    • 使用 minari.load_dataset 加载离线数据集

    • 数据集包含状态、动作、奖励、终止标志等信息

  2. 保守正则化实现

    • 通过随机动作采样计算正则项

    • 超参数 $\alpha$ 控制保守程度

  3. 策略提取技巧

    • 采用基于 Q 值的启发式策略

    • 通过多候选动作采样提升稳定性


五、训练结果

运行代码将观察到:

初始化维度检查: state=39, action=28
Epoch 0020 | Loss: -46.52 | Conserv: 2.26 | LR: 2.7e-05
Epoch 0040 | Loss: -73.80 | Conserv: 2.05 | LR: 2.0e-05
Epoch 0060 | Loss: -73.50 | Conserv: 1.85 | LR: 1.1e-05
Epoch 0080 | Loss: -64.76 | Conserv: 1.67 | LR: 3.8e-06
Epoch 0100 | Loss: -54.37 | Conserv: 1.51 | LR: 3.0e-05
Epoch 0120 | Loss: -59.95 | Conserv: 1.37 | LR: 2.7e-05
Epoch 0140 | Loss: -60.11 | Conserv: 1.24 | LR: 2.0e-05
Epoch 0160 | Loss: -54.49 | Conserv: 1.12 | LR: 1.1e-05
Epoch 0180 | Loss: -46.11 | Conserv: 1.01 | LR: 3.8e-06
Epoch 0200 | Loss: -37.10 | Conserv: 0.92 | LR: 3.0e-05
Epoch 0220 | Loss: -37.56 | Conserv: 0.83 | LR: 2.7e-05
Epoch 0240 | Loss: -36.40 | Conserv: 0.75 | LR: 2.0e-05
Epoch 0260 | Loss: -31.79 | Conserv: 0.68 | LR: 1.1e-05
Epoch 0280 | Loss: -24.44 | Conserv: 0.61 | LR: 3.8e-06
Epoch 0300 | Loss: -17.06 | Conserv: 0.56 | LR: 3.0e-05
Epoch 0320 | Loss: -17.40 | Conserv: 0.50 | LR: 2.7e-05
Epoch 0340 | Loss: -16.91 | Conserv: 0.45 | LR: 2.0e-05
Epoch 0360 | Loss: -12.76 | Conserv: 0.41 | LR: 1.1e-05
Epoch 0380 | Loss: -7.27 | Conserv: 0.37 | LR: 3.8e-06
Epoch 0400 | Loss: -0.27 | Conserv: 0.34 | LR: 3.0e-05
Epoch 0420 | Loss: -1.47 | Conserv: 0.30 | LR: 2.7e-05
Epoch 0440 | Loss: -2.50 | Conserv: 0.30 | LR: 2.0e-05
Epoch 0460 | Loss: -2.87 | Conserv: 0.30 | LR: 1.1e-05
Epoch 0480 | Loss: -2.64 | Conserv: 0.30 | LR: 3.8e-06
Epoch 0500 | Loss: -2.30 | Conserv: 0.30 | LR: 3.0e-05
训练完成...


六、总结与扩展

本文基于 Minari 实现了 CQL 算法的核心逻辑,展示了离线强化学习在安全关键场景的应用价值。读者可尝试以下扩展:

  1. 添加策略网络实现 Actor-Critic 架构

  2. antmaze 等迷宫类数据集测试导航能力

  3. 实现更精确的 OOD(分布外)动作检测

在下一篇文章中,我们将探索 基于模型的强化学习(Model-Based RL),并实现 PETS 算法!


注意事项

  1. 需先安装 minari 库:

    pip install "minari[all]"
  2. 数据集路径可通过 minari.list_datasets() 查看

  3. 调整 alpha 参数可平衡保守性与探索性

希望本文能帮助您理解离线强化学习的核心范式!欢迎在评论区分享您的实践心得。


http://www.kler.cn/a/594436.html

相关文章:

  • 210、【图论】课程表(Python)
  • 高可用环境下Nginx服务管理脚本优化实践
  • VUE中使用路由router跳转页面
  • 传统金融和分布式金融
  • 五、AIGC大模型_10多模态大语言模型基础知识与示例
  • vue3之写一个aichat---已聊天组件部分功能
  • 文献检索与下指南
  • 神来之笔!Profinet转DeviceNet网关书写OTC焊机机器人连接传奇
  • PHP与Python无缝融合,开启跨语言开发新纪元
  • Web开发-JS应用原生代码前端数据加密CryptoJS库jsencrypt库代码混淆
  • 25.单例模式实现线程池
  • HW华为流程管理体系精髓提炼华为流程运营体系(124页PPT)(文末有下载方式)
  • 【QA】策略模式在QT有哪些应用
  • LabVIEW运动控制(二):EtherCAT运动控制器的多轴示教加工应用(下)
  • Unity音频混合器如何暴露参数
  • 用python制作一个俄罗斯方块小游戏
  • js 力扣100题 非负整数加一
  • 大白话详细解读React框架的diffing算法
  • 《剑指数据库:MySQL玄阶查术秘典·下卷》
  • 【c++】【STL】unordered_set 底层实现(简略版)