当前位置: 首页 > article >正文

ray.rllib 入门实践-2:配置算法

前言:

        ray.rllib的算法配置方式有多种,网上的不同教程各不相同,有的互不兼容,本文汇总罗列了多种算法配置方式,给出推荐,并在最后给出可运行代码。

四种配置方式

方法1

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print

## 配置算法
config = PPOConfig()\
        .rollouts(num_rollout_workers = 2)\
        .resources(num_gpus=0)\
        .environment(env="CartPole-v1")
algo = config.build()

缺点:不能在每行配置后面添加注释, 否则报错。 

方法2

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print

## 配置算法
algo = (
    PPOConfig()
    .rollouts(num_rollout_workers=1)  ## 注释
    .resources(num_gpus=0)
    .environment(env="CartPole-v1")
    .build()
)

用"()"把配置过程括起来,每行后面可以添加注释,不报错。官方教程使用的该种方式。 

方式3:推荐

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print

## 配置算法2
storage_path = "F:/codes/RLlib_study/ray_results/build_method_3"
config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config.output = storage_path  ## 设置过程文件的存储路径
algo = config.build()

优点:每一行是一个完整的命令, 后面可以添加注释,可以直接给config类的成员变量赋值。比如上面代码示例中的:config.output = storage_path , 直接配置存储路径,而不用去寻找output变量属于哪一个PPOConfig子模块。 

方式4:

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print

storage_path = "F:/codes/RLlib_study/ray_results/build_method_4"
os.makedirs(storage_path, exist_ok=True)
config = {
    "env":"CartPole-v1",
    "env_config":{}, ## 用于传递给env的信息
    "frame_work":"torch",
    "num_gpus":0,
    "num_workers":2,
    "num_cpus_per_worker":1,
    "num_envs_per_worker":1,
    "num_gpus_per_worker":0,
    "lr":0.001,
    "model":{
        "fcnet_hiddens":[256,256,64],
        "fcnet_activation":"tanh",
        "custom_model_config":{},
        "custom_model":None},
    "output":storage_path
}
algo = PPO(config=config) ## 构建算法

        这种方式在ray1.4版本之前使用较多,是唯一的配置方式。随着ray的更新迭代,用class封装了configDict, 即上面的方法1,方法2,方法3所用的方式。用 PPOConfig 进行配置后,最终也是转成方法4中的字典传递给算法使用, 但是相比方法4的字典, 方法1、2、3可以在编程时有语法提示,告诉你有哪几个成员变量或成员函数可以用于设计config。 

        现在仍旧有很多人用方法4配置rllib算法,我认为这是从老版本传递下来的一种习惯,新上手的人建议使用 AlgorithmConfig的方式配置算法。

汇总代码:

from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print
import os 

## 配置算法1
# config = PPOConfig()\
#         .rollouts(num_rollout_workers = 2)\
#         .resources(num_gpus=0)\
#         .environment(env="CartPole-v1")
# algo = config.build()

# ## 配置算法2
# algo = (
#     PPOConfig()
#     .rollouts(num_rollout_workers=1) 
#     .resources(num_gpus=0)
#     .environment(env="CartPole-v1")
#     .build()
# )

# ## 配置算法3
# storage_path = "F:/codes/RLlib_study/ray_results/build_method_4"
# os.makedirs(storage_path, exist_ok=True)
# config = PPOConfig()
# config = config.rollouts(num_rollout_workers=1) 
# config = config.resources(num_gpus=0)
# config = config.environment(env="CartPole-v1")
# config.output = storage_path
# algo = config.build()

## 配置算法 4
storage_path = "F:/codes/RLlib_study/ray_results/build_method_4"
os.makedirs(storage_path, exist_ok=True)
config = {
    "env":"CartPole-v1",
    "env_config":{}, ## 用于传递给env的信息
    "frame_work":"torch",
    "num_gpus":0,
    "num_workers":2,
    "num_cpus_per_worker":1,
    "num_envs_per_worker":1,
    "num_gpus_per_worker":0,
    "lr":0.001,
    "model":{
        "fcnet_hiddens":[256,256,64],
        "fcnet_activation":"tanh",
        "custom_model_config":{},
        "custom_model":None},
    "output":storage_path
}
algo = PPO(config=config) ## 构建算法
    


## 训练模型. 每个 iter 里重复执行多次 episode. 直到满足条件, 比如新增采样量达到一定体量。
for i in range(2):
    result = algo.train()
    print(pretty_print(result))

## 保存模型
checkpoint_dir = algo.save().checkpoint.path   
## algo.save()用于实现存储checkpoint, 后面跟着的.checkpoint.path用于返回存储路径
print(f"Checkpoint saved in directory {checkpoint_dir}")


http://www.kler.cn/a/518111.html

相关文章:

  • 赚钱的究极认识
  • 基于物联网的智能环境监测系统(论文+源码)
  • 【C语言常见概念详解】
  • C++中左值和右值的概念
  • Java---判断素数的三种方法
  • AI学习指南Ollama篇-Ollama模型的量化与优化
  • Lua 初级教程
  • Android BitmapShader简洁实现马赛克,Kotlin(二)
  • Java设计模式 三十 状态模式 + 策略模式
  • ProfiNet转CANopen应用于汽车总装生产线输送设备ProfiNet与草棚CANopen质量检测系统
  • C++ —— vector 容器
  • 立创开发板入门ESP32C3第八课 修改AI大模型接口为deepseek3接口
  • Redis高阶3-缓存双写一致性
  • 【8】思科IOS AP升级操作
  • 【Flutter】旋转元素(Transform、RotatedBox )
  • 【EI会议推荐】人工智能、电子信息、智能制造、机器人、自动化、控制科学、机械制造等计算机领域多主题可选!
  • STM32 调试小问题记录
  • qsort和std::sort比较函数返回值的说明
  • 《CPython Internals》阅读笔记:p353-p355
  • 正点原子Linux 移植USB Wifi模块MT7601U驱动(上)
  • Android-UI自动化测试环境配置
  • 【C语言算法刷题】第2题 图论 dijkastra
  • PBFT算法
  • ESG报告流程参考
  • 【深度学习】搭建PyTorch神经网络进行气温预测
  • Qt 5.14.2 学习记录 —— 십구 事件