当前位置: 首页 > article >正文

ray.rllib-入门实践-10:自定义环境

        前面介绍的入门实践都是基于 ray.rllib 内置的环境、模型和算法执行的,在应对具体任务时, 需要自定义交互环境、改进网络模型或者算法的损失函数。从本博客开始将逐个介绍。

        在ray.rllib中使用自定义的环境,主要分为三步:

        1) 创建自定义的环境类

        2) 向 ray 注册自定义的环境

        3) 在算法配置和训练中使用环境

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

一、 创建自定义的环境类

import gymnasium as gym
from gymnasium import spaces
import ray
import numpy as np
from ray.rllib.algorithms.ppo import PPO, PPOConfig
from ray.tune.registry import register_env

## 1. 定义环境
class MyEnv(gym.Env):                               ## 注意1: 需要继承自gym.Env
    def __init__(self,env_config):
        self.worker_index = env_config.worker_index ## worker_index是
        self.vector_index = env_config.vector_index
        self.action_space = spaces.Box(low=-1,high=1,shape=(5,)) ## 一般需要定义动作空间作为成员变量
        self.observation_space = spaces.Box(low=-1,high=1,shape=(6,)) ## 一般需要定义观测空间作为成员变量。
        self.step_count = 0

    def reset(self, seed=None, options=None): ## 重构reset函数, 需要包含 seed=None, options=None 两个变量
                                              ## seed 和 options后面要附上默认值None,否则报错
        self.step_count = 0
        obs = self.observation_space.sample()
        info = {}
        print(f"========== reset called ========")
        print(f"========== worker_index = {self.worker_index} ========")
        return obs,info                     ## 返回的数据格式是指定的,需要严格一致,否则报错
    
    def step(self,action):   ## 重构step函数, 输入参数是指定的
        self.step_count += 1
        obs = self.observation_space.sample()
        reward = 0
        terminated = False
        truncated = False 
        info = {}
        if self.step_count > 10:
            terminated = True
        return obs,reward,terminated,truncated,info   ## 输出变量的数据格式必须是指定的

注意事项:

        1. 主要是重构 reset 和 step 函数,严格按照 gym.Env的相关接口规则。

        2. 在 __init__里面,最好定义self.action_space 和 self.observation_space两个成员变量

        3. env_config.worker_index 是AlgorithmConfig 在调用env_config时额外附带的参数, 不是DIY者手工定义的:无需自定义,直接使用即可。

二、 向 ray 注册自定义的环境

from ray.tune.registry import register_env

def env_creator(env_config):
    return MyEnv(env_config)
register_env("MyEnv",env_creator)

 注册完毕之后,就可以向使用rllib自带的环境(比如前面使用的Cartpole-v1环境)一样使用自己创建的环境了。

三、 使用自定义的环境进行算法配置和训练

## 3. 启动 ray, 创建算法,使用环境
ray.init()
## 配置算法
config = PPOConfig()
config = config.environment(env="MyEnv", env_config={})
algo = config.build()

## 训练
for i in range(3):
    iter_result = algo.train()
    print(f"episode_{i}")

代码汇总: 

import gymnasium as gym
from gymnasium import spaces
import ray
import numpy as np
from ray.rllib.algorithms.ppo import PPO, PPOConfig
from ray.tune.registry import register_env

## 1. 定义环境
class MyEnv(gym.Env):                               ## 注意1: 需要继承自gym.Env
    def __init__(self,env_config):
        self.worker_index = env_config.worker_index ## worker_index是
        self.vector_index = env_config.vector_index
        self.action_space = spaces.Box(low=-1,high=1,shape=(5,)) ## 一般需要定义动作空间作为成员变量
        self.observation_space = spaces.Box(low=-1,high=1,shape=(6,)) ## 一般需要定义观测空间作为成员变量。
        self.step_count = 0

    def reset(self, seed=None, options=None): ## 重构reset函数, 需要包含 seed=None, options=None 两个变量
                                              ## seed 和 options后面要附上默认值None,否则报错
        self.step_count = 0
        obs = self.observation_space.sample()
        info = {}
        print(f"========== reset called ========")
        print(f"========== worker_index = {self.worker_index} ========")
        return obs,info                     ## 返回的数据格式是指定的,需要严格一致,否则报错
    
    def step(self,action):   ## 重构step函数, 输入参数是指定的
        self.step_count += 1
        obs = self.observation_space.sample()
        reward = 0
        terminated = False
        truncated = False 
        info = {}
        if self.step_count > 10:
            terminated = True
        return obs,reward,terminated,truncated,info   ## 输出变量的数据格式必须是指定的
    
## 2. 向ray注册环境
def env_creator(env_config):
    return MyEnv(env_config)
register_env("MyEnv",env_creator)

## 3. 启动 ray, 创建算法,使用环境
ray.init()
## 配置算法
config = PPOConfig()
config = config.environment(env="MyEnv", env_config={})
algo = config.build()

## 训练
for i in range(3):
    iter_result = algo.train()
    print(f"episode_{i}")
    


http://www.kler.cn/a/521351.html

相关文章:

  • 【Elasticsearch】聚合分析:管道聚合
  • 基于 Arduino Uno 和 RFID-RC522 的 RFID 卡号读取技术详解
  • 深度学习模型架构演进:从RNN到新兴技术
  • 数据结构——查找算法和排序算法
  • 【C++】std::prev用法
  • ubuntu下编译openjdk17,依赖的包名有所不同
  • 基于 RAMS 的数据驱动建模与应用实践:从理论到具体操作
  • 1.26 实现文件拷贝的功能
  • 我的2024年年度总结
  • 自然元素有哪些选择?
  • K8S部署DevOps自动化运维平台
  • Arouter详解・常见面试题
  • deepseek各个版本及论文
  • WPS数据分析000007
  • ArcGIS安装动物家域分析插件HRT的方法
  • 为AI聊天工具添加一个知识系统 之72 详细设计之13 图灵机
  • Level DB --- TableBuilder
  • C 或 C++ 中用于表示常量的后缀:1ULL
  • C++从入门到实战(二)C++命名空间
  • 【信息系统项目管理师-选择真题】2016上半年综合知识答案和详解