当前位置: 首页 > article >正文

ray.rllib-入门实践-12:自定义policy

        在本博客开始之前,先厘清一下几个概念之间的区别与联系:env,  agent,  model, algorithm, policy. 

        强化学习由两部分组成: 环境(env)和智能体(agent)。环境(env)提供观测值和奖励; agent读取观测值,输出动作或决策。agent是algorithm的类对象。 policy是algorithm的子类, 比如ppo, dqn等。因此,自定义policy本质上是自定义algorithm.  algorithm 主要由两部分组成: 网络结构(model)和损失函数(loss)。 网络结构(model)的自定义由上一个博客ray.rllib-入门实践-11: 自定义模型/网络 进行了介绍:在alrorithm外创建新的model类, 通过 AlgorithmConfig类传入algorithm。因此, c从实际操作上, 自定义algorithm就变成了自定义algorithm 的 loss.

        因此,本博客所提到的自定义policy, 本质上就是继承一个Algorithm, 并修改它的loss函数。

        与之前介绍的自定义env, 自定义model一样, 自定义policy也包含三个步骤:

        1. 继承某个Policy, 创建一个新Policy类, 修改它的损失函数。

        2. 把自己的Policy封装为一个Algorithm, 使ray可识别

        3. 配置使用自己的Policy.

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

一、 自定义 policy

import torch 
import gymnasium as gym 
from gymnasium import spaces
from ray.rllib.utils.annotations import override
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.algorithms.ppo import PPO, PPOConfig, PPOTorchPolicy
from typing import Dict, List, Type, Union
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.sample_batch import SampleBatch

## 1. 自定义 policy, 主要是改变 policy 的 loss 的计算  # 神经网络的损失函数
class MY_PPOTorchPolicy(PPOTorchPolicy):
    """PyTorch policy class used with PPO."""
    def __init__(self, observation_space:gym.spaces.Box, action_space:gym.spaces.Box, config:PPOConfig): 
        PPOTorchPolicy.__init__(self,observation_space,action_space,config)
        ## PPOTorchPolicy 内部对 PPOConfig 格式的config 执行了to_dict()操作,后面可以以 dict 的形式使用 config

    @override(PPOTorchPolicy) 
    def loss(self,model: ModelV2,dist_class: Type[ActionDistribution],train_batch: SampleBatch):
        ## 原始损失
        original_loss = super().loss(model, dist_class, train_batch) # PPO原来的损失函数, 也可以完全自定义新的loss函数, 但是非常不建议。

        ## 新增自定义损失,这里以正则化损失作为示例
        addiontial_loss = torch.tensor(0.0) ## 自己定义的loss
        addiontial_loss = torch.tensor(0.)
        for param in model.parameters():
            addiontial_loss += torch.norm(param)
        ## 得到更新后的损失
        new_loss = original_loss + 0.01 * addiontial_loss
        return new_loss

二、 把自己的policy封装在一个算法中

## 2. 把自己的 policy 封装在算法中: 
##    继承自PPO, 创建一个新的算法类, 默认调用的是自定义的policy
class MY_PPO(PPO):
    ## 重写 get_default_policy_class 函数, 使其返回自定义的policy 
    def get_default_policy_class(self, config):
        return MY_PPOTorchPolicy

三、使用自己的策略创建智能体,执行训练

配置方法1:

## 三、使用自己的策略创建智能体,执行训练
## method-1
from ray.tune.logger import pretty_print 
config = PPOConfig(algo_class = MY_PPO) ## 配置使用自己的算法
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 
algo = config.build()
result = algo.train()
print(pretty_print(result))

配置方法2:

## 3. 使用新策略执行训练
## method-2
from ray.tune.logger import pretty_print 
config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 
algo = MY_PPO(config=config,)  ## 在这里使用自己的policy
result = algo.train()
print(pretty_print(result))

四、代码汇总:

import torch 
import gymnasium as gym 
from gymnasium import spaces
from ray.rllib.utils.annotations import override
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.algorithms.ppo import PPO, PPOConfig, PPOTorchPolicy
from typing import Dict, List, Type, Union
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.sample_batch import SampleBatch
from ray.tune.logger import pretty_print 

