当前位置: 首页 > article >正文

ray.rllib-入门实践-11: 自定义模型/网络

在ray.rllib中定义和使用自己的模型, 分为以下三个步骤:

1. 定义自己的模型。

2. 向ray注册自定义的模型

3. 在config中配置使用自定义的模型

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

一、 定义自己的模型 

需要继承自 TFModel 或 TorchModelV2, 并重写需要自定义的方法, 其代码框架如下:

import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2

class My_Model(TorchModelV2, nn.Module): ## 重构以下函数, 函数接口不能变。

    def __init__(self, obs_space, action_space, num_outputs, model_config, name, *, custom_arg1, custom_arg2): ...
    
    def forward(self, input_dict, state, seq_lens): ...
    
    def value_function(self): ...

示例如下:

## 1. 定义自己的模型
import numpy as np 
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym 
from gymnasium import spaces  
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict

class My_Model(TorchModelV2, nn.Module):
    def __init__(self, obs_space:gym.spaces.Space, 
                 action_space:gym.spaces.Space, 
                 num_outputs:int, 
                 model_config:ModelConfigDict,  ## PPOConfig.training(model = ModelConfigDict), 调用的是config.model中的参数
                 name:str
                 ,*, custom_arg1, custom_arg2):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,model_config,name)
        nn.Module.__init__(self)
        ## 测试 custom_arg1 , custom_arg2 传递进来的是什么数值
        print(f"===========================  custom_arg1 = {custom_arg1}, custom_arg2 = {custom_arg2}")

        ## 定义网络层
        obs_dim = int(np.product(obs_space.shape))
        action_dim = int(np.product(action_space.shape))
        ## shareNet
        self.shared_fc = nn.Linear(obs_dim,128)
        ## actorNet
        self.actorNet = nn.Linear(128, action_dim)
        ## criticNet
        self.criticNet = nn.Linear(128,1)

        self._feature = None 

    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs"].float()
        self._feature = self.shared_fc.forward(obs)
        action_logits = self.actorNet.forward(self._feature)
        return action_logits, state 
    
    def value_function(self):
        value = self.criticNet.forward(self._feature).squeeze(1)
        return value 

        在rllib中,每个算法的所有网络都被汇集到同一个 ModelV2 类下,供算法调用。actor 网络和critic网络可以在外面定义,也可以在model内部直接定义。 model的forward用于返回actor网络的输出, value_function函数用于返回critic网络的输出。 网络结构和网络层共享可以自定义设置。输入输出接口,需要与上面保持严格一致。

二、 向ray注册自定义模型

        ray.rllib.model.ModelCatalog 类,用于向ray注册自定义的model, 还可以用于获取env的 preprocessors 和 action distributions。

import ray 
from ray.rllib.models import ModelCatalog # ModelCatalog 类: 用于注册 models, 获取env的 preprocessors 和 action distributions。 

ModelCatalog.register_custom_model(model_name="my_torch_model", model_class = My_Model)

三、 在算法中配置并使用自定义的模型

主要是在 config.training() 模块中的 model 子模块中传入两个配置信息:

        1)"custom_model":"my_torch_model" ,                      
         2)"custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})  

两个关键字固定不变,填入自己注册的模型名和对应的模型参数即可。

可以有以下三种配置代码的编写方式:

配置方法1:

## 3. 在训练中使用自定义模型
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print 

config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 

