ray.rllib-入门实践-10:自定义环境
前面介绍的入门实践都是基于 ray.rllib 内置的环境、模型和算法执行的,在应对具体任务时, 需要自定义交互环境、改进网络模型或者算法的损失函数。从本博客开始将逐个介绍。
在ray.rllib中使用自定义的环境,主要分为三步:
1) 创建自定义的环境类
2) 向 ray 注册自定义的环境
3) 在算法配置和训练中使用环境
环境配置:
torch==2.5.1
ray==2.10.0
ray[rllib]==2.10.0
ray[tune]==2.10.0
ray[serve]==2.10.0
numpy==1.23.0
python==3.9.18
一、 创建自定义的环境类
import gymnasium as gym
from gymnasium import spaces
import ray
import numpy as np
from ray.rllib.algorithms.ppo import PPO, PPOConfig
from ray.tune.registry import register_env
## 1. 定义环境
class MyEnv(gym.Env): ## 注意1: 需要继承自gym.Env
def __init__(self,env_config):
self.worker_index = env_config.worker_index ## worker_index是
self.vector_index = env_config.vector_index
self.action_space = spaces.Box(low=-1,high=1,shape=(5,)) ## 一般需要定义动作空间作为成员变量
self.observation_space = spaces.Box(low=-1,high=1,shape=(6,)) ## 一般需要定义观测空间作为成员变量。
self.step_count = 0
def reset(self, seed=None, options=None): ## 重构reset函数, 需要包含 seed=None, options=None 两个变量
## seed 和 options后面要附上默认值None,否则报错
self.step_count = 0
obs = self.observation_space.sample()
info = {}
print(f"========== reset called ========")
print(f"========== worker_index = {self.worker_index} ========")
return obs,info ## 返回的数据格式是指定的,需要严格一致,否则报错
def step(self,action): ## 重构step函数, 输入参数是指定的
self.step_count += 1
obs = self.observation_space.sample()
reward = 0
terminated = False
truncated = False
info = {}
if self.step_count > 10:
terminated = True
return obs,reward,terminated,truncated,info ## 输出变量的数据格式必须是指定的
注意事项:
1. 主要是重构 reset 和 step 函数,严格按照 gym.Env的相关接口规则。
2. 在 __init__里面,最好定义self.action_space 和 self.observation_space两个成员变量
3. env_config.worker_index 是AlgorithmConfig 在调用env_config时额外附带的参数, 不是DIY者手工定义的:无需自定义,直接使用即可。
二、 向 ray 注册自定义的环境
from ray.tune.registry import register_env
def env_creator(env_config):
return MyEnv(env_config)
register_env("MyEnv",env_creator)
注册完毕之后,就可以向使用rllib自带的环境(比如前面使用的Cartpole-v1环境)一样使用自己创建的环境了。
三、 使用自定义的环境进行算法配置和训练
## 3. 启动 ray, 创建算法,使用环境
ray.init()
## 配置算法
config = PPOConfig()
config = config.environment(env="MyEnv", env_config={})
algo = config.build()
## 训练
for i in range(3):
iter_result = algo.train()
print(f"episode_{i}")
代码汇总:
import gymnasium as gym
from gymnasium import spaces
import ray
import numpy as np
from ray.rllib.algorithms.ppo import PPO, PPOConfig
from ray.tune.registry import register_env
## 1. 定义环境
class MyEnv(gym.Env): ## 注意1: 需要继承自gym.Env
def __init__(self,env_config):
self.worker_index = env_config.worker_index ## worker_index是
self.vector_index = env_config.vector_index
self.action_space = spaces.Box(low=-1,high=1,shape=(5,)) ## 一般需要定义动作空间作为成员变量
self.observation_space = spaces.Box(low=-1,high=1,shape=(6,)) ## 一般需要定义观测空间作为成员变量。
self.step_count = 0
def reset(self, seed=None, options=None): ## 重构reset函数, 需要包含 seed=None, options=None 两个变量
## seed 和 options后面要附上默认值None,否则报错
self.step_count = 0
obs = self.observation_space.sample()
info = {}
print(f"========== reset called ========")
print(f"========== worker_index = {self.worker_index} ========")
return obs,info ## 返回的数据格式是指定的,需要严格一致,否则报错
def step(self,action): ## 重构step函数, 输入参数是指定的
self.step_count += 1
obs = self.observation_space.sample()
reward = 0
terminated = False
truncated = False
info = {}
if self.step_count > 10:
terminated = True
return obs,reward,terminated,truncated,info ## 输出变量的数据格式必须是指定的
## 2. 向ray注册环境
def env_creator(env_config):
return MyEnv(env_config)
register_env("MyEnv",env_creator)
## 3. 启动 ray, 创建算法,使用环境
ray.init()
## 配置算法
config = PPOConfig()
config = config.environment(env="MyEnv", env_config={})
algo = config.build()
## 训练
for i in range(3):
iter_result = algo.train()
print(f"episode_{i}")