## 1. 自定义 policy, 主要是改变 policy 的 loss 的计算  # 神经网络的损失函数
class MY_PPOTorchPolicy(PPOTorchPolicy):
    """PyTorch policy class used with PPO."""
    def __init__(self, observation_space:gym.spaces.Box, action_space:gym.spaces.Box, config:PPOConfig): 
        PPOTorchPolicy.__init__(self,observation_space,action_space,config)
        ## PPOTorchPolicy 内部对 PPOConfig 格式的config 执行了to_dict()操作,后面可以以 dict 的形式使用 config

    @override(PPOTorchPolicy) 
    def loss(self,model: ModelV2,dist_class: Type[ActionDistribution],train_batch: SampleBatch):
        ## 原始损失
        original_loss = super().loss(model, dist_class, train_batch) # PPO原来的损失函数, 也可以完全自定义新的loss函数, 但是非常不建议。

        ## 新增自定义损失,这里以正则化损失作为示例
        addiontial_loss = torch.tensor(0.0) ## 自己定义的loss
        addiontial_loss = torch.tensor(0.)
        for param in model.parameters():
            addiontial_loss += torch.norm(param)
        ## 得到更新后的损失
        new_loss = original_loss + 0.01 * addiontial_loss
        return new_loss
    
## 2. 把自己的 policy 封装在算法中: 
##    继承自PPO, 创建一个新的算法类, 默认调用的是自定义的policy
class MY_PPO(PPO):
    ## 重写 get_default_policy_class 函数, 使其返回自定义的policy 
    def get_default_policy_class(self, config):
        return MY_PPOTorchPolicy

## 三、使用自己的策略创建智能体,执行训练
## method-1
# config = PPOConfig(algo_class = MY_PPO) ## 配置使用自己的算法
# config = config.environment("CartPole-v1")
# config = config.rollouts(num_rollout_workers=2)
# config = config.framework(framework="torch") 
# algo = config.build()
# result = algo.train()
# print(pretty_print(result))

## method-2
config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 
algo = MY_PPO(config=config,)  ## 在这里配置使用自己的policy
result = algo.train()
print(pretty_print(result))

五、后记:

        如何既要修改网络结构,又要修改loss函数, 可以结合 上一篇博客 和本博客共同实现。


http://www.kler.cn/a/519462.html

相关文章:

  • python生成图片和pdf,快速
  • GESP2024年3月认证C++六级( 第三部分编程题(1)游戏)
  • vite环境变量处理
  • 电脑无法开机,重装系统后没有驱动且驱动安装失败
  • 全面评测 DOCA 开发环境下的 DPU:性能表现、机器学习与金融高频交易下的计算能力分析
  • 【科研建模】Pycaret自动机器学习框架使用流程及多分类项目实战案例详解
  • Maui学习笔记-SignalR简单介绍
  • MySQL中的读锁与写锁:概念与作用深度剖析
  • 延迟之争:LLM服务的制胜关键
  • Linux系统之gzip命令的基本使用
  • C++ 与机器学习:构建高效推理引擎的秘诀
  • Gary Marcus对2025年AI的25项预测:AGI的曙光仍未到来?
  • C语言I/O请用递归实现计算 :1 + 1/3 - 1/5 + 1/7 - 1/9 + .... 1/n 的值,n通过键盘输入
  • SpringBoot基础概念介绍-数据源与数据库连接池
  • An OpenGL Toolbox
  • mysql 学习6 DML语句,对数据库中的表进行 增 删 改 操作
  • 设计模式的艺术-代理模式
  • 2024-2025年终总结
  • 使用vscode + Roo Code (prev. Roo Cline)+DeepSeek-R1使用一句话需求做了个实验
  • 每日一题-二叉搜索树与双向链表
  • 浏览器IndexedDB占用大
  • HarmonyOS DevEco Studio模拟器点击运行没有反应的解决方法
  • rust并发和golang并发比较
  • 二叉搜索树中的搜索(力扣700)
  • Android HandlerThread
  • 【C++基础】多线程并发场景下的同步方法