## 配置使用自定义的模型
config = config.training(model= {"custom_model":"my_torch_model" ,                      
                                 "custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})  
## 主要在上面两行配置使用自己的模型
##    配置 model 的 "custom_model" 项,用于指定rllib算法所使用的模型
##    配置 model 的 "custom_model_config" 项,用于传入自定义的网络参数,供自定义的model使用。
##    这两个关键词不可更改。

algo = config.build()
## 4. 执行训练
result = algo.train()
print(pretty_print(result))

与以上配置内容一样,还可以用以下两种配置写法:

配置方法2:

config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 

## 配置自定义模型
model_config_dict = {}
model_config_dict["custom_model"] = "my_torch_model" 
model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
config = config.training(model= model_config_dict)  

algo = config.build()

 配置方法3(推荐):

config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 

## 配置自定义模型
config.model["custom_model"] = "my_torch_model"
config.model["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}

algo = config.build()

 代码汇总:

"""
在ray.rllib中定义和使用自己的模型, 分为以下三个步骤:
1. 定义自己的模型。 
    需要继承自 TFModel 或 TorchModelV2, 并重写需要自定义的方法
    import torch.nn as nn
    from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
    class CustomTorchModel(TorchModelV2, nn.Module): ## 重构以下函数, 函数接口不能变。 
        def __init__(self, obs_space, action_space, num_outputs, model_config, name, *, custom_arg1, custom_arg2): ...
        def forward(self, input_dict, state, seq_lens): ...
        def value_function(self): ...

2. 向ray注册自定义的模型
    from ray.rllib.models import ModelCatalog
    ModelCatalog.register_custom_model("wzg_torch_model", CustomTorchModel)

3. 在config中配置使用自定义的模型
    model_config_dict = {
        "custom_model":"wzg_torch_model",
        "custom_model_config":{
            "custom_arg1": 1,
            "custom_arg2": 2}
    }
    config = PPOConfig()
    # config = config.training(model = model_config_dict)
    config.model["custom_model"] = "wzg_torch_model"
    config.model["custom_model_config"] = {"custom_arg1": 1,
                                    "custom_arg2": 2}
"""

## 1. 定义自己的模型
import numpy as np 
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym 
from gymnasium import spaces  
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict

class My_Model(TorchModelV2, nn.Module):
    def __init__(self, obs_space:gym.spaces.Space, 
                 action_space:gym.spaces.Space, 
                 num_outputs:int, 
                 model_config:ModelConfigDict,  ## PPOConfig.training(model = ModelConfigDict), 调用的是config.model中的参数
                 name:str
                 ,*, custom_arg1, custom_arg2):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,model_config,name)
        nn.Module.__init__(self)
        ## 测试 custom_arg1 , custom_arg2 传递进来的是什么数值
        print(f"===========================  custom_arg1 = {custom_arg1}, custom_arg2 = {custom_arg2}")

        ## 定义网络层
        obs_dim = int(np.product(obs_space.shape))
        action_dim = int(np.product(action_space.shape))
        ## shareNet
        self.shared_fc = nn.Linear(obs_dim,128)
        ## actorNet
        self.actorNet = nn.Linear(128, action_dim)
        ## criticNet
        self.criticNet = nn.Linear(128,1)

        self._feature = None 

    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs"].float()
        self._feature = self.shared_fc.forward(obs)
        action_logits = self.actorNet.forward(self._feature)
        return action_logits, state 
    
    def value_function(self):
        value = self.criticNet.forward(self._feature).squeeze(1)
        return value 

## 2. 向ray注册自定义模型
import ray 
from ray.rllib.models import ModelCatalog # ModelCatalog 类: 用于注册 models, 获取env的 preprocessors 和 action distributions。 

ModelCatalog.register_custom_model(model_name="my_torch_model", model_class = My_Model)
ray.init()

## 3. 在训练中使用自定义模型
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print 

config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 

# ## 配置自定义模型:方法 1
# config = config.training(model= {"custom_model":"my_torch_model" ,                      
#                                  "custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})  
# ## 配置自定义模型:方法 2
# model_config_dict = {}
# model_config_dict["custom_model"] = "my_torch_model" 
# model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
# config = config.training(model= model_config_dict) 

## 配置自定义模型: 方法 3 (个人更喜欢, 因为嵌套层次少)
config.model["custom_model"] = "my_torch_model"
config.model["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}

## 错误方法:
# model_config_dict = {}
# model_config_dict["custom_model"] = "my_torch_model" 
# model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
# config.model = model_config_dict # 会清空 model 里面的其他默认配置,导致报错

algo = config.build()

## 4. 执行训练
result = algo.train()
print(pretty_print(result))


 


http://www.kler.cn/a/519607.html

相关文章:

  • 一文讲解Java中的接口和抽象类
  • BLE透传方案,IoT短距无线通信的“中坚力量”
  • 工业数据分析:解锁工厂数字化的潜力
  • 人格分裂(交互问答)-我想懂Elasticsearch
  • vite环境变量处理
  • 我谈区域偏心率
  • 第22章 走进xUnit:测试驱动开发的关键工具(持续探索)
  • 凝“华”聚智,“清”创未来-----华清远见教育科技集团成都中心2024年度总结大会暨2025新春盛典
  • 【论文阅读】HumanPlus: Humanoid Shadowing and Imitation from Humans
  • 蓝桥杯之c++入门(一)【第一个c++程序】
  • 27. 【.NET 8 实战--孢子记账--从单体到微服务】--简易报表--报表服务
  • Docker 系列之 docker-compose 容器编排详解
  • 【信息系统项目管理师-选择真题】2017上半年综合知识答案和详解
  • Transfoemr的解码器(Decoder)与分词技术
  • QT:控件属性及常用控件(4)-----多元素控件、容器类控件、布局管理器
  • 3.numpy练习(2)
  • RabbitMQ 分布式高可用
  • 【Linux】Linux编译器-g++、gcc、动静态库
  • 7、知识库内容更新与自动化
  • 系统编程(线程互斥)
  • 牛角棋项目实践1:牛角棋的定义和用python实现简单功能
  • 大模型开发 | RAG在实际开发中可能遇到的坑
  • rewrite规则
  • STL中的list容器
  • 汇编的使用总结
  • CSS:跑马灯